概述
一、前言
这次主要围绕~/src/train_softmax.py
脚本中的两个类来进行记录,AccMetric类
和LossValueMetric类
。
目录地址:insightface人脸识别代码记录(总)(基于MXNet)
二、主要内容
以下为train_softmax.py
中的部分代码,仅包含两个评价验证类。可以看出,这两个类均继承mxnet.metric.EvalMetric
基础类。
两个类均是在update()
中计算得到self.sum_metric
和self.num_inst
,这两个类似于分子和分母,self.sum_metric
存放的是预测的和实际一样的,self.num_inst
存放的是预测的总和。而评价函数的主要工作就是改写update()
。
然后,通过mx.metric.create()
来管理你的评价函数,或者通过另一种方式,mxnet.metric.CompositeEvalMetric类
也可以,具体见下面的示例。最后把这个eval_metric
送入fit()
中,就没我们啥事了。但是,metric在fit()
中究竟是怎么计算呢?
请接着往下看。
train_softmax.py
:
class AccMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric, self).__init__(
'acc', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
self.count = 0
def update(self, labels, preds):
self.count+=1
label = labels[0]
pred_label = preds[1]
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim==2:
label = label[:,0]
label = label.astype('int32').flatten()
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric, self).__init__(
'lossvalue', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
def update(self, labels, preds):
loss = preds[-1].asnumpy()[0]
self.sum_metric += loss
self.num_inst += 1.0
gt_label = preds[-2].asnumpy()
#print(gt_label)
...
...
metric1 = AccMetric()
eval_metrics = [mx.metric.create(metric1)]
if args.ce_loss:
metric2 = LossValueMetric()
eval_metrics.append( mx.metric.create(metric2) )
# eval_metrics = mx.metric.CompositeEvalMetric()
# for child_metric in [metrics_1, metrics_2]:
# eval_metrics.add(child_metric)
...
...
model.fit(train_dataiter,
begin_epoch = begin_epoch,
num_epoch = end_epoch,
eval_data = val_dataiter,
eval_metric = eval_metrics,
kvstore = 'device',
optimizer = opt,
#optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
allow_missing = True,
batch_end_callback = _batch_callback,
epoch_end_callback = epoch_cb )
这就得去官网看源码了,请看以下截图:
首先来到module模块中,即https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/module,进入base_module.py中,我们便可以看到fit()
的原型。如下图:
然后我们可以在fit()
中找到评价函数到底是如何运作计算的。如下图:
然后跟着来到metric.py脚本中,https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/metric.py,找到刚才的那个方法,get_global_name_value()
;然后根据self._has_global_stats
这个参数,分为是否全局两种情况,分别计算,如下图:
总结一下这个fit()
的过程:
首先,fit()
是Module类的一个方法,所以我们要去module.py脚本找,这个脚本里定义的Module类
,fit()方法继承自BaseModule类并且在BaseModule类中实现,Module类中并没有做修改,所以需要在BaseModule类
中查看fit()方法。再来到base_module.py,即实现BaseModule类
的地方,找到fit()
,在其中找到metric调用的get_global_name_value()
,最后来到metric.py,找到其计算方式。
三、结尾
这里就是关于Insightface中关于识别的评价函数的介绍。针对情况的不同,评价函数自然不同,但是万变不离其宗。
最后
以上就是单身蜡烛为你收集整理的insightface人脸识别代码记录(三)(评价函数)一、前言二、主要内容三、结尾的全部内容,希望文章能够帮你解决insightface人脸识别代码记录(三)(评价函数)一、前言二、主要内容三、结尾所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复