当前位置:网站首页>2022/6/29-日报
2022/6/29-日报
2022-07-05 06:18:00 【mentalps】
0 2022/6/29
1 代码
1.1 计算更新参数之间的余弦相似度
import numpy as np
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def weights_reshape(weights):
new_weights = np.hstack((weights[0].reshape(-1), weights[1].reshape(-1), weights[2].reshape(-1), weights[3].reshape(-1)))
return new_weights
def client_cos_sim(client_weights):
all_cos = np.array([[0]*len(client_weights)]*len(client_weights), dtype=float)
for i in range(len(client_weights)):
for j in range(len(client_weights)):
all_cos[i][j] = cos_sim(client_weights[i], client_weights[j])
return all_cos
1.2 把相似度高于一定值的客户端聚合在一起
import numpy as np
import data
import model
import aggregator
import classification
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
ClientNum = 5
EPOCH = 5
if __name__ == '__main__':
data1, shards, x_test, y_test = data.generate_client_data(ClientNum)
model_client = []
for i in range(ClientNum):
model_client.append(model.Model())
global_model = model.Model()
# shards[0] = data.generate_malice_data(shards[0])
# shards[1] = data.generate_malice_data(shards[1])
all_weights = []
x_all = []
y_all = []
for i in range(ClientNum):
xx, yy = zip(*shards[i])
xx = np.array(xx)
yy = np.array(yy)
x_all.append(xx)
y_all.append(yy)
a_0 = np.argmax(global_model.predict(x_all[0][0:784]))
global_model.saveModel('sever_model')
for i in range(ClientNum):
np.argmax(model_client[i].predict(x_all[0][0:784]))
for i in range(EPOCH):
client_difference_value = []
client_difference_value_reshape = []
for j in range(ClientNum):
model_client[j].setWeights(global_model.getWeights())
for j in range(ClientNum):
model_client[j].run(x_all[j], y_all[j])
all_weights.append(model_client[j].getWeights())
for j in range(ClientNum):
client_difference_value.append(np.array(global_model.getWeights()) - np.array(model_client[j].getWeights()))
client_difference_value_reshape.append(classification.weights_reshape(client_difference_value[j]))
all_cos = classification.client_cos_sim(client_difference_value_reshape)
fedavg = aggregator.FedAvg(global_model, client_difference_value, ClientNum)
global_model.saveModel('sever_model')
x = []
y = []
for i in range(len(x_test)):
if y_test[i] == 1:
x.append(x_test[i])
y.append(y_test[i])
test_loss, test_acc = global_model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)
print('模型个数:', len(all_weights))
边栏推荐
- 实时时钟 (RTC)
- C Primer Plus Chapter 15 (bit operation)
- Alibaba's new member "Lingyang" officially appeared, led by Peng Xinyu, Alibaba's vice president, and assembled a number of core department technical teams
- Gaussian elimination acwing 884 Gauss elimination for solving XOR linear equations
- [leetcode] day95 effective Sudoku & matrix zeroing
- MatrixDB v4.5.0 重磅发布,全新推出 MARS2 存储引擎!
- Leetcode divide and conquer / dichotomy
- Leetcode stack related
- 【Rust 笔记】13-迭代器(中)
- Single chip computer engineering experience - layered idea
猜你喜欢
阿里新成员「瓴羊」正式亮相,由阿里副总裁朋新宇带队,集结多个核心部门技术团队
数据可视化图表总结(二)
MIT-6874-Deep Learning in the Life Sciences Week 7
Gaussian elimination acwing 884 Gauss elimination for solving XOR linear equations
QQ computer version cancels escape character input expression
P2575 master fight
Operator priority, one catch, no doubt
1.15 - input and output system
容斥原理 AcWing 890. 能被整除的数
Open source storage is so popular, why do we insist on self-development?
随机推荐
背包问题 AcWing 9. 分组背包问题
11-gorm-v2-03-basic query
MySQL advanced part 1: index
中国剩余定理 AcWing 204. 表达整数的奇怪方式
927. Trisection simulation
Data visualization chart summary (II)
博弈论 AcWing 893. 集合-Nim游戏
【LeetCode】Day94-重塑矩阵
1.14 - assembly line
NotImplementedError: Cannot convert a symbolic Tensor (yolo_boxes_0/meshgrid/Size_1:0) to a numpy ar
【Rust 笔记】13-迭代器(中)
TypeScript 基础讲解
The difference between CPU core and logical processor
MySQL advanced part 2: SQL optimization
Leetcode dynamic programming
[leetcode] day95 effective Sudoku & matrix zeroing
阿里巴巴成立企业数智服务公司“瓴羊”,聚焦企业数字化增长
4. 对象映射 - Mapping.Mapster
Niu Mei's math problems
Chapter 6 relational database theory