我是靠谱客的博主 听话奇异果,最近开发中收集的这篇文章主要介绍模型训练与测试-----pytorch和keras的软融合----GAP与全连接的作用,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

最近因为项目需要做细粒度,需要将POI的全局特征,和OCR的文字特征进行融合,然后进行分类从而让模型具有区分细粒度的特点。 POI模型是pytorch,但是OCR那边的模型是keras, 因为实习时间有限,如果通过H5融合的话其实对于两个模型的后期调整非常麻烦,就是每次你都要调整,跑完全部数据,然后保存对应的全部H5,如果训练数据多的话,光是保存H5就需要时间了。而且OCR我们想提取哪一层特征也不是很明确,所以这样做的话时间会比较久。更新迭代的版本也是有限的,所以我决定尝试软融合,也就是我不通过保存H5,用pytorch作为主体框架,用keras模型部分,作为pytorch的一部分。选择用pytorch,第一是我可能不太了解tf,但是tf之前做过,但是感觉tensorflow可拓展性不是很强,而且里面涉及到了graph和session的概念,所以我还是用了pytorch.比如pytorch可以自己去定义一些层用nn.module只要将反向传播写好了就可以。pytorch的灵活性体现在它可以任意拓展我们所需要的内容

Tensorflow学习资料
https://zhuanlan.zhihu.com/p/47136826
https://zhuanlan.zhihu.com/p/45476929

pytorch灵活性的资料:
https://blog.csdn.net/weixin_37721058/article/details/97167620

pytorch的debug日记
https://blog.csdn.net/weixin_37721058/article/details/97146452

– NCHW 还是 NHWC

首先是pytorch 和keras 的tensor本身的排布是不一样的,这里是需要注意的,避免在后面混合的时候出问题。

def ocr_feature(self, x):

        #input pytorch type: (B,C,H,W) NCHW
        out_feature = x.permute(0,2,3,1).contiguous() 
        #input keras type:(B,H,W,C)

        ocr_feature = ocr_model(out_feature) NHWC
        #output keras :(B,H,W,C)
        out_feature = ocr_feature.permute(0,3,1,2).contiguous()
        
        #outp

最后

以上就是听话奇异果为你收集整理的模型训练与测试-----pytorch和keras的软融合----GAP与全连接的作用的全部内容,希望文章能够帮你解决模型训练与测试-----pytorch和keras的软融合----GAP与全连接的作用所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(34)

评论列表共有 0 条评论

立即
投稿
返回
顶部