本文是利用
torch.nn
实现一个简单的线性回归模型, 主要是根据nn.Module
和torch.nn
中的Linear
模型, 来创建一个自己的线性模型.适合与上篇手动创建模型进行对比学习.
导入所需模块
1 | import torch |
生成数据
首先我们都知道, 线性模型应该是满足如下线性函数的:
\[ f(\mathbf{x};\mathbf{w},b) = \mathbf{w}^T\mathbf{x}+b \]
那么我们可以通过自定义权重向量 \(\mathbf{w}\) 和偏置 \(b\) 来生成我们的线性模型数据.
假设真实权重 \(\mathbf{w} = [4.5, -1.7]\), 真实偏差 \(b = 2.8\), 以及一个随机噪声 \(\epsilon\) 来确定一个训练数据集 \(\mathcal{D} = \{\mathbf{X},Y\}^{1000 \times 2}\).
1 | num_inputs = 2 # 输入特征数 |
与上文PyTorch | 手动实现线性回归使用的数据集一致.
读取数据集
通过继承 nn.Module
来创建自己的网络, 就需要让数据集也适应网络的数据格式, 那么就可以使用 Dataset 类来实现数据的抽象. 但我们还需要对数据进行并行加速, 这就可以使用 DataLoader
来实现这些功能.
1 | batch_size = 10 # 小批量下降时的 batch 大小 |
需要注意的是 Windows 下
num_workers
只能设置为0
. 除非放在main
函数中使用, 具体可自行查找.
构建模型
torch.nn
的核心是 Module
, 这是一个抽象的概念, 既可以表示网络中的一个层, 也可以表示包含很多层的一个网络.
常见的写法是定义模型类来继承 nn.Module
, 在此基础上来构建网络. 主要特点是:
- 可以自动检测
parameter
, 并进行学习. (包括weight
,bias
等) - 主
Module
可以递归查找子Module
的parameter
.
需要做的是实现 Module
的两个基本函数:
- 构造函数
__init__()
: 在其中实现层的参数定义, 一般是可导参数 - 前向计算函数
forward()
: 实现前向运算
1 | # 方法一:继承nn.Module构建网络 |
或者可以使用 nn.Squential
和 nn.Linear
构建. 可以有三种形式:
1 | # 形式1 |
1 | # 形式2: |
1 | # 形式3: |
其中, 形式1对每层模型并没有设置名字, 如果需要使用相应层, 则可以使用 net[i]
来表示, 其他的可以使用例如 net.linear
表示.
初始化模型参数
在模型训练之前, 需要对模型参数进行初始化, 权重 w
初始化为均值为 0
, 标准差为 0.01
的正态分布, 偏置 b
初始化为 0
.
可以使用 nn.init
来初始化参数, normal
是正态分布, constant
是常量. 常量也可以使用如下另一种方法进行修改:
1 | nn.init.normal_(net.linear.weight, mean=0, std=0.01) |
这里不需要像之前手动创建线性模型时需要设置参数可导, 因为在构建模型时, 我们需要的参数都包含在 nn.Linear()
模型中了, 并且默认可导.
定义损失函数和优化算法
损失函数选择简单的平方损失(MSE)函数:
\[ loss = \frac{1}{N} \sum^{N}_{n=1}(y_i-\hat{y}_i)^2 \]
1 | loss = nn.MSELoss() |
常用损失函数还有:
- 平均绝对误差(MAE):
loss = nn.L1Loss()
- 二元交叉熵损失函数:
loss = nn.BCELoss()
- 包含 sigmoid 的二元交叉损失函数:
loss = nn.BCEWithLogitsLoss()
- 交叉熵损失函数:
nn.CrossEntropyLoss()
优化算法选择小批量梯度下降算法, 该优化算法 SGD
在 torch.optim
中已经内置了, 还包括 Adam
和 RMSProop
等, 此处直接使用 SGD
, 并且设定学习率 lr
:
1 | optimizer = optim.SGD(net.parameters(), lr=0.03) |
训练模型
接下来就开始训练模型, 小批量梯度下降的学习率 lr
和批次大小 batch_size
在 SGD
和 data_iter
中就已经设置好了, 这里需要设置训练的轮次 num_epochs
即可.
1 | num_epochs = 3 |
最后让我们来查看一下结果, 可以发现, 模型的拟合效果也是棒棒哒👍!
1 | print(f'true_w:{true_w}, w: {net.linear.weight.tolist()}') |
个人收获
使用 nn.Module
相比于手动创建更简单一些, 屏蔽了底层的细节, 更注重于网络的设计, 使得代码更加清晰明了.
不过各种内置的方法和函数需要花更多的时间去了解, 这样才能知道有哪些模型、哪些损失函数、哪些优化方法等, 可以直接使用.