当前位置:网站首页>2022/6/29-日报
2022/6/29-日报
2022-07-05 06:18:00 【mentalps】
0 2022/6/29
1 代码
1.1 计算更新参数之间的余弦相似度
import numpy as np
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def weights_reshape(weights):
new_weights = np.hstack((weights[0].reshape(-1), weights[1].reshape(-1), weights[2].reshape(-1), weights[3].reshape(-1)))
return new_weights
def client_cos_sim(client_weights):
all_cos = np.array([[0]*len(client_weights)]*len(client_weights), dtype=float)
for i in range(len(client_weights)):
for j in range(len(client_weights)):
all_cos[i][j] = cos_sim(client_weights[i], client_weights[j])
return all_cos
1.2 把相似度高于一定值的客户端聚合在一起
import numpy as np
import data
import model
import aggregator
import classification
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
ClientNum = 5
EPOCH = 5
if __name__ == '__main__':
data1, shards, x_test, y_test = data.generate_client_data(ClientNum)
model_client = []
for i in range(ClientNum):
model_client.append(model.Model())
global_model = model.Model()
# shards[0] = data.generate_malice_data(shards[0])
# shards[1] = data.generate_malice_data(shards[1])
all_weights = []
x_all = []
y_all = []
for i in range(ClientNum):
xx, yy = zip(*shards[i])
xx = np.array(xx)
yy = np.array(yy)
x_all.append(xx)
y_all.append(yy)
a_0 = np.argmax(global_model.predict(x_all[0][0:784]))
global_model.saveModel('sever_model')
for i in range(ClientNum):
np.argmax(model_client[i].predict(x_all[0][0:784]))
for i in range(EPOCH):
client_difference_value = []
client_difference_value_reshape = []
for j in range(ClientNum):
model_client[j].setWeights(global_model.getWeights())
for j in range(ClientNum):
model_client[j].run(x_all[j], y_all[j])
all_weights.append(model_client[j].getWeights())
for j in range(ClientNum):
client_difference_value.append(np.array(global_model.getWeights()) - np.array(model_client[j].getWeights()))
client_difference_value_reshape.append(classification.weights_reshape(client_difference_value[j]))
all_cos = classification.client_cos_sim(client_difference_value_reshape)
fedavg = aggregator.FedAvg(global_model, client_difference_value, ClientNum)
global_model.saveModel('sever_model')
x = []
y = []
for i in range(len(x_test)):
if y_test[i] == 1:
x.append(x_test[i])
y.append(y_test[i])
test_loss, test_acc = global_model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)
print('模型个数:', len(all_weights))
边栏推荐
- 1040 Longest Symmetric String
- Leetcode-6109: number of people who know secrets
- Alibaba's new member "Lingyang" officially appeared, led by Peng Xinyu, Alibaba's vice president, and assembled a number of core department technical teams
- 求组合数 AcWing 889. 满足条件的01序列
- TypeScript 基础讲解
- Navicat連接Oracle數據庫報錯ORA-28547或ORA-03135
- How to generate an image from text on fly at runtime
- Gaussian elimination acwing 884 Gauss elimination for solving XOR linear equations
- Doing SQL performance optimization is really eye-catching
- 阿里巴巴成立企业数智服务公司“瓴羊”,聚焦企业数字化增长
猜你喜欢

4. 对象映射 - Mapping.Mapster

NotImplementedError: Cannot convert a symbolic Tensor (yolo_boxes_0/meshgrid/Size_1:0) to a numpy ar

SPI details

高斯消元 AcWing 884. 高斯消元解异或线性方程组

阿里巴巴成立企业数智服务公司“瓴羊”,聚焦企业数字化增长

博弈论 AcWing 894. 拆分-Nim游戏

MIT-6874-Deep Learning in the Life Sciences Week 7

RGB LED infinite mirror controlled by Arduino

QQ computer version cancels escape character input expression

Redis publish subscribe command line implementation
随机推荐
QQ电脑版取消转义符输入表情
Record the process of configuring nccl and horovod in these two days (original)
11-gorm-v2-03-basic query
LeetCode 0107. Sequence traversal of binary tree II - another method
Erreur de connexion Navicat à la base de données Oracle Ora - 28547 ou Ora - 03135
1040 Longest Symmetric String
做 SQL 性能优化真是让人干瞪眼
Leetcode-31: next spread
P2575 master fight
高斯消元 AcWing 884. 高斯消元解异或线性方程组
Sum of three terms (construction)
How to set the drop-down arrow in the spinner- How to set dropdown arrow in spinner?
MatrixDB v4.5.0 重磅发布,全新推出 MARS2 存储引擎!
Daily question 1189 Maximum number of "balloons"
【LeetCode】Day94-重塑矩阵
Alibaba's new member "Lingyang" officially appeared, led by Peng Xinyu, Alibaba's vice president, and assembled a number of core department technical teams
MySQL advanced part 1: triggers
Currently clicked button and current mouse coordinates in QT judgment interface
Gaussian elimination acwing 884 Gauss elimination for solving XOR linear equations
Appium automation test foundation - Summary of appium test environment construction