当前位置:网站首页>June 29, 2022 daily
June 29, 2022 daily
2022-07-05 06:23:00 【mentalps】
0 2022/6/29
1 Code
1.1 Calculate cosine similarity between update parameters
import numpy as np
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def weights_reshape(weights):
new_weights = np.hstack((weights[0].reshape(-1), weights[1].reshape(-1), weights[2].reshape(-1), weights[3].reshape(-1)))
return new_weights
def client_cos_sim(client_weights):
all_cos = np.array([[0]*len(client_weights)]*len(client_weights), dtype=float)
for i in range(len(client_weights)):
for j in range(len(client_weights)):
all_cos[i][j] = cos_sim(client_weights[i], client_weights[j])
return all_cos
1.2 Aggregate clients whose similarity is higher than a certain value
import numpy as np
import data
import model
import aggregator
import classification
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
ClientNum = 5
EPOCH = 5
if __name__ == '__main__':
data1, shards, x_test, y_test = data.generate_client_data(ClientNum)
model_client = []
for i in range(ClientNum):
model_client.append(model.Model())
global_model = model.Model()
# shards[0] = data.generate_malice_data(shards[0])
# shards[1] = data.generate_malice_data(shards[1])
all_weights = []
x_all = []
y_all = []
for i in range(ClientNum):
xx, yy = zip(*shards[i])
xx = np.array(xx)
yy = np.array(yy)
x_all.append(xx)
y_all.append(yy)
a_0 = np.argmax(global_model.predict(x_all[0][0:784]))
global_model.saveModel('sever_model')
for i in range(ClientNum):
np.argmax(model_client[i].predict(x_all[0][0:784]))
for i in range(EPOCH):
client_difference_value = []
client_difference_value_reshape = []
for j in range(ClientNum):
model_client[j].setWeights(global_model.getWeights())
for j in range(ClientNum):
model_client[j].run(x_all[j], y_all[j])
all_weights.append(model_client[j].getWeights())
for j in range(ClientNum):
client_difference_value.append(np.array(global_model.getWeights()) - np.array(model_client[j].getWeights()))
client_difference_value_reshape.append(classification.weights_reshape(client_difference_value[j]))
all_cos = classification.client_cos_sim(client_difference_value_reshape)
fedavg = aggregator.FedAvg(global_model, client_difference_value, ClientNum)
global_model.saveModel('sever_model')
x = []
y = []
for i in range(len(x_test)):
if y_test[i] == 1:
x.append(x_test[i])
y.append(y_test[i])
test_loss, test_acc = global_model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)
print(' Number of models :', len(all_weights))
边栏推荐
- Error ora-28547 or ora-03135 when Navicat connects to Oracle Database
- [learning] database: several cases of index failure
- Appium automation test foundation - Summary of appium test environment construction
- 栈 AcWing 3302. 表达式求值
- Redis publish subscribe command line implementation
- 求组合数 AcWing 888. 求组合数 IV
- [leetcode] day94 reshape matrix
- Basic explanation of typescript
- Nested method, calculation attribute is not applicable, use methods
- [moviepy] unable to find a solution for exe
猜你喜欢

SPI details

MySQL advanced part 2: SQL optimization

Open source storage is so popular, why do we insist on self-development?

P2575 master fight

SQLMAP使用教程(一)

Is it impossible for lamda to wake up?

WordPress switches the page, and the domain name changes back to the IP address

NotImplementedError: Cannot convert a symbolic Tensor (yolo_boxes_0/meshgrid/Size_1:0) to a numpy ar

MySQL advanced part 2: MySQL architecture

高斯消元 AcWing 884. 高斯消元解异或线性方程组
随机推荐
[moviepy] unable to find a solution for exe
Redis-02.Redis命令
Filter the numbers and pick out even numbers from several numbers
MPLS experiment
Alibaba established the enterprise digital intelligence service company "Lingyang" to focus on enterprise digital growth
New title of module a of "PanYun Cup" secondary vocational network security skills competition
[rust notes] 17 concurrent (Part 2)
阿里巴巴成立企业数智服务公司“瓴羊”,聚焦企业数字化增长
One question per day 1020 Number of enclaves
[rust notes] 16 input and output (Part 2)
MySQL advanced part 1: index
Redis publish subscribe command line implementation
高斯消元 AcWing 884. 高斯消元解异或线性方程组
WordPress switches the page, and the domain name changes back to the IP address
TypeScript 基础讲解
Alibaba's new member "Lingyang" officially appeared, led by Peng Xinyu, Alibaba's vice president, and assembled a number of core department technical teams
高斯消元 AcWing 884. 高斯消元解异或線性方程組
How to generate an image from text on fly at runtime
背包问题 AcWing 9. 分组背包问题
Data visualization chart summary (I)