当前位置:网站首页>2022/6/29-日报
2022/6/29-日报
2022-07-05 06:18:00 【mentalps】
0 2022/6/29
1 代码
1.1 计算更新参数之间的余弦相似度
import numpy as np
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def weights_reshape(weights):
new_weights = np.hstack((weights[0].reshape(-1), weights[1].reshape(-1), weights[2].reshape(-1), weights[3].reshape(-1)))
return new_weights
def client_cos_sim(client_weights):
all_cos = np.array([[0]*len(client_weights)]*len(client_weights), dtype=float)
for i in range(len(client_weights)):
for j in range(len(client_weights)):
all_cos[i][j] = cos_sim(client_weights[i], client_weights[j])
return all_cos
1.2 把相似度高于一定值的客户端聚合在一起
import numpy as np
import data
import model
import aggregator
import classification
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
ClientNum = 5
EPOCH = 5
if __name__ == '__main__':
data1, shards, x_test, y_test = data.generate_client_data(ClientNum)
model_client = []
for i in range(ClientNum):
model_client.append(model.Model())
global_model = model.Model()
# shards[0] = data.generate_malice_data(shards[0])
# shards[1] = data.generate_malice_data(shards[1])
all_weights = []
x_all = []
y_all = []
for i in range(ClientNum):
xx, yy = zip(*shards[i])
xx = np.array(xx)
yy = np.array(yy)
x_all.append(xx)
y_all.append(yy)
a_0 = np.argmax(global_model.predict(x_all[0][0:784]))
global_model.saveModel('sever_model')
for i in range(ClientNum):
np.argmax(model_client[i].predict(x_all[0][0:784]))
for i in range(EPOCH):
client_difference_value = []
client_difference_value_reshape = []
for j in range(ClientNum):
model_client[j].setWeights(global_model.getWeights())
for j in range(ClientNum):
model_client[j].run(x_all[j], y_all[j])
all_weights.append(model_client[j].getWeights())
for j in range(ClientNum):
client_difference_value.append(np.array(global_model.getWeights()) - np.array(model_client[j].getWeights()))
client_difference_value_reshape.append(classification.weights_reshape(client_difference_value[j]))
all_cos = classification.client_cos_sim(client_difference_value_reshape)
fedavg = aggregator.FedAvg(global_model, client_difference_value, ClientNum)
global_model.saveModel('sever_model')
x = []
y = []
for i in range(len(x_test)):
if y_test[i] == 1:
x.append(x_test[i])
y.append(y_test[i])
test_loss, test_acc = global_model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)
print('模型个数:', len(all_weights))
边栏推荐
- WordPress switches the page, and the domain name changes back to the IP address
- Leetcode-6109: number of people who know secrets
- Leetcode dynamic programming
- There are three kinds of SQL connections: internal connection, external connection and cross connection
- [rust notes] 14 set (Part 1)
- 做 SQL 性能优化真是让人干瞪眼
- Leetcode backtracking method
- SPI details
- Daily question 1189 Maximum number of "balloons"
- Leetcode divide and conquer / dichotomy
猜你喜欢

WordPress switches the page, and the domain name changes back to the IP address

2021apmcm post game Summary - edge detection

SQL三种连接:内连接、外连接、交叉连接

区间问题 AcWing 906. 区间分组

求组合数 AcWing 888. 求组合数 IV

Appium基础 — 使用Appium的第一个Demo

MySQL advanced part 2: SQL optimization

Sqlmap tutorial (1)

QQ computer version cancels escape character input expression

Gaussian elimination acwing 884 Gauss elimination for solving XOR linear equations
随机推荐
MySQL advanced part 1: triggers
NotImplementedError: Cannot convert a symbolic Tensor (yolo_boxes_0/meshgrid/Size_1:0) to a numpy ar
Sum of three terms (construction)
Niu Mei's math problems
阿里新成员「瓴羊」正式亮相,由阿里副总裁朋新宇带队,集结多个核心部门技术团队
MySQL advanced part 2: optimizing SQL steps
Leetcode-6111: spiral matrix IV
Data visualization chart summary (II)
927. 三等分 模拟
[leetcode] day94 reshape matrix
MySQL advanced part 2: storage engine
Network security skills competition in Secondary Vocational Schools -- a tutorial article on middleware penetration testing in Guangxi regional competition
[rust notes] 17 concurrent (Part 2)
redis发布订阅命令行实现
Navicat连接Oracle数据库报错ORA-28547或ORA-03135
C Primer Plus Chapter 15 (bit operation)
TypeScript 基础讲解
CPU内核和逻辑处理器的区别
实时时钟 (RTC)
数据可视化图表总结(一)