我是靠谱客的博主 危机月饼,这篇文章主要介绍【transformer】【pytorch】DeiT的数据增强,现在分享给大家,希望可以做个参考。

1 main中的相关参数

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#函数:def get_args_parser(): parser.add_argument('--input-size', default=224, type=int, help='images input size') #颜色抖动 parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') #rand_augment_transform的参数 parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: rand-m9-mstd0.5-inc1)'), #插值方法 parser.add_argument('--train-interpolation', type=str, default='bicubic', help='Training interpolation (random, bilinear, bicubic default: "bicubic")') #repeated parser.add_argument('--repeated-aug', action='store_true') parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') parser.set_defaults(repeated_aug=True) #下面的是与随机擦除有关的参数 parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob (default: 0.25)') parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split')

2 datasets中的build_transform函数

没有直接将库中的create_transform返回,也是为了能够对其进行修改。

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def build_transform(is_train, args): resize_im = args.input_size > 32 #用于训练 if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.train_interpolation, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, ) if not resize_im: # replace RandomResizedCropAndInterpolation with # RandomCrop transform.transforms[0] = transforms.RandomCrop(#create_transform返回的是一个列表,可以对列表中的函数进行更改 args.input_size, padding=4) return transform #测试 t = [] if resize_im: size = int((256 / 224) * args.input_size) t.append( transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor())#最后两个不能忘记 t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) return transforms.Compose(t)#形成的列表放入Compose中

3 create_transform函数

来源:transforms_factory.py(timm库)
函数是一个

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#下面带井号的都是传入的参数 def create_transform( input_size,# is_training=False,# use_prefetcher=False, no_aug=False, scale=None, ratio=None, hflip=0.5, vflip=0., color_jitter=0.4,# auto_augment=None,#arg.aa=rand-m9-mstd0.5-inc1 interpolation='bilinear',#bicubic mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, re_prob=0.,# re_mode='const',# re_count=1,# re_num_splits=0, crop_pct=None, tf_preprocessing=False, separate=False): ... #没用到的就没有写 transform = transforms_imagenet_train(#使用的是在ImageNet数据集上训练后得到的参数 img_size, scale=scale, ratio=ratio, hflip=hflip, vflip=vflip, color_jitter=color_jitter, auto_augment=auto_augment, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits, separate=separate)

4 transforms_imagenet_train函数

来源:同上

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
1)transform返回列表: * 能选择是否返回三个还是合并的一个,显然seperate是这个作用; * 第一个函数是RandomResizedCropAndInterpolation,记得刚才在函数中对其进行了替换,transform[0]2)primary_tfl: * RandomResizedCropAndInterpolation * RandomHorizontalFlip(可选) * RandomVerticalFlip(可选) 3)secondary_tfl: * rand_augment_transform * ColorJitter 4)final_tfl: * ToTensor * Normalize * RandomErasing 5)as_params: * translate_const * img_mean
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def transforms_imagenet_train( img_size=224, scale=None, ratio=None, hflip=0.5, vflip=0., color_jitter=0.4, auto_augment=None, interpolation='random', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, re_prob=0., re_mode='const', re_count=1, re_num_splits=0, separate=False, ): """ If separate==True, the transforms are returned as a tuple of 3 separate transforms for use in a mixing dataset that passes * all data through the first (primary) transform, called the 'clean' data * a portion of the data through the secondary transform * normalizes and converts the branches above with the third, final transform """ scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range primary_tfl = [ RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)] if hflip > 0.: primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] if vflip > 0.: primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] secondary_tfl = [] if auto_augment:#rand-m9-mstd0.5-inc1 assert isinstance(auto_augment, str) if isinstance(img_size, tuple): img_size_min = min(img_size) else: img_size_min = img_size aa_params = dict( translate_const=int(img_size_min * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in mean]), ) if interpolation and interpolation != 'random':#从这里开始,看使用哪个auto_augment aa_params['interpolation'] = _pil_interp(interpolation) if auto_augment.startswith('rand'):#yes secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] elif auto_augment.startswith('augmix'): aa_params['translate_pct'] = 0.3 secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] else: secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] elif color_jitter is not None: # color jitter is enabled when not using AA if isinstance(color_jitter, (list, tuple)): # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # or 4 if also augmenting hue assert len(color_jitter) in (3, 4) else: # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue color_jitter = (float(color_jitter),) * 3 secondary_tfl += [transforms.ColorJitter(*color_jitter)] final_tfl = [] if use_prefetcher:#False # prefetcher and collate will handle tensor conversion and norm final_tfl += [ToNumpy()] else: final_tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std)) ] if re_prob > 0.: final_tfl.append( RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) if separate: return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) else: return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)

5 rand_augment_transform函数

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
调用:secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] auto_augment:rand-m9-mstd0.5-inc1 aa_params: translate_const, img_mean def rand_augment_transform(config_str, hparams): """ Create a RandAugment transform :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining sections, not order sepecific determine 'm' - integer magnitude of rand augment 'n' - integer num layers (number of transform ops selected per image) 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 'mstd' - float std deviation of magnitude noise applied 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) #为1表示使用严重程度随幅度增加的增强 Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 :param hparams: Other hparams (kwargs) for the RandAugmentation scheme :return: A PyTorch compatible Transform """ #auto_augment:rand-m9-mstd0.5-inc1 #aa_params: translate_const, img_mean magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) num_layers = 2 # default to 2 ops per image weight_idx = None # default to no probability weights for op choice transforms = _RAND_TRANSFORMS config = config_str.split('-') #m9, mstd0.5, inc1 assert config[0] == 'rand' config = config[1:] # mstd0.5, inc1 for c in config: cs = re.split(r'(d.*)', c) if len(cs) < 2: continue key, val = cs[:2] if key == 'mstd': # noise param injected via hparams for now hparams.setdefault('magnitude_std', float(val)) elif key == 'inc': if bool(val): transforms = _RAND_INCREASING_TRANSFORMS elif key == 'm': magnitude = int(val) elif key == 'n': num_layers = int(val) elif key == 'w': weight_idx = int(val) else: assert False, 'Unknown RandAugment config section' ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) #所有的增强方法 choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) #如果w没有,那么就是None ,否则就使用_RAND_CHOICE_WEIGHTS_的参数 return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) #随机选取操作,ra_ops是传入的方法str列表,num_layers是增强方法的数目,choice_weight是对应的权重参数 _RAND_INCREASING_TRANSFORMS = [ 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'PosterizeIncreasing', 'SolarizeIncreasing', 'SolarizeAdd', 'ColorIncreasing', 'ContrastIncreasing', 'BrightnessIncreasing', 'SharpnessIncreasing', 'ShearX', 'ShearY', 'TranslateXRel', 'TranslateYRel', #'Cutout' # NOTE I've implement this as random erasing separately # These experimental weights are based loosely on the relative improvements mentioned in paper. # They may not result in increased performance, but could likely be tuned to so. _RAND_CHOICE_WEIGHTS_0 = { 'Rotate': 0.3, 'ShearX': 0.2, 'ShearY': 0.2, 'TranslateXRel': 0.1, 'TranslateYRel': 0.1, 'Color': .025, 'Sharpness': 0.025, 'AutoContrast': 0.025, 'Solarize': .005, 'SolarizeAdd': .005, 'Contrast': .005, 'Brightness': .005, 'Equalize': .005, 'Posterize': 0, 'Invert': 0, }

6 RandAugment类

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class RandAugment: def __init__(self, ops, num_layers=2, choice_weights=None): self.ops = ops self.num_layers = num_layers self.choice_weights = choice_weights def __call__(self, img): # no replacement when using weighted choice ops = np.random.choice( self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) for op in ops: img = op(img) return img

最后

以上就是危机月饼最近收集整理的关于【transformer】【pytorch】DeiT的数据增强的全部内容,更多相关【transformer】【pytorch】DeiT内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(63)

评论列表共有 0 条评论

立即
投稿
返回
顶部