当前位置:网站首页>June 29, 2022 daily
June 29, 2022 daily
2022-07-05 06:23:00 【mentalps】
0 2022/6/29
1 Code
1.1 Calculate cosine similarity between update parameters
import numpy as np
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def weights_reshape(weights):
new_weights = np.hstack((weights[0].reshape(-1), weights[1].reshape(-1), weights[2].reshape(-1), weights[3].reshape(-1)))
return new_weights
def client_cos_sim(client_weights):
all_cos = np.array([[0]*len(client_weights)]*len(client_weights), dtype=float)
for i in range(len(client_weights)):
for j in range(len(client_weights)):
all_cos[i][j] = cos_sim(client_weights[i], client_weights[j])
return all_cos
1.2 Aggregate clients whose similarity is higher than a certain value
import numpy as np
import data
import model
import aggregator
import classification
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
ClientNum = 5
EPOCH = 5
if __name__ == '__main__':
data1, shards, x_test, y_test = data.generate_client_data(ClientNum)
model_client = []
for i in range(ClientNum):
model_client.append(model.Model())
global_model = model.Model()
# shards[0] = data.generate_malice_data(shards[0])
# shards[1] = data.generate_malice_data(shards[1])
all_weights = []
x_all = []
y_all = []
for i in range(ClientNum):
xx, yy = zip(*shards[i])
xx = np.array(xx)
yy = np.array(yy)
x_all.append(xx)
y_all.append(yy)
a_0 = np.argmax(global_model.predict(x_all[0][0:784]))
global_model.saveModel('sever_model')
for i in range(ClientNum):
np.argmax(model_client[i].predict(x_all[0][0:784]))
for i in range(EPOCH):
client_difference_value = []
client_difference_value_reshape = []
for j in range(ClientNum):
model_client[j].setWeights(global_model.getWeights())
for j in range(ClientNum):
model_client[j].run(x_all[j], y_all[j])
all_weights.append(model_client[j].getWeights())
for j in range(ClientNum):
client_difference_value.append(np.array(global_model.getWeights()) - np.array(model_client[j].getWeights()))
client_difference_value_reshape.append(classification.weights_reshape(client_difference_value[j]))
all_cos = classification.client_cos_sim(client_difference_value_reshape)
fedavg = aggregator.FedAvg(global_model, client_difference_value, ClientNum)
global_model.saveModel('sever_model')
x = []
y = []
for i in range(len(x_test)):
if y_test[i] == 1:
x.append(x_test[i])
y.append(y_test[i])
test_loss, test_acc = global_model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)
print(' Number of models :', len(all_weights))
边栏推荐
- Nested method, calculation attribute is not applicable, use methods
- 1.13 - RISC/CISC
- 阿里新成员「瓴羊」正式亮相,由阿里副总裁朋新宇带队,集结多个核心部门技术团队
- 背包问题 AcWing 9. 分组背包问题
- Record the process of configuring nccl and horovod in these two days (original)
- [rust notes] 15 string and text (Part 1)
- 4.Oracle-重做日志文件管理
- Alibaba established the enterprise digital intelligence service company "Lingyang" to focus on enterprise digital growth
- Sword finger offer II 058: schedule
- TypeScript 基础讲解
猜你喜欢

There are three kinds of SQL connections: internal connection, external connection and cross connection

LeetCode 0107. Sequence traversal of binary tree II - another method

MySQL advanced part 1: stored procedures and functions

【LeetCode】Easy | 20. Valid parentheses

Leetcode stack related

1.手动创建Oracle数据库

Gauss Cancellation acwing 884. Solution d'un système d'équations Xor linéaires par élimination gaussienne

快速使用Amazon MemoryDB并构建你专属的Redis内存数据库

阿里新成员「瓴羊」正式亮相,由阿里副总裁朋新宇带队,集结多个核心部门技术团队

4. Object mapping Mapster
随机推荐
Erreur de connexion Navicat à la base de données Oracle Ora - 28547 ou Ora - 03135
AE tutorial - path growth animation
FFmpeg build下载(包含old version)
LeetCode-54
传统数据库逐渐“难适应”,云原生数据库脱颖而出
Appium automation test foundation - Summary of appium test environment construction
P3265 [jloi2015] equipment purchase
2.Oracle-数据文件的添加及管理
[2020]GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis
SQL三种连接:内连接、外连接、交叉连接
2022/6/29-日报
11-gorm-v2-02-create data
MySQL advanced part 2: storage engine
Presentation of attribute value of an item
Navicat连接Oracle数据库报错ORA-28547或ORA-03135
[rust notes] 16 input and output (Part 2)
MySQL advanced part 2: MySQL architecture
Leetcode-6108: decrypt messages
[rust notes] 17 concurrent (Part 1)
[2021]GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields