概述
算法进阶迁移学习
- 加载数据集
- 训练
加载数据集
class Image_Data(Dataset):
def __init__(self,img_h=128,img_w=128,path,mode='train',process=True):
self.img_h=img_h
self.img_w=img_w
self.path=path
self.mode=mode
self.process=process
if self.mode is 'train':
self.path=self.path+'/train'
self.transform=transforms.Compose([
transforms.Resize([img_h,img_w]),
transform.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
else:
self.path=self.path+'./val'
self.transfoem=transforms.Compose([
transforms.Resize([img_h,img_w]),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
self.file=self.file_list()
random.shuffle(self.file)
def file_list(self):
ants=os.listdir(self.path+'/ants')
bees=os.listdir(self.path+'/bees')
file_list=['ants/'+file for file in ants]+['bees/'+file for file in bees]
return file_list
def __len__(self):
return len(self.file_list())
def __getitem__(self,item):
image_name=self.file[item]
if 'ants' in image_name:
label=1
else:
label=0
return image,label
训练
class Trainer(object):
def __init__(self,lr=0.005,batch_size=50,num_epochs=50,train_data=None,test_data=None):
self.lr=lr
self.batch_size=batch_size
self.epoch=num_epoches
self.train_loader=DataLoader(dataset=train_data,batch_size=self.batch_size,shuffle=True)
self.test_loader=DataLoader(dataset=test_data,batch_size=self.batch_size,shuffle=True)
self.mode=models.resnet18()
self.loss=nn.CrossEntropyLoss()
#对模型全连接层进行修改
num=self.mode.fc.in_features
self.mode.fc=nn.Linear(num,2)
self.optimizer=torch.optim.Adam(self.mode.parameters(),lr=self.lr)
def train(self):
for epoch in range(self.epoch):
epoch_loss=0
for i,(bx,by) in enumerate(self.train_loader);
bx_gen=self.mode(bx)
bx_gen=torch.sigmoid(bx_gen)
loss=self.loss(bx_gen,by)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
epoch_loss+=loss.item()
acc=self.test()
def test(self):
acc=0
for i,(bx,by) in enumerate(self.test_loader);
bx_gen=self.mode(bx)
_,resule=torch.max(bx_gen,1)
acc+=torch.sum(resule==by.data)
acc=acc/self.teat_loader.dataset.__len__()
return acc.item()
最后
以上就是开放枕头为你收集整理的算法进阶迁移学习加载数据集训练的全部内容,希望文章能够帮你解决算法进阶迁移学习加载数据集训练所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复