机器学习框架通常提供高层语言(比如 Python)作为给用户使用的接口,在 Python 层用户可以做的事有:
-
数据处理:从不同的数据源读取数据,还可以做预处理,比如增强;
-
模型定义:
通常会定义一个或多个模型的结构,一般长这样:
class Net(): def __init__(self): xxx def forward(self, input): xxx
__init__()
里定义模型需要的一些 module 或者 op,forward()
用来定义模型的前向结构,比如哪一层由哪些 op 哪些输入做了什么样的计算,也可以说这里是用户在定义“计算图”。反向就不用写了,现在框架基本都支持自动微分,梯度更新的时候框架会根据前向结构计算梯度;
-
优化器定义:模型的输出需要和用户的标记进行对比,这个对比差异一般通过损失函数(Loss function)来进行评估。因此,优化器定义API允许用户定义自己的损失函数,并且根据损失来引入(Import)和定义各种优化算法(Optimisation algorithms)来计算梯度(Gradient),完成对模型参数的更新;
-
训练:有了 dataset、模型、损失函数、优化器,用户一般需要自己写一个循环(比如多少个 epoch)来跑模型,跑一个训练集叫作一个 epoch,跑一个 batch size 叫作一个 step(或者 iter);
-
测试:练过程中,用户需要测试API来对当前模型的精度进行评估。当精度达到目标后,训练结束。这一过程中,用户往往需要调试API来完成对模型的性能和正确性进行验证;
-
(可选)保存模型: 训练模型的过程中模型的参数会通过反向传播得到更新(一般初始时参数值是随机的),得到更新之后的模型参数我们可以用来部署或者继续训练(微调),有时候我们需要把模型参数保存下来,之后再训练可以在此基础上继续进行。