当前位置:网站首页>2022/6/29-日报
2022/6/29-日报
2022-07-05 06:18:00 【mentalps】
0 2022/6/29
1 代码
1.1 计算更新参数之间的余弦相似度
import numpy as np
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def weights_reshape(weights):
new_weights = np.hstack((weights[0].reshape(-1), weights[1].reshape(-1), weights[2].reshape(-1), weights[3].reshape(-1)))
return new_weights
def client_cos_sim(client_weights):
all_cos = np.array([[0]*len(client_weights)]*len(client_weights), dtype=float)
for i in range(len(client_weights)):
for j in range(len(client_weights)):
all_cos[i][j] = cos_sim(client_weights[i], client_weights[j])
return all_cos
1.2 把相似度高于一定值的客户端聚合在一起
import numpy as np
import data
import model
import aggregator
import classification
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
ClientNum = 5
EPOCH = 5
if __name__ == '__main__':
data1, shards, x_test, y_test = data.generate_client_data(ClientNum)
model_client = []
for i in range(ClientNum):
model_client.append(model.Model())
global_model = model.Model()
# shards[0] = data.generate_malice_data(shards[0])
# shards[1] = data.generate_malice_data(shards[1])
all_weights = []
x_all = []
y_all = []
for i in range(ClientNum):
xx, yy = zip(*shards[i])
xx = np.array(xx)
yy = np.array(yy)
x_all.append(xx)
y_all.append(yy)
a_0 = np.argmax(global_model.predict(x_all[0][0:784]))
global_model.saveModel('sever_model')
for i in range(ClientNum):
np.argmax(model_client[i].predict(x_all[0][0:784]))
for i in range(EPOCH):
client_difference_value = []
client_difference_value_reshape = []
for j in range(ClientNum):
model_client[j].setWeights(global_model.getWeights())
for j in range(ClientNum):
model_client[j].run(x_all[j], y_all[j])
all_weights.append(model_client[j].getWeights())
for j in range(ClientNum):
client_difference_value.append(np.array(global_model.getWeights()) - np.array(model_client[j].getWeights()))
client_difference_value_reshape.append(classification.weights_reshape(client_difference_value[j]))
all_cos = classification.client_cos_sim(client_difference_value_reshape)
fedavg = aggregator.FedAvg(global_model, client_difference_value, ClientNum)
global_model.saveModel('sever_model')
x = []
y = []
for i in range(len(x_test)):
if y_test[i] == 1:
x.append(x_test[i])
y.append(y_test[i])
test_loss, test_acc = global_model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)
print('模型个数:', len(all_weights))
边栏推荐
- Leetcode-6109: number of people who know secrets
- 快速使用Amazon MemoryDB并构建你专属的Redis内存数据库
- MySQL advanced part 2: SQL optimization
- Nested method, calculation attribute is not applicable, use methods
- MySQL advanced part 1: View
- 【LeetCode】Day94-重塑矩阵
- 求组合数 AcWing 888. 求组合数 IV
- 11-gorm-v2-03-basic query
- Error ora-28547 or ora-03135 when Navicat connects to Oracle Database
- [rust notes] 14 set (Part 2)
猜你喜欢
Navicat连接Oracle数据库报错ORA-28547或ORA-03135
Alibaba established the enterprise digital intelligence service company "Lingyang" to focus on enterprise digital growth
RGB LED infinite mirror controlled by Arduino
Sqlmap tutorial (II) practical skills I
MySQL advanced part 1: View
Sqlmap tutorial (1)
1.13 - RISC/CISC
Leetcode-6111: spiral matrix IV
LVS简介【暂未完成(半成品)】
SPI details
随机推荐
P2575 master fight
【Rust 笔记】15-字符串与文本(上)
做 SQL 性能优化真是让人干瞪眼
Winter messenger 2
What is socket? Basic introduction to socket
SQL三种连接:内连接、外连接、交叉连接
[rust notes] 17 concurrent (Part 2)
C - XOR to all (binary topic)
1.13 - RISC/CISC
Open source storage is so popular, why do we insist on self-development?
求组合数 AcWing 888. 求组合数 IV
MySQL advanced part 1: View
[rust notes] 15 string and text (Part 1)
Leetcode backtracking method
[rust notes] 16 input and output (Part 1)
[leetcode] day94 reshape matrix
How to generate an image from text on fly at runtime
Alibaba established the enterprise digital intelligence service company "Lingyang" to focus on enterprise digital growth
中国剩余定理 AcWing 204. 表达整数的奇怪方式
11-gorm-v2-02-create data