我是靠谱客的博主 大意老鼠,最近开发中收集的这篇文章主要介绍流量预测之联邦学习时间序列预测算法之联邦学习,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

时间序列预测算法之联邦学习

介绍

设三个节点,其中一个中心节点,两个子节点,子节点利用LSTM模型训练,保证每个epoch完跟中心节点进行交互,完成参数融合

  1. 子节点部分代码
class LSTM(nn.Module):
def __init__(self, input_size=2, hidden_size=4, output_size=1, num_layer=1):
super(LSTM, self).__init__()
self.layer1 = nn.LSTM(input_size, hidden_size, num_layer)
self.layer2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x, _ = self.layer1(x)
x = torch.relu(x)
s, b, h = x.size()
x = x.view(s * b, h)
x = self.layer2(x)
x = x.view(s, b, -1)
return x
# 二、模型构建
model = LSTM(look_back, 4, 1, 2)
# print(model)
loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
# s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# # 定义中心节点的地址端口号
# host = '127.0.0.1'
# port = 9999
# #建立链接
# s.connect((host, port))
# 三、开始训练
losses = list()
steps = list()
for epoch in range(1, EPOCH + 1):
log("33[1;31;40m第33[1;31;40m%s33[1;31;40m轮开始训练!33[1;31;40m" % str(epoch))
# 第一个网络
for t in range(10):
loss_t = list()
# 前向传播
out = model(var_x)
loss = loss_fun(out, var_y)
loss_t.append(loss.item())
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(sum(loss_t)/len(loss_t))
steps.append(epoch)
plt.plot(steps, losses, "o-")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.draw()
plt.pause(0.1)
log("建立连接并上传......")
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 定义中心节点的地址端口号
host = '127.0.0.1'
port = 9999
# 建立链接
s.connect((host, port))
# 序列化
data = {}
data['num'] = epoch
data['model'] = model.state_dict()
keys = model.state_dict().keys()
data = pickle.dumps(data)
print(s.send(data))
# 等待待收
log("等待接收......")
try:
s.settimeout(30000)
data = s.recv(1024 * 100)
# print(data)
data = pickle.loads(data)
print(data['num'], epoch)
if data['num'] == epoch:
global_state_dict = data['model']
else:
global_state_dict = model.state_dict()
except Exception as e:
print(e)
# s.sendto(data, (host, port))
log("没有在规定时间收到正确的包, 利用本地参数更新")
global_state_dict = model.state_dict()
# print(global_state_dict)
# 重新加载全局参数
model.load_state_dict(global_state_dict)
s.close()
log("训练完毕,关闭连接")
s.close()

2.中心节点部分代码

def socket_udp_server():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# SOCK_STREAM指类型是UDP
host = '127.0.0.1'
# 监听指定的ip,host=''即监听所有的ip
port = 9999
# 绑定端口
s.bind((host, port))
# 开始监听:
s.listen(5)
# param : 等待连接的最大数量
print('waiting for connecting')
res, addrs = [], []
cnt = 1
while True:
log("第%d轮开始接收并计时" % cnt)
try:
s.settimeout(30000)
start = time.time()
# 接收操作
sock, addr = s.accept() #accept会等待并发返回一个客户端的连接
print(sock)
data = sock.recv(1024*100)
# 接收来自客户端的数据,最大(1k),阻塞式等待
print('Received from %s:%s' % addr)
# print('Received data:', data)
tmp = pickle.loads(data)
print(tmp['num'], cnt)
if tmp['num'] == cnt:
addrs.append(sock)
res.append(tmp['model'])
# print(res)
recv_time = time.time() - start
print(len(res))
if len(res) >= 2 or recv_time > 2000000:
log("第%d轮接收完毕 接收来自%d个节点的参数" % (cnt, len(res)))
# 处理操作
log("开始融合处理操作......")
# time.sleep(5)
# res = str(sum(res))
for m, n in zip(res[0].values(), res[1].values()):
if len(m.size()) == 1:
m1(m, n)
elif len(m.size()) == 2:
m2(m, n)
# print(res[0])
# res = pickle.dumps(res[0])
data = {}
data['num'] = cnt
data['model'] = res[0]
# 下发操作
log('第%d轮融合完毕,下发......' % cnt)
data = pickle.dumps(data)
# print(data)
for sock in (addrs):
sock.send(data)
sock.close()
# s.sendto(b'%s' % res.encode('utf-8'), addr)
# else:
#
res = '处理完毕,关闭连接'
#
for addr in (addrs):
#
s.sendto(b'%s' % res.encode('utf-8'), addr)
#
break
res, addrs = [], []
cnt += 1
if cnt > Epoch:
log('处理完毕,关闭连接')
break
else:
continue
except:
log("超时!!!")
cnt += 1
s.close()
  • 完整代码见FL节点流量预测

最后

以上就是大意老鼠为你收集整理的流量预测之联邦学习时间序列预测算法之联邦学习的全部内容,希望文章能够帮你解决流量预测之联邦学习时间序列预测算法之联邦学习所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(61)

评论列表共有 0 条评论

立即
投稿
返回
顶部