人工智能中的编程 - 第9章: 自动微分(Automatic Differentiation)
自动微分在 AI 框架中的地位
自动微分(Automatic Differentiation, AutoDiff)是深度学习框架的核心功能之一。它解决了高效计算神经网络梯度的问题,是反向传播算法(backpropagation)的高级实现方式。
AI 框架的六大任务
在 AI 框架中,自动微分与其他组件密切配合:
- 神经网络编程(NN Programming):CNN, RNN, GraphNet, Transformer, INR 等
- 自动微分(AutoDiff):计算梯度的自动化
- 数据管理与处理(Data Management):训练、验证、测试数据集和模型参数
- 模型训练与部署(Model Training & Deployment):SGD 等优化算法和部署工具
- 硬件加速(Hardware Acceleration):GPU, TPU, NPU 等
- 分布式执行(Distributed Execution):跨机器、跨 GPU 的并行计算
AI 框架的职责
现代 AI 框架(如 PyTorch、TensorFlow、JAX)需要:
灵活的编程模型:
- 提供直观的神经网络构建接口
- 自动推导计算图
- 自动计算微分(梯度)
- 自动优化网络
高效的计算能力:
- 自动编译和优化算法(子表达式消除、内核融合、内存优化)
- 根据硬件自动并行化和分布式化
- 支持多种处理器(CPU、GPU、TPU)
微分方法对比
三种微分方法
计算函数导数有三种主要方法,各有优缺点:
1. 符号微分(Symbolic Differentiation)
原理:使用数学符号操作和求导规则自动推导导数表达式
优点:
- 结果精确,无舍入误差
- 适合简单、结构固定的表达式
缺点:
- 表达式爆炸:导数表达式可能比原函数复杂百倍
- 效率低,执行速度慢
- 难以处理控制流(if/while 等)
- 对于复杂的函数或神经网络不实用
应用:
- Mathematica、Maple 等计算机代数系统
- 不适合深度学习框架
2. 数值微分(Numerical Differentiation)
原理:利用导数的定义近似计算
一阶前向差分: \(f'(x) \approx \frac{f(x+h) - f(x)}{h}\)
更精确的中心差分(二阶精度): \(f'(x) \approx \frac{f(x+h) - f(x-h)}{2h}\)
误差分析:
对于前向差分,截断误差为 \(O(h)$,舍入误差为\)O(\epsilon/h)\((其中\)\epsilon\(是机器精度)。总误差为:\)E(h) = \frac{Mh}{2} + \frac{\epsilon}{h}$$
最优步长为 \(h^* = 2\sqrt{\epsilon/M}$。对于双精度浮点数(\)\epsilon \approx 10^{-16}\(),最优步长约为\)10^{-8}$$。
具体例子:
计算 \(f(x) = x^2\) 在 \(x=2\) 处的导数(精确值为 4):
- 使用 \(h=0.01\):\(f'(2) \approx \frac{(2.01)^2 - 2^2}{0.01} = \frac{4.0401 - 4}{0.01} = 4.01\),误差 \(0.01\)
- 使用 \(h=10^{-8}\)(中心差分):\(f'(2) \approx \frac{(2+10^{-8})^2 - (2-10^{-8})^2}{2 \times 10^{-8}} \approx 4.000000\),误差 \(\sim 10^{-8}\)
优点:
- 实现简单,易于理解
- 对任何黑盒函数都适用
缺点:
- 精度差:在实际计算中,由于舍入误差和截断误差的权衡,精度难以提高
- 计算复杂度高:对于 \(n\) 维输入需要 \(n+1\) 次函数调用(或 \(2n\) 次用中心差分)
- 不可行性:对于现代神经网络(百万级参数),需要百万次前向计算
- 步长选择困难:需要仔细调整步长 \(h\),不同函数最优值不同
- 数值不稳定:在非光滑函数处表现差
应用:
- 梯度检验(Gradient Checking):验证自动微分实现的正确性
- 计算数值梯度和自动微分梯度,比较两者差异
- 相对误差应小于 \(10^{-7}\)(对于中心差分)
- 简单场景:黑盒优化问题
梯度检验示例代码:
def numerical_gradient(f, x, h=1e-5):
"""计算 f 在 x 处的数值梯度"""
grad = np.zeros_like(x)
for i in range(x.size):
x_h = x.copy().reshape(-1)
x_h[i] += h
fxh_pos = f(x_h.reshape(x.shape))
x_h[i] -= 2 * h
fxh_neg = f(x_h.reshape(x.shape))
grad.flat[i] = (fxh_pos - fxh_neg) / (2 * h)
return grad
# 验证自动微分
x = torch.randn(10, requires_grad=True)
y = (x ** 2).sum()
y.backward()
auto_grad = x.grad.numpy()
# 计算数值梯度
def f_numpy(x):
return (x ** 2).sum()
num_grad = numerical_gradient(f_numpy, x.detach().numpy())
# 检查相对误差
rel_error = np.linalg.norm(auto_grad - num_grad) / np.linalg.norm(auto_grad + num_grad)
print(f"相对误差: {rel_error}") # 应该 < 1e-7
3. 自动微分(Automatic Differentiation)
原理:将函数分解为基本操作,逐步应用链式法则
根据计算顺序分为:
- 前向模式(Forward Mode):从输入到输出
- 反向模式(Reverse Mode):从输出到输入
优点:
- 精度高:只有浮点舍入误差
- 效率高:\(O(1)\) 到 \(O(n)\) 次基本操作
- 支持控制流(if/while)
- 适合深度学习
缺点:
- 实现复杂度较高
- 需要维护计算图或中间变量
应用:
- 深度学习框架(PyTorch、TensorFlow)
- 科学计算库(JAX)
方法对比表
| 特性 | 符号微分 | 数值微分 | 自动微分 |
|---|---|---|---|
| 精度 | 精确 | 近似(低) | 精确(舍入误差) |
| 复杂度 | \(O(n)\) 个导数表达式 | \(O(n)\) 次函数调用 | \(O(n)\) 个算子 |
| 速度 | 慢(表达式爆炸) | 慢(多次计算) | 快 |
| 控制流支持 | 困难 | 支持 | 支持 |
| 实现复杂度 | 高 | 低 | 中等 |
| 适用范围 | 简单表达式 | 黑盒函数 | 复杂函数/神经网络 |
自动微分的基础
链式法则(Chain Rule)
自动微分的数学基础是链式法则。对于复合函数 \(y = f(g(x))\):
\[\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}\]其中 \(u = g(x)\)
对于多元函数,偏导数的链式法则为:
\[\frac{\partial y}{\partial x_i} = \sum_j \frac{\partial y}{\partial u_j} \cdot \frac{\partial u_j}{\partial x_i}\]具体例子:
考虑函数 \(y = \sin(x^2)\),我们要计算 $$\frac{dy}{dx}$:
设 \(u = x^2\),则 \(y = \sin(u)\)
\[\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} = \cos(u) \cdot 2x = 2x \cos(x^2)\]| 在 \(x = 1\) 处:$$\frac{dy}{dx} | _{x=1} = 2 \cos(1) \approx 2 \times 0.5403 \approx 1.0806$$ |
多元函数例子:
对于 \(z = (x + y)^2\),求 \(\frac{\partial z}{\partial x}$ 和\)\frac{\partial z}{\partial y}$$:
设 \(u = x + y\),则 \(z = u^2\)
\[\frac{\partial z}{\partial x} = \frac{\partial z}{\partial u} \cdot \frac{\partial u}{\partial x} = 2u \cdot 1 = 2(x + y)\] \[\frac{\partial z}{\partial y} = \frac{\partial z}{\partial u} \cdot \frac{\partial u}{\partial y} = 2u \cdot 1 = 2(x + y)\]计算图(Computational Graph)
函数的计算可以表示为计算图,其中:
- 节点表示基本操作(加、乘、激活函数等)
- 边表示数据流动
- 叶子节点表示输入变量
- 根节点表示输出(损失函数)
例如,\(y = (x_1 + x_2) \times (x_1 - x_2)\) 的计算图:
x₁ x₂
\ /
(加法) (减法)
\ /
(乘法)
|
y
前向模式(Forward Mode AD)
思想:从输入开始,沿着计算图正向传播,逐层计算导数(也称 Tangent Linear)
详细过程:
- 初始化:对要求导的输入 \(x_i\),设 \(\dot{x}_i = 1\)(其他输入的导数为 0)
- 前向传播:对计算图中每个中间节点 \(u_j\),计算: \(\dot{u}_j = \sum_k \frac{\partial u_j}{\partial u_k} \dot{u}_k\)
- 最后得到输出的导数:\(\dot{y} = \frac{dy}{dx_i}\)
具体例子:
计算 \(y = (x_1 + x_2) \times (x_1 - x_2)\) 对 \(x_1\) 的导数,在 \(x_1 = 3, x_2 = 2\) 处:
计算过程:
- \(u_1 = x_1 + x_2 = 5\),\(\dot{u}_1 = \dot{x}_1 + \dot{x}_2 = 1 + 0 = 1\)
- \(u_2 = x_1 - x_2 = 1\),\(\dot{u}_2 = \dot{x}_1 - \dot{x}_2 = 1 - 0 = 1\)
- \(y = u_1 \times u_2 = 5\),\(\dot{y} = u_2 \cdot \dot{u}_1 + u_1 \cdot \dot{u}_2 = 1 \times 1 + 5 \times 1 = 6\)
验证:\(\frac{\partial y}{\partial x_1} = \frac{\partial}{\partial x_1}[(x_1+x_2)(x_1-x_2)] = (x_1-x_2) + (x_1+x_2) = 2x_1 = 6\) ✓
计算量分析:
对于 \(n\) 个输入、1 个输出,需要 \(n\) 次前向传播(每次对应一个输入变量)
若计算图有 \(m\) 条边,总复杂度为 \(O(n \times m)\)
适用场景:
- 输入少、输出多的函数
- 例如:神经网络中的向量场计算,梯度方向场等
反向模式(Reverse Mode AD)
思想:从输出开始,沿着计算图逆向传播,逐层反向计算导数(也称 Adjoint 或 Cotangent Linear)
详细过程:
- 前向阶段:计算函数值,保存所有中间结果
- 初始化:\(\bar{y} = \frac{\partial L}{\partial y} = 1\)(对于标量输出)
- 反向传播:从输出向输入逐层计算伴随变量(adjoint): \(\bar{u}_j = \sum_k \bar{u}_k \frac{\partial u_k}{\partial u_j}\)
- 最后得到:\(\bar{x}_i = \frac{\partial L}{\partial x_i}\)
具体例子:
同上,计算 \(y = (x_1 + x_2) \times (x_1 - x_2)\) 对 \(x_1, x_2\) 的导数:
前向阶段:
- \[u_1 = x_1 + x_2 = 5\]
- \[u_2 = x_1 - x_2 = 1\]
- \[y = u_1 \times u_2 = 5\]
反向阶段:
- \[\bar{y} = 1\]
- \[\bar{u}_1 = \bar{y} \cdot \frac{\partial y}{\partial u_1} = 1 \times u_2 = 1\]
- \[\bar{u}_2 = \bar{y} \cdot \frac{\partial y}{\partial u_2} = 1 \times u_1 = 5\]
- \[\bar{x}_1 = \bar{u}_1 \cdot \frac{\partial u_1}{\partial x_1} + \bar{u}_2 \cdot \frac{\partial u_2}{\partial x_1} = 1 \times 1 + 5 \times 1 = 6\]
- \[\bar{x}_2 = \bar{u}_1 \cdot \frac{\partial u_1}{\partial x_2} + \bar{u}_2 \cdot \frac{\partial u_2}{\partial x_2} = 1 \times 1 + 5 \times (-1) = -4\]
验证:\(\frac{\partial y}{\partial x_2} = 2x_2 = 4$... 不对,让我重算:\)y = x_1^2 - x_2^2\(,\)\frac{\partial y}{\partial x_2} = -2x_2 = -4$$ ✓
计算量分析:
无论有多少个输入,反向模式只需 1 次前向 + 1 次反向,总复杂度为 \(O(m)\)(其中 \(m\) 为计算图的边数)
内存成本:
需要保存前向计算的所有中间结果,内存复杂度为 $$O(m + n)$
适用场景:
- 神经网络训练:输入参数多(百万级),输出少(标量损失)
- 这是深度学习框架中使用反向传播的根本原因
反向模式 vs 前向模式的效率对比:
对于有 \(n = 10^6\) 个参数、单个标量输出的神经网络:
- 前向模式:需要 \(10^6\) 次前向传播 → 总计 \(10^6\) 倍的计算量
- 反向模式:1 次前向 + 1 次反向 → 总计 2 倍的计算量
性能比率:\(10^6 : 2 \approx 500000 : 1\),这说明反向模式快 50 万倍!
深度学习中的自动微分
反向传播(Backpropagation)
深度学习使用反向模式自动微分来高效计算损失函数对所有参数的梯度。
过程:
- 前向传播:计算损失 \(L = f(x, \theta)\)
- 初始化:\(\bar{L} = \frac{\partial L}{\partial L} = 1\)
- 反向传播:逐层计算 \(\bar{\theta}^{(l)} = \frac{\partial L}{\partial \theta^{(l)}}\)
- 参数更新:\(\theta \leftarrow \theta - \eta \bar{\theta}\)
计算效率分析
对于有 \(n\) 个参数的神经网络:
- 数值微分:需要 \(n+1\) 次前向传播,复杂度 \(O(n)\)
- 反向传播:只需 1 次前向 + 1 次反向传播,复杂度 \(O(1)\)
这是反向传播相比数值微分的关键优势。
面向对象实现
操作算子的包装
在自动微分框架中,每个基本操作都需要包装为操作对象,包含:
前向计算:\(y = f(x_1, \ldots, x_n)\)
反向梯度计算:\(\bar{x}_i = \bar{y} \cdot \frac{\partial y}{\partial x_i}\)
计算图的构建
动态图(Dynamic Graph,如 PyTorch):
- 在代码执行时动态构建计算图
- 灵活支持控制流
- 调试友好
静态图(Static Graph,如 TensorFlow 早期):
- 先定义计算图,再执行
- 优化空间大
- 部署效率高
现代框架(PyTorch 的 eager execution、TensorFlow 的 eager mode)倾向于动态图。
常见的自动微分库
PyTorch
import torch
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = (x ** 2).sum()
y.backward() # 自动计算梯度
print(x.grad) # 输出: tensor([4., 6.])
特点:
- 命令式编程,易于学习
- 动态计算图
- 研究友好
TensorFlow
import tensorflow as tf
x = tf.Variable([2.0, 3.0])
with tf.GradientTape() as tape:
y = tf.reduce_sum(x ** 2)
grads = tape.gradient(y, x)
print(grads) # 输出: [4. 6.]
特点:
- 支持声明式和命令式编程
- 静态优化与动态灵活性的平衡
- 生产部署强大
JAX
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x ** 2)
grad_f = jax.grad(f)
x = jnp.array([2.0, 3.0])
print(grad_f(x)) # 输出: [4. 6.]
特点:
- 函数式编程风格
- 可组合的转换(grad、vmap、jit)
- 适合研究创新
总结
自动微分是 AI 框架的核心,通过高效的链式法则计算梯度。相比数值微分的低精度和高计算量,自动微分提供了:
- 精确性:只有浮点舍入误差
- 效率:反向模式对神经网络训练最优(\(O(1)\) 倍计算)
- 灵活性:支持动态控制流
- 可扩展性:支持复杂的神经网络架构
现代深度学习框架通过包装基本操作算子,自动构建和执行计算图,为用户屏蔽了自动微分的复杂性,使得写神经网络代码就像写普通 Python 代码一样简单。
Enjoy Reading This Article?
Here are some more articles you might like to read next: