概述
MegEngine是一个强大的深度学习推理与训练的框架,这里向大家介绍以下MegEngine中端上训练的功能~
那么,什么是端上训练呢?顾名思义,端上训练就是在手机和一些边缘式设备上进行深度学习模型的训练。这些设备的普遍特点是计算能力尤其是并行计算能力不强,端上训练的特点则是本地训练,即数据不回传到服务器。
咦,那么既然本来可以数据回传服务器进行训练,那还要端上训练干嘛呢?
有很多人遇到过这样一个现象,自己和朋友聊着天,说起了某个事情,比如好久没吃火锅了,过一会打开小红书微博淘宝等APP,首页推荐赫然是火锅。很多朋友和我说起这个问题感觉自己被窃听了,抱怨自己的隐私可能受到了侵犯。
是否真的窃听了我们这里姑且不论,但毫无疑问,对于用户而言,绝大多数情况下不希望自己的隐私暴露出去。根据中国的法律法规,侵犯用户隐私的行为也是不被允许的。并且在现在几乎人手一部智能手机的时代,手机中包含了大量的用户隐私,更有很多处于边缘地带的隐私,比如用户点单的时间,稍有不慎则会有侵犯隐私权之嫌。
然而,继续拿上面的例子来说,人们并不会因为看到首页的火锅,觉得这是自己不想看到的推荐而感到反感,而是不想泄露隐私。相反,这种推荐技术一定程度上对用户是有利的,如果你在某宝搜索了好几次电饭煲,可是都没找到自己喜欢的,假如时候闲着无聊随便刷刷某宝App,是不是如果出现电饭煲的商品信息会乐意点进去一看呢?事实上,很多应用都需要将用户数据回传进行模型的训练,比如我们身边无处不在的推荐系统,精准地猜测出你想吃什么,想看什么类型的电影等等,与之相对应的就是用户数据的上传用于推荐系统的训练,抖音等平台的巨大成功也无疑可以印证这种模式在大数据时代的威力。抛开隐私不谈,用户数据回馈模型会有利于帮助服务提供者提升用户的体验。
然而技术往往也是把双刃剑,比如人脸识别应用中的人脸解锁屏幕,几乎所有用过的手机用户都会觉得这样非常方便。新的人脸数据需要对模型进行微调,也即训练模型的一部分参数来达到好的效果。可是如果人脸数据被上传至了某手机厂商的服务器,那用户知道后,想必心里就会有个梗。
总结起来可以归纳为以下几点:
- 用户对于深度学习在移动端的应用给自己带来的便利是喜闻乐见的。
- 用户不喜欢、且法律不准许随意将移动设备采集到的个人数据,尤其是包含生物特征的数据进行上传。
- 在许多移动端深度学习应用场景中,需要对模型参数进行微调。
在这种情况下,想到在移动端进行训练便是顺理成章的了。如果可以进行端上训练,那么数据就只会留存在用户的设备本地,而不会通过网络上传,那么对于用于来讲隐私权得到了保护,对于服务提供者而言则可以在很好地规避法律问题的同时,提升其服务的质量。
此外,MegEngine对于异构计算的支持也可以很好的适应端上训练的需求,只需要一份相同的代码即可以在不同设备上运行。比如你可以在PC端构建样例并验证,然后在移动端进行部署,而不必每次测试都要在移动端进行。
那么接下来,就来看一下如何在MegEngine里面进行端上训练吧~
仍然是老规矩,拿Mnist数据集来进行试手,模型选用LeNet。在我们的内部测试中,调用端上训练接口的代码可以直接在手机上运行,并且效果和通用的Python训练接口完全对齐。
回顾在Pytorch、Tensorflow等框架建立训练流程时候做的事情,我们可以发现主要包括:
- 搭建模型;
- 添加Loss与Optimizer;
- 导入数据集;
- 设置学习率、训练轮数等超参数并训练。
搭建模型
模型的搭建其实是构造前向计算图的一个过程,通过调用算子,获取与输入相对应的输出。
从LeNet的模型结构容易得知,我们需要调用2次卷积算子,2次池化算子,1次Flatten算子,2次矩阵乘算子,以及若干次四则运算的算子。
在MegEngine中,算子只是负责执行运算的一个“黑盒子”,我们需要提前设置好参数,然后将参数与数据一起“喂”给算子。如下图所示,数据永远是逐层进行传递的,且其Layout会被自动计算,而参数则需要我们手动进行设置。
对于LeNet这种前馈神经网络,我们只需要将前面算子的输出与下一组参数链接到下一个算子,就可以将计算过程连接起来。
由于此处代码比较冗长,这里给出一个简化版的代码示例。可以看出,其实和调用通用的Python接口写法差别不大,甚至是一一对应的,比如opr::Convolution
对应nn.Conv2d
, opr::MatrixMul
对应nn.Linear
,只是由于C++语言特性和Python不同,所以写起来会有一些差异。
SymbolVar symbol_input =
opr::Host2DeviceCopy::make(*graph, m_input); // 初始化输入数据
SymbolVar symbol_conv =
opr::Convolution::make(symbol_input, symbol_conv_weight, conv_param); // symbol_weighs[0]即我们提前设置好的卷积filter权重
symbol_conv = opr::relu(symbol_conv + symbol_conv_bias); //加偏置之后激活
SymbolVar symbol_maxpool =
opr::Pooling::make(symbol_conv, pooling_param)
.reshape({batchsize, fc_shape[0]}); //池化之后进行展平
SymbolVar symbol_fc =
opr::MatrixMul::make(symbol_maxpool, symbol_fc_weight) +
symbol_fc_bias;
symbol_fc1= opr::relu(symbol_fc); //通过矩阵乘运算构造全连接层
通过这种方式,我们即可以将算子、数据与参数进行组合,构建出我们需要的前向计算图。
调用Loss与Optimizer
现在MegEngine中已经在C++层面对Loss和Optimizer进行了封装,下面我们以Mnist数据集训练中的交叉熵损失以及SGD优化器为例讲解。
在MegEngine中,一切推理与训练实际上都是在一张计算图上进行,而Loss与Optimizer本质上不过是将构造计算图的一部分任务封装了起来以供用户直接调用,而无需重复“造轮子”。例如,我们最熟悉的均方误差中,实际上是调用一次减法算子之后再调用一次乘方算子。
M S E = ( y − y ’ ) 2 MSE,,=,,left( y-y^’ right) ^2 MSE=(y−y’)2
明白了这一点之后,我们只需要继续上一步,在我们的模型输出后面调用Loss的API并进行拼接就可以,代码非常简单,和Pytorch中训练十分相似。
CrossEntopyLoss loss_func; // 先定义一个损失函数的实例,这里选取交叉熵损失
SymbolVar symbol_loss = loss_func(symbol_fc, symbol_label); // 将模型输出与标签作为输入,调用损失函数
这时,我们得到的symbol_loss
就是我们训练过程中的损失。
与调用Loss API类似,我们也可以很轻松地调用优化器插入到已有计算图中。
SGD optimizer = SGD(0.01f, 5e-4f, .9f); //实例化SGD优化器并设置参数
SymbolVarArray symbol_updates =
optimizer.make_multiple(symbol_weights, symbol_grads, graph); // 将Optimizer插入到计算图中
这样一来,在反向传播之后,梯度就会被Optimizer进行处理并更新模型参数。
导入数据集
既然模型参数是我们手动定义,那肯定会注意到一个问题就是我们的数据集怎么转化成参与计算图计算的数据呢?
这个当然MegEngine已经准备好了办法,可以通过继承一个接口并实现其中的get_item
与size
方法,并将这个类的实例输入到DataLoader中,那么就可以完成数据集的转换啦~
我们要继承的接口定义如下。咦,这里平时用Pytorch的小伙伴肯定已经闻到了熟悉的味道。
class IDataView {
public:
virtual DataPair get_item(int idx) = 0;
virtual size_t size() = 0;
virtual ~IDataView() = default;
};
话不多说直接上一个示例,这里只示意如何继承接口并得到DataLoader,如果有兴趣看具体实现的小伙伴可以去关注MegEngine~
class MnistDataset : public IDataView {
public:
MnistDataset(std::string dir_name); // 初始化数据集,指定数据集存放路径
void load_data(Mode mode, std::string dir_name); //读取Mnist数据集,存到dataset列表中。
DataPair get_item(int idx); // 实现接口
size_t size(); //实现接口
protected:
std::vector<DataPair> dataset;
};
// 实例化上面定义的数据集类
auto train_dataset = std::make_shared<MnistDataset>(dataset_dir);
// 用这个实例来获取对应的DataLoader
auto train_dataloader =
DataLoader(train_dataset, batchsize);
训练
既然完成了各个步骤,那么接下来的事情就是让训练跑起来~这里也是给出简单的伪代码示例。唔……这里使用Pytorch的小伙伴看了也会感到非常熟悉,也就是循环每个epoch,每个epoch中又循环每组数据与标签,不同的是在这里我们不需要在循环中调用Loss与Optimizer,因为前面已经构造好了完整的计算图,这里只需要执行我们编译后的计算图即可。
func = graph->compile(); // 编译计算图
for (int epoch = 0; epoch < epochs; epoch++) {
for (size_t i = 0; i < train_dataloader.size(); i++) {
data = train_dataloader.next(); // 从DataLoader中获取数据
func->execute(); // 执行计算图
}
}
通过我的以身试法(x),发现在端上训练可以达到用Pytorch以及MegEngine的Python训练接口训练的相同准确率~到这里我们的验证即获成功!
看到这里,相信你已经了解了如何在MegEngine中进行端上训练了,那么Loss和Optimizer又到底是什么样的接口呢?
Loss与Optimizer的封装
有的时候,我们会遇到需要封装自己需要的Loss和Optimizer的情况,这时候了解Loss和Optimizer的API就显得比较重要。
Loss的接口十分简单,可以归结为如下所示:
class ILoss {
public:
virtual mgb::SymbolVar operator()(mgb::SymbolVar symbol_pred,
mgb::SymbolVar symol_label) = 0;
virtual ~ILoss() = default;
};
只要输入预测值和标签值两个计算节点,能对应输出一个计算节点即可,这里细心的小伙伴可能已经注意到SymbolVar就是前面构建前向计算图的时候用到的类,这也是为什么说Loss的本质就是帮助你在计算图中插入一段计算过程。
Optimizer的接口也很简明,可以归结为下面的代码:
class IOptimizer {
public:
virtual mgb::SymbolVarArray make_multiple(
mgb::SymbolVarArray symbol_weights,
mgb::SymbolVarArray symbol_grads,
std::shared_ptr<mgb::cg::ComputingGraph> graph) = 0;
virtual mgb::SymbolVar make(
mgb::SymbolVar symbol_weight, mgb::SymbolVar symbol_grad,
std::shared_ptr<mgb::cg::ComputingGraph> graph) = 0;
virtual ~IOptimizer() = default;
};
class Optimizer : public IOptimizer {
public:
mgb::SymbolVarArray make_multiple(
mgb::SymbolVarArray symbol_weights,
mgb::SymbolVarArray symbol_grads,
std::shared_ptr<mgb::cg::ComputingGraph> graph); // 注意这里并不是纯虚函数
virtual mgb::SymbolVar make(
mgb::SymbolVar symbol_weight, mgb::SymbolVar symbol_grad,
std::shared_ptr<mgb::cg::ComputingGraph> graph) = 0;
virtual ~Optimizer() = default;
};
与Loss类似,这里我们也是输入计算节点,然后对应输出一个计算节点。值得注意的是Optimizer分为了两部分,一部分是纯粹的接口IOptimizer
,另一部分是继承了这个接口的抽象类Optimizer
。事实上,由于很多情况下,我们习惯于用一个数组或列表来存放我们的参数与得到的梯度,这时候由于静态语言的限制,不能直接将这种情况归并到单一输入的情况中,但是实际上只要我们实现了Make
接口,输入是数组的情况也自然会得到解决。但是考虑到接口与类应当进行分离的理念,这里进行了抽离,变成了一个接口、一个抽象类,且抽象类中包含了对数组输入的情况(make_multiple
接口)的默认实现。
倘若需要添加一个自定义的Loss或Optimizer,只需要继承相应的接口或抽象类并实现即可。
例如对均方误差MSE的实现:
mgb::SymbolVar MSELoss::operator()(
mgb::SymbolVar symbol_pred, mgb::SymbolVar symol_label) {
return opr::pow(symbol_pred - symol_label, symbol_pred.make_scalar(2));
}
总结与展望
看到这里,也许你会充满好奇,也许你会一脸嫌弃……
端上训练作为一个尚在探索中的方向,现在的确和已有的训练、推理框架没法比较,但MegEngine提供端上训练的功能会在你需要的时候为你提供一种选择。在这样一个手机越来越占据人们生活的时代,以及人们对服务质量的需求不断提高的时代,想必端上训练会有用武之地。
当前MegEngine端上训练的主要问题与下一步可能的改进点有:
- 模型的构建过程当前比较原始,可以进一步的封装出类似
nn.module
的模块。 - 有时候手里已经有了带有计算图信息的某个权重文件,不希望再次搭建计算图,而是直接读取现有的计算图并插入训练过程,可以提供类似的API
- 在C++侧进行数据的读取会比较麻烦
欢迎大家来尝试使用MegEngine搭建端上训练应用,也欢迎大家能指出当前MegEngine中端上训练存在的不足以便我们改进,也可以来提PR一起解决问题~
MegEngine项目地址:https://github.com/MegEngine/MegEngine
MegEngine官网:https://megengine.org.cn/
最后
以上就是野性唇膏为你收集整理的备用-暂时发布-MegEngine端上训练的全部内容,希望文章能够帮你解决备用-暂时发布-MegEngine端上训练所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复