我是靠谱客的博主 精明月光,最近开发中收集的这篇文章主要介绍在anti-spoofing中,在OULU数据集上求APCER,BPCER,ACER上的一个注意事项,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

OULU数据集用于anti-spoofing的测试,里面包含了打印攻击、回放攻击等一些攻击类型,想要具体了解OULU的数据概况的学者们,请读下OULU的原文,本人就不在这里详述了。OULU原文链接如下:OULU数据集 PDF版原文

 在用OULU对你设计的模型进行指标测试时,需要注意的一点是:你需要在OULU提供的四个协议上分别进行测试,OULU的前两个协议文件分别只有一个训练txt文件,而后两个协议则分别包含了6个训练文件,dev,test文件在这四个协议中也是如此!这四个协议是作者分别针对不同的攻击环境来进行设置的,具体如下:

OULU的四个协议
        协议名含义
P1评估在背景与光照上的泛化能力(未知的光照与背景变化的泛化能力)
P2对不同类型的打印和视频攻击进行训练和测试,来评估泛化能力(未知的攻击介质的泛化能力)
P3对图像捕获设备的类型进行泛化评估
P4结合P1-P3上的所有变化,来评估泛化能力

借此机会,我把SIW数据集的三个测试协议也放上,感兴趣的学者们可以了解下:

SIW的三个协议
协议名含义
P1评估对姿势和表情变化的泛化能力
P2评估在不同回放攻击设备上的泛化能力
P3通过在专门包含重放攻击或打印攻击视频的数据集上进行

一些研究者在OULU数据集上的测试结果:

 上图第一列代表的是测试协议名,第二列代表模型名字,其中协议3与协议4中值代表“均值±方差”;

 在用OULU的四个协议对你的模型测试时,切记,千万不要直接用求出来的TN,TP,FN,FP来直接求APCER,BPCER,ACER,这样的测试方式是不对的,本人一开始就是用这样的方式来求APCER这些指标的,结果发现测出来的指标出奇的好,后来是通过读OULU那篇原论文才发现,这样得测试方式是不对的!害,白白地让我瞎折腾了将近2个月,泪奔!

对于正确的测试方法,原文是这样写的:

它的大致意思就是:APCER与BPCER依赖于一个决策阈值,development集(可以理解为一个验证集)作为一个分离的验证集运行,用于调节模型的参数,并预测阈值以供测试集使用;也就是说:由验证集development set得到一个阈值threshold,而测试集利用验证集得到的这个threshold进行APCER,BPCER的指标测试

用代码的表示上述的测试方式如下:

#创建一个用于验证集测试的类变量
pad_meter_val = PADMeter()
#图像送入模型
output,_ = model(img)
#预测概率
class_output = nn.functional.softmax(output, dim=1)
#测量TP,TN,FP,FN
pad_meter_val.update(target.cpu().data.numpy(),class_output.cpu().data.numpy())
#求eer,阈值thr
pad_meter_val.get_eer_and_thr()
#用阈值thr求hter,apcer,bpcer
pad_meter_val.get_hter_apcer_etal_at_thr(pad_meter_val.threshold)
#用阈值thr求acc
pad_meter_val.get_accuracy(pad_meter_val.threshold)
####################################################
#创建一个用于测试集测试的类变量
pad_meter_test = PADMeter()
#图像送入模型
output,_ = model(img)
#预测概率
class_output1 = nn.functional.softmax(output, dim=1)
#得到TP,TN,FP,FN
pad_meter_test.update(target.cpu().data.numpy(),class_output1.cpu().data.numpy())
#在验证集得到的阈值下,测量测试集的hter,apcer等指标
pad_meter_test.get_hter_apcer_etal_at_thr(pad_meter_val.threshold)
#验证集得到的阈值下,测量测试集的acc指标
pad_meter_test.get_accuracy(pad_meter_val.threshold)

PADMeter()类封装的一些具体实现细节如下:

import math
import numpy as np
from sklearn.metrics import roc_curve, accuracy_score
from sklearn.metrics import roc_auc_score
from torch import nn

class PADMeter(object):
    """Presentation Attack Detection Meter"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.label = np.ones(0)
        self.output = np.ones(0)
        self.threshold = None
        self.grid_density = 10000

    def update(self, label, output):
        #一般取softmax之后,第1维度上的第1列数据
        if len(output.shape) > 1 and output.shape[1] > 1:
            output = output[:, 1]
        elif len(output.shape) > 1 and output.shape[1] == 1:
            output = output[:, 0]
        #拼接,将新的数据拼接到已有的后方
        self.label = np.hstack([self.label, label])
        self.output = np.hstack([self.output, output])

    def get_tpr(self, fixed_fpr):
        fpr, tpr, thr = roc_curve(self.label, self.output)
        tpr_filtered = tpr[fpr <= fixed_fpr]
        if len(tpr_filtered) == 0:
            self.tpr = 0.0
        self.tpr = tpr_filtered[-1]

    def eval_stat(self, thr):

        pred = self.output >= thr
        TN = np.sum((self.label == 0) & (pred == False))
        FN = np.sum((self.label == 1) & (pred == False))
        FP = np.sum((self.label == 0) & (pred == True))
        TP = np.sum((self.label == 1) & (pred == True))
        if TN + FP == 0:
            TN += 0.0001
        if TP + FN == 0:
            TP += 0.0001
        return TN, FN, FP, TP

    def get_eer_and_thr(self):

        thresholds = []
        Min, Max = min(self.output), max(self.output)
        for i in range(self.grid_density + 1):
            thresholds.append(Min + i * (Max - Min) / float(self.grid_density))
        min_dist = 1.0
        min_dist_stats = []
        for thr in thresholds:
            TN, FN, FP, TP = self.eval_stat(thr)
            far = FP / float(TN + FP)
            frr = FN / float(TP + FN)
            dist = math.fabs(far - frr)
            if dist < min_dist:
                min_dist = dist
                min_dist_stats = [far, frr, thr]

        # for exception
        if len(min_dist_stats) >= 2:
            self.eer = (min_dist_stats[0] + min_dist_stats[1]) / 2.0
            self.threshold = min_dist_stats[2]
        else:
            self.eer = 0.5
            self.threshold = 0.5

    def get_hter_apcer_etal_at_thr(self, thr=None):
        if thr is None:
            self.get_eer_and_thr()
            thr = self.threshold
        TN, FN, FP, TP = self.eval_stat(thr)

        far = FP / float(TN + FP)
        frr = FN / float(TP + FN)
        fpr = FP / float(FP + TN)
        tpr = TP / float(TP + FN)
        fpr1, tpr1, _ = roc_curve(self.label, self.output, pos_label=1)
        # print(type(fpr),fpr.shape,fpr.size())
        # TPR@FPR=e-2
        # tpr1 = np.array(tpr)
        # fpr1 = np.array(fpr)
        # print(type(fpr1), fpr1.shape,len(fpr1))
        # print(fpr1)
        score_1 = tpr1[np.where(fpr1 >= 0.01)[0][0]]
        # TPR@FPR=e-3
        score_2 = tpr1[np.where(fpr1 >= 0.001)[0][0]]
        # TPR@FPR=e-4
        score_3 = tpr1[np.where(fpr1 >= 0.0001)[0][0]]
        self.TN = TN
        self.FN = FN
        self.FP = FP
        self.TP = TP
        self.apcer = far
        self.bpcer = frr
        self.acer = (self.apcer + self.bpcer) / 2.0
        self.hter = (far + frr) / 2.0
        self.fpr = fpr
        self.tpr = tpr
        self.score1 = score_1
        self.score2 = score_2
        self.score3 = score_3
        try:
            self.auc = roc_auc_score(self.label, self.output)
        except ValueError:
            pass

    def get_accuracy(self, thr=None):
        if thr == None:
            self.get_eer_and_thr()
            thr = self.threshold
        TN, FN, FP, TP = self.eval_stat(thr)
        self.accuracy = accuracy = float(TP + TN) / len(self.output)

到此就结束了,在此祝各位学者科研顺利!

最后

以上就是精明月光为你收集整理的在anti-spoofing中,在OULU数据集上求APCER,BPCER,ACER上的一个注意事项的全部内容,希望文章能够帮你解决在anti-spoofing中,在OULU数据集上求APCER,BPCER,ACER上的一个注意事项所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(61)

评论列表共有 0 条评论

立即
投稿
返回
顶部