当前位置:网站首页>2022/6/29-日报
2022/6/29-日报
2022-07-05 06:18:00 【mentalps】
0 2022/6/29
1 代码
1.1 计算更新参数之间的余弦相似度
import numpy as np
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def weights_reshape(weights):
new_weights = np.hstack((weights[0].reshape(-1), weights[1].reshape(-1), weights[2].reshape(-1), weights[3].reshape(-1)))
return new_weights
def client_cos_sim(client_weights):
all_cos = np.array([[0]*len(client_weights)]*len(client_weights), dtype=float)
for i in range(len(client_weights)):
for j in range(len(client_weights)):
all_cos[i][j] = cos_sim(client_weights[i], client_weights[j])
return all_cos
1.2 把相似度高于一定值的客户端聚合在一起
import numpy as np
import data
import model
import aggregator
import classification
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
ClientNum = 5
EPOCH = 5
if __name__ == '__main__':
data1, shards, x_test, y_test = data.generate_client_data(ClientNum)
model_client = []
for i in range(ClientNum):
model_client.append(model.Model())
global_model = model.Model()
# shards[0] = data.generate_malice_data(shards[0])
# shards[1] = data.generate_malice_data(shards[1])
all_weights = []
x_all = []
y_all = []
for i in range(ClientNum):
xx, yy = zip(*shards[i])
xx = np.array(xx)
yy = np.array(yy)
x_all.append(xx)
y_all.append(yy)
a_0 = np.argmax(global_model.predict(x_all[0][0:784]))
global_model.saveModel('sever_model')
for i in range(ClientNum):
np.argmax(model_client[i].predict(x_all[0][0:784]))
for i in range(EPOCH):
client_difference_value = []
client_difference_value_reshape = []
for j in range(ClientNum):
model_client[j].setWeights(global_model.getWeights())
for j in range(ClientNum):
model_client[j].run(x_all[j], y_all[j])
all_weights.append(model_client[j].getWeights())
for j in range(ClientNum):
client_difference_value.append(np.array(global_model.getWeights()) - np.array(model_client[j].getWeights()))
client_difference_value_reshape.append(classification.weights_reshape(client_difference_value[j]))
all_cos = classification.client_cos_sim(client_difference_value_reshape)
fedavg = aggregator.FedAvg(global_model, client_difference_value, ClientNum)
global_model.saveModel('sever_model')
x = []
y = []
for i in range(len(x_test)):
if y_test[i] == 1:
x.append(x_test[i])
y.append(y_test[i])
test_loss, test_acc = global_model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)
print('模型个数:', len(all_weights))
边栏推荐
- 2021apmcm post game Summary - edge detection
- LeetCode-61
- 4. Object mapping Mapster
- Chart. JS - Format Y axis - chart js - Formatting Y axis
- 【Rust 笔记】13-迭代器(中)
- Traversal of leetcode tree
- 11-gorm-v2-02-create data
- 1041 Be Unique
- [2021]GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields
- 927. 三等分 模拟
猜你喜欢

Error ora-28547 or ora-03135 when Navicat connects to Oracle Database

Sqlmap tutorial (1)

MySQL advanced part 1: stored procedures and functions

求组合数 AcWing 887. 求组合数 III

Navicat连接Oracle数据库报错ORA-28547或ORA-03135

博弈论 AcWing 894. 拆分-Nim游戏

MatrixDB v4.5.0 重磅发布,全新推出 MARS2 存储引擎!

数据可视化图表总结(二)

QQ computer version cancels escape character input expression

Groupbykey() and reducebykey() and combinebykey() in spark
随机推荐
高斯消元 AcWing 884. 高斯消元解异或线性方程组
Leetcode backtracking method
11-gorm-v2-03-basic query
Arduino 控制的 RGB LED 无限镜
How to understand the definition of sequence limit?
Leetcode array operation
[rust notes] 14 set (Part 2)
Data visualization chart summary (I)
博弈论 AcWing 893. 集合-Nim游戏
栈 AcWing 3302. 表达式求值
Quickly use Amazon memorydb and build your own redis memory database
MySQL advanced part 1: stored procedures and functions
数据可视化图表总结(二)
Error ora-28547 or ora-03135 when Navicat connects to Oracle Database
A reason that is easy to be ignored when the printer is offline
Simple selection sort of selection sort
11-gorm-v2-02-create data
Chapter 6 relational database theory
Leetcode-1200: minimum absolute difference
求组合数 AcWing 889. 满足条件的01序列