概述
最近因为项目需要做细粒度,需要将POI的全局特征,和OCR的文字特征进行融合,然后进行分类从而让模型具有区分细粒度的特点。 POI模型是pytorch,但是OCR那边的模型是keras, 因为实习时间有限,如果通过H5融合的话其实对于两个模型的后期调整非常麻烦,就是每次你都要调整,跑完全部数据,然后保存对应的全部H5,如果训练数据多的话,光是保存H5就需要时间了。而且OCR我们想提取哪一层特征也不是很明确,所以这样做的话时间会比较久。更新迭代的版本也是有限的,所以我决定尝试软融合,也就是我不通过保存H5,用pytorch作为主体框架,用keras模型部分,作为pytorch的一部分。选择用pytorch,第一是我可能不太了解tf,但是tf之前做过,但是感觉tensorflow可拓展性不是很强,而且里面涉及到了graph和session的概念,所以我还是用了pytorch.比如pytorch可以自己去定义一些层用nn.module只要将反向传播写好了就可以。pytorch的灵活性体现在它可以任意拓展我们所需要的内容
Tensorflow学习资料
https://zhuanlan.zhihu.com/p/47136826
https://zhuanlan.zhihu.com/p/45476929
pytorch灵活性的资料:
https://blog.csdn.net/weixin_37721058/article/details/97167620
pytorch的debug日记
https://blog.csdn.net/weixin_37721058/article/details/97146452
– NCHW 还是 NHWC
首先是pytorch 和keras 的tensor本身的排布是不一样的,这里是需要注意的,避免在后面混合的时候出问题。
def ocr_feature(self, x):
#input pytorch type: (B,C,H,W) NCHW
out_feature = x.permute(0,2,3,1).contiguous()
#input keras type:(B,H,W,C)
ocr_feature = ocr_model(out_feature) NHWC
#output keras :(B,H,W,C)
out_feature = ocr_feature.permute(0,3,1,2).contiguous()
#outp
最后
以上就是听话奇异果为你收集整理的模型训练与测试-----pytorch和keras的软融合----GAP与全连接的作用的全部内容,希望文章能够帮你解决模型训练与测试-----pytorch和keras的软融合----GAP与全连接的作用所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复