反向传播#
反向传播(英语:Backpropagation,意为误差反向传播,缩写为 BP)是对多层人工神经网络进行梯度下降的算法,也就是用链式法则以网络每层的权重为变数计算损失函数的梯度,以更新权重来最小化损失函数。
简单例子计算#
符号 | |
---|---|
$ w_{ij} $ | 权重 |
$ z_i $ | 输入 |
$ y_i $ | 输出 |
$ E = \frac{1}{2}(y_p-y_a)^2$ | 损失 |
$ f(x) = \frac{1}{1+e^{-x}}$ | 激活函数 |
例如更新 $w_{53}$
主要思想就是我们无法找到 $E$ 和 $w_{53}$ 的直接关系,所以使用链式求导法则来间接求梯度:
代码实现#
from math import exp, sqrt
def f(x):
"""returns sigmoid(x)"""
return 1 / (1 + exp(-x))
def d_f(x):
"""return d(sigmoid(x))/dx"""
return f(x) * (1 - f(x))
def E(yp, ya):
"""return the error in squre"""
return 1/2*(yp-ya)**2
def d_E(yp, ya):
"""return dE/d(yp)"""
return yp - ya
ya = 0.5
z = {1: 0.35, 2: 0.9, 3: 0.0, 4: 0.0, 5: 0.0}
w = {(1, 3): 0.1, (2, 3): 0.8, (1, 4): 0.4,
(2, 4): 0.6, (3, 5): 0.3, (4, 5): 0.9}
y = {i: 0.0 for i in range(1, 6)}
s = {(1, 3): 0.0, (2, 3): 0.0, (1, 4): 0.0,
(2, 4): 0.0, (3, 5): 0.0, (4, 5): 0.0}
def forward():
"""正向计算"""
y[1], y[2] = z[1], z[2]
z[3], z[4] = w[1, 3] * y[1] + w[2, 3] * \
y[2], w[1, 4] * y[1] + w[2, 4] * y[2]
y[3], y[4] = f(z[3]), f(z[4])
z[5] = w[3, 5] * y[3] + w[4, 5] * y[4]
y[5] = f(z[5])
e = E(y[5], ya)
print(f'yp = {y[5]}, Error = {e}')
return e
def update(n, g):
"""update weights"""
w[n] -= g
def back():
"""back propagation"""
g = d_E(y[5], ya) * d_f(z[5]) * y[3]
update((3, 5), g)
g = d_E(y[5], ya) * d_f(z[5]) * y[4]
update((4, 5), g)
g = d_E(y[5], ya) * d_f(z[5]) * w[3, 5] * d_f(z[3]) * y[1]
update((1, 3), g)
g = d_E(y[5], ya) * d_f(z[5]) * w[3, 5] * d_f(z[3]) * y[2]
update((2, 3), g)
g = d_E(y[5], ya) * d_f(z[5]) * w[4, 5] * d_f(z[4]) * y[1]
update((1, 4), g)
g = d_E(y[5], ya) * d_f(z[5]) * w[4, 5] * d_f(z[4]) * y[2]
update((2, 4), g)
if __name__ == '__main__':
step = 1
while True:
print(f'step = {step}:', end=' ')
if forward() < 1e-7:
break
step += 1
back()
运行代码,发现迭代 111 次后,损失才达到预期。
➜ python3 main.py
step = 1: yp = 0.6902834929076443, Error = 0.01810390383656677
step = 2: yp = 0.6820312027460466, Error = 0.016567679386586154
step = 3: yp = 0.6739592936119999, Error = 0.015130917916992996
step = 4: yp = 0.6660860653430867, Error = 0.013792290550574028
...
...
...
step = 110: yp = 0.500455314229855, Error = 1.036555239542297e-07
step = 111: yp = 0.5004302844701667, Error = 9.2572362633313e-08
动量法优化#
def update(n, g):
"""update weights"""
lr = 0.1
beta = 0.999
epsilon = 0.01
s[n] = beta * s[n] + (1 - beta) * g**2
w[n] -= lr / (sqrt(s[n]) + epsilon) * g
代码如上,使用动量法优化后,仅迭代 11 步损失就达到了预期。
➜ python3 main.py
step = 1: yp = 0.6902834929076443, Error = 0.01810390383656677
step = 2: yp = 0.6118790538896405, Error = 0.006258461349620539
step = 3: yp = 0.5593494607066376, Error = 0.0017611792430843605
step = 4: yp = 0.5301730589054645, Error = 0.00045520674185631563
step = 5: yp = 0.5151484953395415, Error = 0.00011473845552605596
step = 6: yp = 0.5075810359788381, Error = 2.873605325621872e-05
step = 7: yp = 0.5037909668833773, Error = 7.185714955431785e-06
step = 8: yp = 0.5018953359787233, Error = 1.7961492361214424e-06
step = 9: yp = 0.5009475298860605, Error = 4.48906442488917e-07
step = 10: yp = 0.5004736747645289, Error = 1.1218389127573745e-07
step = 11: yp = 0.5002367820786155, Error = 2.803287637675012e-08
使用 PyTorch 实现#
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
hidden = nn.Linear(2, 2, bias=False)
output = nn.Linear(2, 1, bias=False)
hidden.weight.data = torch.tensor([[0.1, 0.8], [0.4, 0.6]])
output.weight.data = torch.tensor([0.3, 0.9])
self.net = nn.Sequential(
hidden,
nn.Sigmoid(),
output,
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
def MSE(pred, act):
return 1/2*torch.mean((pred - act)**2)
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.tensor([0.35, 0.9]).to(device)
y = torch.tensor(0.5).to(device)
net = Net().to(device)
optimizer = torch.optim.RMSprop(
net.parameters(), lr=0.1, alpha=0.999, eps=1e-2)
def train():
optimizer.zero_grad()
pred = net(x)
loss = MSE(pred, y)
loss.backward()
optimizer.step()
err = loss.detach().item()
print(f'pred = {pred}, error = {err}')
return err
if __name__ == '__main__':
step = 1
while True:
print(f'step = {step}:', end=' ')
if train() < 1e-7:
break
step += 1
References#
https://zh.m.wikipedia.org/zh-cn/%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95