概述
写在前面的话
该问题在 GitHub的 detectron2 的 issues 上被提出,有人解决了(如下图所示)
提示一下,去 GitHub 上的 issues 搜索问题,尽量找【closed】标签的,这些基本都是有解决方法的问题。
这里只做个记录,仅供学习使用
参考GitHub链接:
How do I compute validation loss during training?
这个实现的很巧妙,直接把训练集的替换成验证集,用原本的训练集的计算loss的方法做计算
添加的包
from detectron2.engine import HookBase
from detectron2.data import build_detection_train_loader
import detectron2.utils.comm as comm
功能函数
class ValidationLoss(HookBase):
def __init__(self, cfg, DATASETS_VAL_NAME):#多加一个DATASETS_VAL_NAME参数(小改动)
super().__init__()
self.cfg = cfg.clone()
self.cfg.DATASETS.TRAIN = DATASETS_VAL_NAME##
self._loader = iter(build_detection_train_loader(self.cfg))
def after_step(self):
data = next(self._loader)
with torch.no_grad():
loss_dict = self.trainer.model(data)
losses = sum(loss_dict.values())
assert torch.isfinite(losses).all(), loss_dict
loss_dict_reduced = {"val_" + k: v.item() for k, v in
comm.reduce_dict(loss_dict).items()}
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
if comm.is_main_process():
self.trainer.storage.put_scalars(total_val_loss=losses_reduced,
**loss_dict_reduced)
使用方法
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = Trainer(cfg)
val_loss = ValidationLoss(cfg, cfg.DATASETS.VAL) ##多加的参数
trainer.register_hooks([val_loss])
# swap the order of PeriodicWriter and ValidationLoss
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
trainer.resume_or_load(resume=False)
trainer.train()
实现效果
total_val_loss
val_loss_cls
val_loss_box_reg
最后
以上就是冷酷小霸王为你收集整理的detectron2 在训练过程中输出 validation loss(验证集的损失)的全部内容,希望文章能够帮你解决detectron2 在训练过程中输出 validation loss(验证集的损失)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复