博客地址:https://www.cnblogs.com/zylyehuo/
多层感知机(MLP)训练过程
可以把多层感知机(MLP)的训练过程想象成“小学生做练习题 + 老师批改并讲解”的一个循环
1. 拿到一张“试卷”(输入数据)
- 每次训练,我们从训练集中抽一小批图片(比如 256 张 28×28 的 Fashion‑MNIST 图像),就好像给学生发 256 道习题。
- 每张图片对应一个“正确答案”(标签),比如“鞋子”“衬衫”……
2. 学生先自己做一遍(前向传播)
Flatten → 全连→ReLU→全连→输出 10 维分数
- 扁平化:把 28×28 的像素摊平成长度 784 的一行题干;
- 第一道 “全连接题”(784→256):学生做题时先根据当前掌握的“权重”(W₁、b₁)算出 256 个中间答题状态;
- ReLU 激活:遇到负分数就当 0 处理(“不懂的题目先跳过”),遇到正分数就保留;
- 第二道 “全连接题”(256→10):再根据 W₂、b₂ 计算出 10 个“原始分数”(logits),代表对每个类别的“信心值”。
3. 老师批改给分(计算损失)
- 老师把学生给出的 10 个分数经过 Softmax 变成概率(总和为 1),然后看正确答案对应的概率有多大:
- 损失越大,说明学生这道题做得越“离谱”。
4. 学生据此改正(反向传播 + 参数更新)
反向传播:
- 老师会告诉学生「每道题哪个步骤出错最多」,即算出 ∂损失/∂W₁、∂损失/∂b₁、∂损失/∂W₂、∂损失/∂b₂。
梯度下降(SGD):
- 学生根据老师的反馈,用一个小步长(学习率 lr)去调整自己的“解题方法”(权重矩阵 W₁、b₁、W₂、b₂),让下次做同样的题时分数更高(损失更低)。
5. 重复做题—循序渐进(多轮 Epoch)
- 把全班题(整个训练集)分一小批一小批做,每批都反复「做→批改→调整」。
- 做完一遍称为一个 Epoch。训练 10 个 Epoch,就相当于学生做了 10 遍全套试卷。
6. 定期模拟考试(评估集 / 测试集)
- 每隔一段时间(或每做完一遍试卷),让学生做一套“模拟考试”——从没见过的测试数据里抽一批。
- 测一次做对的百分比(准确率),观察「训练时成绩」和「模拟考试成绩」是否都在稳步上升。
为什么能学会“识别”?
第一层全连接 + ReLU
- 相当于学生在画面里学会“识别简单特征”,比如“一个区域像线条”、“这个区域像块状”等。
第二层全连接
- 把这些“线条和块状特征”重新组合,形成更高级的“形状概念”(鞋头、衣领、口袋边缘……),最终映射到 10 类之一。
通过大量不同款式的衣服图片训练,学生(模型)能逐渐:
- 在第一层 捕捉到“哪儿有亮度突变”(边缘)或“哪儿比较暗”(阴影)这类低级特征;
- 在第二层 组合出“这更像一只鞋子还是一件上衣”这样的高级判断;
- 最终层 做出“我最相信它是第 3 类(运动鞋)”的预测。
每一轮「做题→批改→调整」都会让模型的参数更趋向于“在所有训练样本上都做对”,从而也能在新的、没见过的补充测试题上取得不错成绩。
- import torch
- from torch import nn
- from d2l import torch as d2l
- # 1. 定义模型结构:Flatten + 两个隐层 + 激活 + 输出层
- net = nn.Sequential(
- nn.Flatten(), # 扁平化输入
- nn.Linear(784, 256), # 第一隐藏层:784 → 256
- nn.ReLU(), # 激活函数:引入非线性
- nn.Linear(256, 10) # 输出层:256 → 10
- )
- # 2. 参数初始化:同样给所有 Linear 层赋小随机值
- def init_weights(m):
- if isinstance(m, nn.Linear):
- nn.init.normal_(m.weight, std=0.01)
- net.apply(init_weights)
- # 3. 超参数设置
- batch_size, lr, num_epochs = 256, 0.1, 10
- loss = nn.CrossEntropyLoss(reduction='none')
- trainer = torch.optim.SGD(net.parameters(), lr=lr)
- train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
- # 4. 训练并可视化(同前)
- d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
复制代码
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |