我是靠谱客的博主 重要母鸡,最近开发中收集的这篇文章主要介绍Scaffold 基于fedavg方法的改进,代码复现(联邦学习),觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

当前有个工作需要实现scaffold算法,该方法通过添加修正项c来解决客户端漂移现象,
在参考github上的相关框架后,复现了该算法。
算法分为三个模块:
optimizer: 重写优化器sdg
clientscaffold:客户端操作
serverscaffold:服务端操作

optimizer部分代码:

import torch
from torch.optim import Optimizer
class SCAFFOLDOptimizer(Optimizer):
def __init__(self, params, lr, weight_decay):
defaults = dict(lr=lr, weight_decay=weight_decay)
super(SCAFFOLDOptimizer, self).__init__(params, defaults)
pass
def step(self, server_controls, client_controls, closure=None):
loss = None
if closure is not None:
loss = closure
# for group, c, ci in zip(self.param_groups, server_controls, client_controls):
#
p = group['params'][0]
#
if p.grad is None:
#
continue
#
d_p = p.grad.data + c.data - ci.data
#
p.data = p.data - d_p.data * group['lr']
for group in self.param_groups:
for p, c, ci in zip(group['params'], server_controls, client_controls):
if p.grad is None:
continue
d_p = p.grad.data + c.data - ci.data #这里实现用c来更新本地模型
p.data = p.data - d_p.data * group['lr']
return loss

serverscaffold:

from flcore.clients.clientscaffold import clientScaffold
from flcore.servers.serverbase import Server
from utils.data_utils import read_client_data
from threading import Thread
import torch
import random
class Scaffold(Server):
def __init__(self, device, dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal, time_threthold):
super().__init__(dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal,
time_threthold)
# select slow clients
self.set_slow_clients()
self.global_model=model
for i, train_slow, send_slow in zip(range(self.num_clients), self.train_slow_clients, self.send_slow_clients):
train, test = read_client_data(dataset, i)
client = clientScaffold(device, i, train_slow, send_slow, train, test, model, batch_size, learning_rate, local_steps)
self.clients.append(client)
print(f"nJoin clients / total clients: {self.join_clients} / {self.num_clients}")
self.server_controls = [torch.zeros_like(p.data) for p in model.parameters() if p.requires_grad]
def train(self):
for i in range(self.global_rounds+1):
self.send_parameters() #发送修正项 c
if i%self.eval_gap == 0:
print(f"n-------------Round number: {i}-------------")
print("nEvaluate global model")
self.evaluate()
self.selected_clients = self.select_clients()
for client in self.selected_clients:
client.train()
self.aggregate_parameters()
print("nBest global results.")
self.print_(max(self.rs_test_acc), max(
self.rs_train_acc), min(self.rs_train_loss))
self.save_results()
self.save_global_model()
def send_parameters(self):
assert (len(self.clients) > 0)
for client in self.clients:
client.set_parameters(self.global_model)
for control, new_control in zip(client.server_controls, self.server_controls):
control.data = new_control.data
def aggregate_parameters(self):
assert (len(self.selected_clients) > 0)
active_clients = random.sample(
self.selected_clients, int((1-self.client_drop_rate) * self.join_clients))
active_train_samples = 0
for client in active_clients:
active_train_samples += client.train_samples
self.uploaded_weights = []
for client in active_clients:
self.uploaded_weights.append(client.train_samples / active_train_samples)
for user,w in zip(active_clients,self.uploaded_weights):
self.add_parameters(user, active_train_samples,w)
def add_parameters(self, user, total_samples,w):
num_of_selected_users = self.join_clients#len(self.selected_clients)
num_of_users = self.num_clients
num_of_samples = user.train_samples
for param, control, del_control, del_model in zip(self.global_model.parameters(), self.server_controls,
user.delta_controls, user.delta_model):
#因为数据不是独立同分布,所以采用每个客户端的样本比例来替代客户端数量

# param.data = param.data + del_model.data / num_of_selected_users
# control.data = control.data + del_control.data / num_of_users
param.data = param.data + del_model.data *w
control.data = control.data + del_control.data *w

clientscaffold:

import torch
import torch.nn as nn
from flcore.clients.clientbase import Client
import numpy as np
import time
import copy
from flcore.optimizers.fedoptimizer import *
import math
class clientScaffold(Client):
def __init__(self, device, numeric_id, train_slow, send_slow, train_data, test_data, model, batch_size, learning_rate,
local_steps):
super().__init__(device, numeric_id, train_slow, send_slow, train_data, test_data, model, batch_size, learning_rate,
local_steps)
self.loss = nn.CrossEntropyLoss()
#self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
L=0 # Regularization term 用的它默认值
#这里模型用的重写的优化器
self.optimizer = SCAFFOLDOptimizer(self.model.parameters(), lr=self.learning_rate, weight_decay=L)
self.controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
self.server_controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
self.delta_controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
#更新模型
self.delta_model = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
self.server_model = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
self.local_model = copy.deepcopy(list(self.model.parameters()))
def set_grads(self, new_grads):
if isinstance(new_grads, nn.Parameter):
for model_grad, new_grad in zip(self.model.parameters(), new_grads):
model_grad.data = new_grad.data
elif isinstance(new_grads, list):
for idx, model_grad in enumerate(self.model.parameters()):
model_grad.data = new_grads[idx]
def train(self):
start_time = time.time()
# self.model.to(self.device)
self.model.train()
#暂时用scaffold论文第2种方法更新本地的c
grads = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
#self.get_grads(grads)
max_local_steps = self.local_steps
if self.train_slow:
max_local_steps = np.random.randint(1, max_local_steps // 2)
for step in range(max_local_steps):
if self.train_slow:
time.sleep(0.1 * np.abs(np.random.rand()))
x, y = self.get_next_train_batch()
self.optimizer.zero_grad()
output = self.model(x)
loss = self.loss(output, y)
loss.backward()
self.optimizer.step(self.server_controls, self.controls)
# get model difference #得到当前和服务端模型的差异
for local, server, delta in zip(self.model.parameters(), self.server_model, self.delta_model):
delta.data = local.data.clone() - server.data.clone()
# get client new controls,对应论文的两种操作
new_controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
opt = 2
if opt == 1:
for new_control, grad in zip(new_controls, grads):
new_control.data = grad.grad
if opt == 2:
for server_control, control, new_control, delta in zip(self.server_controls, self.controls, new_controls,
self.delta_model):
a = 1 / (math.ceil(self.train_samples / self.batch_size) * self.learning_rate)
new_control.data = control.data - server_control.data - delta.data * a
# get controls differences
for control, new_control, delta in zip(self.controls, new_controls, self.delta_controls):
delta.data = new_control.data - control.data
control.data = new_control.data
self.train_time_cost['num_rounds'] += 1
self.train_time_cost['total_cost'] += time.time() - start_time
def set_parameters(self, server_model):
for old_param, new_param, local_param, server_param in zip(self.model.parameters(), server_model.parameters(), self.local_model, self.server_model):
old_param.data = new_param.data.clone()
local_param.data = new_param.data.clone()
server_param.data = new_param.data.clone()

最后

以上就是重要母鸡为你收集整理的Scaffold 基于fedavg方法的改进,代码复现(联邦学习)的全部内容,希望文章能够帮你解决Scaffold 基于fedavg方法的改进,代码复现(联邦学习)所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(50)

评论列表共有 0 条评论

立即
投稿
返回
顶部