当前位置:网站首页>June 29, 2022 daily
June 29, 2022 daily
2022-07-05 06:23:00 【mentalps】
0 2022/6/29
1 Code
1.1 Calculate cosine similarity between update parameters
import numpy as np
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def weights_reshape(weights):
new_weights = np.hstack((weights[0].reshape(-1), weights[1].reshape(-1), weights[2].reshape(-1), weights[3].reshape(-1)))
return new_weights
def client_cos_sim(client_weights):
all_cos = np.array([[0]*len(client_weights)]*len(client_weights), dtype=float)
for i in range(len(client_weights)):
for j in range(len(client_weights)):
all_cos[i][j] = cos_sim(client_weights[i], client_weights[j])
return all_cos
1.2 Aggregate clients whose similarity is higher than a certain value
import numpy as np
import data
import model
import aggregator
import classification
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
ClientNum = 5
EPOCH = 5
if __name__ == '__main__':
data1, shards, x_test, y_test = data.generate_client_data(ClientNum)
model_client = []
for i in range(ClientNum):
model_client.append(model.Model())
global_model = model.Model()
# shards[0] = data.generate_malice_data(shards[0])
# shards[1] = data.generate_malice_data(shards[1])
all_weights = []
x_all = []
y_all = []
for i in range(ClientNum):
xx, yy = zip(*shards[i])
xx = np.array(xx)
yy = np.array(yy)
x_all.append(xx)
y_all.append(yy)
a_0 = np.argmax(global_model.predict(x_all[0][0:784]))
global_model.saveModel('sever_model')
for i in range(ClientNum):
np.argmax(model_client[i].predict(x_all[0][0:784]))
for i in range(EPOCH):
client_difference_value = []
client_difference_value_reshape = []
for j in range(ClientNum):
model_client[j].setWeights(global_model.getWeights())
for j in range(ClientNum):
model_client[j].run(x_all[j], y_all[j])
all_weights.append(model_client[j].getWeights())
for j in range(ClientNum):
client_difference_value.append(np.array(global_model.getWeights()) - np.array(model_client[j].getWeights()))
client_difference_value_reshape.append(classification.weights_reshape(client_difference_value[j]))
all_cos = classification.client_cos_sim(client_difference_value_reshape)
fedavg = aggregator.FedAvg(global_model, client_difference_value, ClientNum)
global_model.saveModel('sever_model')
x = []
y = []
for i in range(len(x_test)):
if y_test[i] == 1:
x.append(x_test[i])
y.append(y_test[i])
test_loss, test_acc = global_model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)
print(' Number of models :', len(all_weights))
边栏推荐
- LeetCode 0107. Sequence traversal of binary tree II - another method
- Presentation of attribute value of an item
- Single chip computer engineering experience - layered idea
- 快速使用Amazon MemoryDB并构建你专属的Redis内存数据库
- The difference between CPU core and logical processor
- Leetcode-6109: number of people who know secrets
- [rust notes] 16 input and output (Part 2)
- Simple selection sort of selection sort
- 阿里新成员「瓴羊」正式亮相,由阿里副总裁朋新宇带队,集结多个核心部门技术团队
- Navicat连接Oracle数据库报错ORA-28547或ORA-03135
猜你喜欢
Redis-01.初识Redis
Redis publish subscribe command line implementation
Network security skills competition in Secondary Vocational Schools -- a tutorial article on middleware penetration testing in Guangxi regional competition
Navicat连接Oracle数据库报错ORA-28547或ORA-03135
[2021]IBRNet: Learning Multi-View Image-Based Rendering Qianqian
Groupbykey() and reducebykey() and combinebykey() in spark
Gauss Cancellation acwing 884. Solution d'un système d'équations Xor linéaires par élimination gaussienne
区间问题 AcWing 906. 区间分组
Sorting out the latest Android interview points in 2022 to help you easily win the offer - attached is the summary of Android intermediate and advanced interview questions in 2022
Real time clock (RTC)
随机推荐
阿里新成员「瓴羊」正式亮相,由阿里副总裁朋新宇带队,集结多个核心部门技术团队
Dataframe (1): introduction and creation of dataframe
高斯消元 AcWing 884. 高斯消元解异或線性方程組
什么是套接字?Socket基本介绍
求组合数 AcWing 888. 求组合数 IV
求组合数 AcWing 889. 满足条件的01序列
What is socket? Basic introduction to socket
SPI details
[rust notes] 14 set (Part 2)
Modnet matting model reproduction
LeetCode 0107. Sequence traversal of binary tree II - another method
Network security skills competition in Secondary Vocational Schools -- a tutorial article on middleware penetration testing in Guangxi regional competition
The difference between CPU core and logical processor
博弈论 AcWing 894. 拆分-Nim游戏
求组合数 AcWing 887. 求组合数 III
How to generate an image from text on fly at runtime
Leetcode-31: next spread
[BMZCTF-pwn] ectf-2014 seddit
2022/6/29-日报
MPLS experiment