概述
最近英伟达发布了一个开源项目,https://github.com/NVIDIA/retinanet-examples,查看源码我们发现在RetinaNet/model.py 中将将pytorch的pth模型转化为onnx时,代码中有这样一段代码:
import torch.onnx.symbolic
# Override Upsample's ONNX export until new opset is supported
@torch.onnx.symbolic.parse_args('v', 'is')
def upsample_nearest2d(g, input, output_size):
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
return g.op("Upsample", input,
scales_f=(1, 1, height_scale, width_scale),
mode_s="nearest")
torch.onnx.symbolic.upsample_nearest2d = upsample_nearest2d
后面发现有人在https://github.com/onnx/onnx-tensorrt/issues/77 中提到,目前onnx-tensorrt 项目的upsample 这个layer会报错:
Attribute not found: height_scale
然后onnx-tensorrt 项目源码中将这个bug修复了,即使用
onnx2trt my_model.onnx -o my_engine.trt
会正常将onnx模型序列化,但是在运行这个序列化文件时,还是会报
Attribute not found: height_scale
错误。然后再做个实验,我直接使用tensorrt5.0的API接口:
void onnxToTRTModel(const std::string& modelFile, // name of the onnx model
unsigned int maxBatchSize, // batch size - NB must be at least as large as the batch we want to run with
nvinfer1::IHostMemory*& trtModelStream,
nvinfer1::DataType dataType,
nvinfer1::IInt8Calibrator* calibrator,
std::string save_name) // output buffer for the TensorRT model
{
int verbosity = (int)nvinfer1::ILogger::Severity::kINFO;
// create the builder
nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(gLogger);
nvinfer1::INetworkDefinition* network = builder->createNetwork();
auto parser = nvonnxparser::createParser(*network, gLogger);
if (!parser->parseFromFile(modelFile.c_str(), verbosity))
{
string msg("failed to parse onnx file");
gLogger.log(nvinfer1::ILogger::Severity::kERROR, msg.c_str());
exit(EXIT_FAILURE);
}
if ((dataType == nvinfer1::DataType::kINT8 && !builder->platformHasFastInt8()) )
exit(EXIT_FAILURE); //如果不支持kint8或不支持khalf就返回false
// Build the engine
builder->setMaxBatchSize(maxBatchSize);
builder->setMaxWorkspaceSize(4_GB); //不能超过你的实际能用的显存的大小,例如我的1060的可用为4.98GB,超过4.98GB会报错
builder->setInt8Mode(dataType == nvinfer1::DataType ::kINT8); //
builder->setInt8Calibrator(calibrator); //
samplesCommon::enableDLA(builder, gUseDLACore);
nvinfer1::ICudaEngine* engine = builder->buildCudaEngine(*network);
assert(engine);
// we can destroy the parser
parser->destroy();
// serialize the engine, then close everything down 序列化
trtModelStream = engine->serialize();
gieModelStream.write((const char*)trtModelStream->data(), trtModelStream->size());
std::ofstream SaveFile(save_name, std::ios::out | std::ios::binary);
SaveFile.seekp(0, std::ios::beg);
SaveFile << gieModelStream.rdbuf();
gieModelStream.seekg(0, gieModelStream.beg);
engine->destroy();
network->destroy();
builder->destroy();
}
在执行parser->parseFromFile(modelFile.c_str(), verbosity) 这句代码时,直接段错误,完全无法定位错误原因。但是事实上错误原因很简单,tensorrt5.0支持的onnx 的opset版本是9 ,但是目前pytorch导出的onnx已经是10了。
总结
目前tensorrt5.0 出来的时候,pytorch1.0未正式发布,所以tensorrt5.0是按照pytorch0.4.1进行开发的,pytorch1.0以后onnx导出的版本又发生了变化,但是tensorrt5.0未更新,所以我们必须要修改所有pytorch1.0及以上版本的onnx导出规则,即在运行代码中按照https://github.com/NVIDIA/retinanet-examples所做的那样,代码中加入upsample_nearest2d的重载,这样就可以正常使用tensorrt5.0 的onnx解析功能了。
最后
以上就是闪闪棒球为你收集整理的pytorch1.0,1.0.1-- onnx --tensorRT5.0.2.6的upsample_nearest2d BUG总结的全部内容,希望文章能够帮你解决pytorch1.0,1.0.1-- onnx --tensorRT5.0.2.6的upsample_nearest2d BUG总结所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复