当前位置:网站首页>June 29, 2022 daily
June 29, 2022 daily
2022-07-05 06:23:00 【mentalps】
0 2022/6/29
1 Code
1.1 Calculate cosine similarity between update parameters
import numpy as np
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def weights_reshape(weights):
new_weights = np.hstack((weights[0].reshape(-1), weights[1].reshape(-1), weights[2].reshape(-1), weights[3].reshape(-1)))
return new_weights
def client_cos_sim(client_weights):
all_cos = np.array([[0]*len(client_weights)]*len(client_weights), dtype=float)
for i in range(len(client_weights)):
for j in range(len(client_weights)):
all_cos[i][j] = cos_sim(client_weights[i], client_weights[j])
return all_cos
1.2 Aggregate clients whose similarity is higher than a certain value
import numpy as np
import data
import model
import aggregator
import classification
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
ClientNum = 5
EPOCH = 5
if __name__ == '__main__':
data1, shards, x_test, y_test = data.generate_client_data(ClientNum)
model_client = []
for i in range(ClientNum):
model_client.append(model.Model())
global_model = model.Model()
# shards[0] = data.generate_malice_data(shards[0])
# shards[1] = data.generate_malice_data(shards[1])
all_weights = []
x_all = []
y_all = []
for i in range(ClientNum):
xx, yy = zip(*shards[i])
xx = np.array(xx)
yy = np.array(yy)
x_all.append(xx)
y_all.append(yy)
a_0 = np.argmax(global_model.predict(x_all[0][0:784]))
global_model.saveModel('sever_model')
for i in range(ClientNum):
np.argmax(model_client[i].predict(x_all[0][0:784]))
for i in range(EPOCH):
client_difference_value = []
client_difference_value_reshape = []
for j in range(ClientNum):
model_client[j].setWeights(global_model.getWeights())
for j in range(ClientNum):
model_client[j].run(x_all[j], y_all[j])
all_weights.append(model_client[j].getWeights())
for j in range(ClientNum):
client_difference_value.append(np.array(global_model.getWeights()) - np.array(model_client[j].getWeights()))
client_difference_value_reshape.append(classification.weights_reshape(client_difference_value[j]))
all_cos = classification.client_cos_sim(client_difference_value_reshape)
fedavg = aggregator.FedAvg(global_model, client_difference_value, ClientNum)
global_model.saveModel('sever_model')
x = []
y = []
for i in range(len(x_test)):
if y_test[i] == 1:
x.append(x_test[i])
y.append(y_test[i])
test_loss, test_acc = global_model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)
print(' Number of models :', len(all_weights))
边栏推荐
- 11-gorm-v2-02-create data
- Niu Mei's math problems
- 传统数据库逐渐“难适应”,云原生数据库脱颖而出
- LeetCode 0107. Sequence traversal of binary tree II - another method
- MPLS experiment
- [2020]GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis
- 20220213-CTF MISC-a_ good_ Idea (use of stegsolve tool) -2017_ Dating_ in_ Singapore
- 开源存储这么香,为何我们还要坚持自研?
- MySQL advanced part 2: the use of indexes
- Sum of three terms (construction)
猜你喜欢
SPI 详解
快速使用Amazon MemoryDB并构建你专属的Redis内存数据库
安装OpenCV--conda建立虚拟环境并在jupyter中添加此环境的kernel
Sqlmap tutorial (1)
LeetCode-61
1.13 - RISC/CISC
MySQL advanced part 2: optimizing SQL steps
Is it impossible for lamda to wake up?
[2020]GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis
Leetcode-6110: number of incremental paths in the grid graph
随机推荐
论文阅读报告
1.15 - input and output system
Dataframe (1): introduction and creation of dataframe
Sqlmap tutorial (1)
【LeetCode】Day94-重塑矩阵
How to understand the definition of sequence limit?
Alibaba established the enterprise digital intelligence service company "Lingyang" to focus on enterprise digital growth
MySQL advanced part 2: SQL optimization
MySQL advanced part 1: stored procedures and functions
In depth analysis of for (VaR I = 0; I < 5; i++) {settimeout (() => console.log (I), 1000)}
区间问题 AcWing 906. 区间分组
P2575 master fight
Presentation of attribute value of an item
MySQL怎么运行的系列(八)14张图说明白MySQL事务原子性和undo日志原理
高斯消元 AcWing 884. 高斯消元解异或线性方程组
MySQL advanced part 1: triggers
LeetCode 0107. Sequence traversal of binary tree II - another method
One question per day 1020 Number of enclaves
The difference between CPU core and logical processor
Is it impossible for lamda to wake up?