概述
资料:
github代码链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
b站一个不错的up主讲解视频:https://www.bilibili.com/video/BV1of4y1m7nj?t=99&p=2
数据集
数据集使用Pascal VOC2012 (共20个分类)
Pascal VOC2012 train/val数据集下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
文件结构
├── backbone: 特征提取网络,可以根据自己的要求选择
├── network_files: Faster R-CNN网络(包括Fast R-CNN以及RPN等模块)
├── train_utils: 训练验证等相关模块(包括cocotools)
├── my_dataset.py: 自定义dataset用于读取VOC数据集
├── train_mobilenet.py: 以MobileNetV2做为backbone进行训练
├── train_resnet50_fpn.py: 以resnet50+FPN做为backbone进行训练
├── train_multi_GPU.py: 针对使用多GPU的用户使用
├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
├── valisation.py: 利用训练好的权重验证/测试数据的COCO指标,并生成record_mAP.txt文件
└── pascal_voc_classes.json: pascal_voc标签文件
9
predict.py操作步骤
1.确定设备情况
2.定义模型
3.载入定义好的模型
4.载入模型权重
5.读取pascal_voc的索引文件,即类别和索引
6.载入一张图片对其进行预处理
7.
1.定义网络架构(模型)
def create_model(num_classes):
backbone = resnet50_fpn_backbone()
model = FasterRCNN(backbone=backbone, num_classes=num_classes)
return model
定义一个名为create_model的函数,包含变量num_classes。其中backbone用的是之前定义的resnet50_fpn_backbone,模型直接调用FasterRCNN,在FasterRCNN中给主干网络backbone和标签类别num_classes进行赋值。
2.测试代码在cuda运行时间
def time_synchronized():
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
torch.cuda.synchronize()是测试时间的函数,完成的命令是等待当前设备上所有流中的所有核心完成。一般使用该操作来等待GPU全部执行结束,CPU才可以读取时间信息。
torch.cuda.is_available()函数用来查看GPU是否可用,如果torch.cuda.is_available()返回Ture说明GPU可用。
3.定义主函数
def main():
# get devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# create model
model = create_model(num_classes=21)
# load train weights
train_weights = "./save_weights/model.pth"
assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights)
model.load_state_dict(torch.load(train_weights, map_location=device)["model"])
#加入["model"]用来载入模型相关的一些权重文件,比如之前保存的优化器啊学习率的调整等权重。
model.to(device)
#将模型指配到设备当中
# read class_indict,读取pascal_voc的索引文件,即类别和索引
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
json_file = open(label_json_path, 'r')
class_dict = json.load(json_file)
category_index = {v: k for k, v in class_dict.items()}
# load image载入一张图片对其进行预处理,这里只将它转化为tensor,因为在Fasterrcnn中已经包含的有一些图像的预处理。
original_img = Image.open("./test.jpg")
# from pil image to tensor, do not normalize image
data_transform = transforms.Compose([transforms.ToTensor()])
img = data_transform(original_img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)#增加一个维度
model.eval()
# 进入验证模式
with torch.no_grad():#在上下文关联器中即torch.no_grad中去把图像指认到设备当中
# init
img_height, img_width = img.shape[-2:]
init_img = torch.zeros((1, 3, img_height, img_width), device=device)
model(init_img)
t_start = time_synchronized()
predictions = model(img.to(device))[0]#把图像指认到设备当中,传入模型进行正向传播得到预测结果
t_end = time_synchronized()
print("inference+NMS time: {}".format(t_end - t_start))
#将结果获取放入到cpu上
predict_boxes = predictions["boxes"].to("cpu").numpy()
predict_classes =predictions["labels"].to("cpu").numpy()
predict_scores = predictions["scores"].to("cpu").numpy()
if len(predict_boxes) == 0:
print("没有检测到任何目标!")
#通过draw_box方法来绘制图像
draw_box(original_img,
predict_boxes,
predict_classes,
predict_scores,
category_index,
thresh=0.5,
line_thickness=3)
plt.imshow(original_img)
plt.show()
# 保存预测的图片结果
original_img.save("test_result.jpg")
主函数步骤:
1)获取设备信息
获取gpu设备信息的操作要放在读取数据之前。
如果torch.cuda.is_available()返回Ture,即gpu可用,此时Tensor分配到第一台(“0”)gpu.否则使用cpu。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
常用的关于查询gpu设备信息的操作如下:
torch.cuda.is_available()----cuda是否可用
torch.cuda.device_count()----返回gpu数量
torch.cuda.get_device_name(0)----返回gpu名字,设备索引默认从0开始
torch.cuda.current_device()----返回当前设备索引
device = torch.device(‘cuda’)----将数据转移到GPU
device = torch.device(‘cpu’)----将数据转移的cpu
2)给模型赋值
标签数为21,将其传递给刚才定义的create_model函数,返回值赋给model。
model = create_model(num_classes=21)
3)下载训练权重
# load train weights
train_weights = "./save_weights/model.pth"
assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights)
model.load_state_dict(torch.load(train_weights, map_location=device)["model"])
model.to(device)
assert其作用是如果它的条件返回错误,则终止程序执行。
os.path.exists()就是判断括号里的文件是否存在的意思,括号内的可以是文件路径。存在时True,不存在是False。
format函数用于字符串的格式化,比如。
通过关键字:
print(’{name}在{option}’.format(name=“谢某人”,option=“写代码”))
结果:谢某人在写代码
通过位置:
print(‘name={} path={}’.format(‘zhangsan’, ‘/’)
结果:name=zhangsan path=/
state_dict 是一个简单的python的字典对象,作用是将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)。
torch.load(文件名,设备)用来加载模型。
4)读取分类信息
# read class_indict
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
json_file = open(label_json_path, 'r')
class_dict = json.load(json_file)
category_index = {v: k for k, v in class_dict.items()}
5)加载测试集
original_img = Image.open("./test.jpg")
6) 处理测试集图片(格式和维度)
# from pil image to tensor, do not normalize image
data_transform = transforms.Compose([transforms.ToTensor()])
img = data_transform(original_img)
定义和调用data_transform函数对测试集图片进行数据转化,转化为张量。
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
squeeze的用法主要就是对数据的维度进行压缩或者解压。
squeeze()函数功能:
主要对数据的维度进行压缩(默认为1)。也可以通过dim指定位置,删掉指定位置的维数。
unsqueeze()函数功能:
对数据维度进行扩充。dim指定位置,添加指定位置的维数添加1。
详细:https://blog.csdn.net/xiexu911/article/details/80820028
7)验证
model.eval()
# 进入验证模式
with torch.no_grad():
# init
img_height, img_width = img.shape[-2:]
init_img = torch.zeros((1, 3, img_height, img_width), device=device)
model(init_img)
t_start = time_synchronized()
predictions = model(img.to(device))[0]
t_end = time_synchronized()
print("inference+NMS time: {}".format(t_end - t_start))
predict_boxes = predictions["boxes"].to("cpu").numpy()
predict_classes = predictions["labels"].to("cpu").numpy()
predict_scores = predictions["scores"].to("cpu").numpy()
if len(predict_boxes) == 0:
print("没有检测到任何目标!")
draw_box(original_img,
predict_boxes,
predict_classes,
predict_scores,
category_index,
thresh=0.5,
line_thickness=3)
plt.imshow(original_img)
plt.show()
# 保存预测的图片结果
original_img.save("test_result.jpg")
最后
以上就是开朗镜子为你收集整理的Faster R-CNN代码讲解之predict.py的全部内容,希望文章能够帮你解决Faster R-CNN代码讲解之predict.py所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复