当前位置:网站首页>协同多智能体学习的价值分解网络的原理与代码复现
协同多智能体学习的价值分解网络的原理与代码复现
2022-07-29 14:23:00 【丰。。】
概念引入
算法思想
VDN可以说是QMIX算法的前身
主要思想是把总的Q分解为多个Q之和,Q即对应智能体的动作价值

即:视多为一
但是也有副作用,那就是,累计出来的Q并不是针对具体情况,具体条件的Q,并没有具体意义。
算法复现
首先定义每个智能体的 QLearner类,将默认参数初始化后,判断为VDN网络,则将self.mixer初始化为VDNMixer(),并将网络参数复制给 self.optimiser使用优化算法RMSprop,其中参数根据类中的默认参数设置。
class QLearner:
def __init__(self,mac,scheme,logger,args):
self.args = args
self.mac=mac
self.logger = logger
self.params = list(mac.parametes())
self.last_target_update_episode=0
self.mixer = None
if args.mixer is not None:
if args.mixer == "vdn":
self.mixer = VDNMixer()
else:
raise ValueError("Mixer {} not recognised.".format(args.mixer))
self.params += list(self.mixer.parametes())
self.target_mixer = copy,deepcopy(self.mixer)
self.optimiser = RMSprop(params = self.params,lr = args.lr,alpha = args.optm_alpha,eps = args.optm_eps)
self.log_stats_t = -self.args.leraner_log_interval -1
在训练函数中,先获取有关参数的值,之后计算估计Q的值,将 agent_outs 存入 mac_out中,循环结束后,将 mac_out第一个维度数值进行叠加,形成新的tensor,最后为每个智能体所采取的操作选择Q值
def train(self,batch:EpisodeBatch,t_env:int,episode_num:int):
rewards = batch["rewards"][:,:-1]#视多为一
actions = batch["actions"][:,:-1]
terminated = batch["terminated"][:,:-1].float()
mask = batch["filled"][:,:-1].float()
mask[:,:-1]=mask[:,:-1]*(1 - terminated[:,:-1])
avail_actions = batch["avail_actions"]
mac_out = []
self.mac.init_hidden(batch.batch_size)
for t in range(batch.max_seq_length):
agent_outs = self.max.forward(batch,t = t)
mac_out.append(agent_outs)
mac_out - torch.stack(mac,dim = 1)
chosen_action_qvals = th.gather(mac_out[:,:-1],dim = 3,index = actions).squeeze(3)
之后计算目标网络所需的Q值,得到 target_mac_out.同上,对第一个维度进行叠加, target_mac_out = th.stack(target_mac_out[1:],dim=1)
剔除不可用动作,target_mac_out[avail_actions[:,1:]==0] = -9999999若设置为双Qlearning则同样的操作用于 mac_out
即原网络的Q值并得到其中的最大值索引,并提取 target_mac_out中对应的值,target_max_qvals = torch.gather(target_mac_out,3,cur_max_actions).squeeze(3)
否则,target_max_qvals = target_mac_out.max(dim=3)[0]
target_mac_out = []
self.target_mac.init_hidden(batch.batch_size)
for t in range(batch.max_seq_length):
target_agent_outs = self.target_mac.forward(batch,t=t)
target_mac_out.append(target_agent_outs)
target_mac_out = th.stack(target_mac_out[1:],dim=1)
target_mac_out[avail_actions[:,1:]==0] = -9999999
if self.args.double_q:
mac_out[avail_actions==0] = -9999999
cur_max_actions = mac_out[:,1:].max(dim=3,keepdim = True)[1]
else:
target_max_qvals = target_mac_out.max(dim=3)[0]
原网络和目标网络分别将操作选择Q值和最大Q值估计输入网络,之后计算1步Q-learning目标,targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals
和 td_error = (chosen_action_qvals - targets.detach())输出来自填充数据的目标,masked_td_error = td_error * mask最后计算l2损失,即实际数据的平均值
if self,mixer is not None:
chosen_action_qvals = self.mixer(chosen_action_qvals,batch["state"][:,:-1])
target_max_qvals = self.target_mixer(target_max_qvals,batch["state"][:,:-1])
targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals
td_error = (chosen_action_qvals - targets.detach())
mask = mask.expend_as(td_error)
masked_td_error = td_error * mask
loss = (masked_td_error **2).sum()/mask.sum()
边栏推荐
猜你喜欢

kubernetes cks strace etcd

84.(cesium之家)cesium模型在地形上运动

蚂蚁三面滑铁卢!遭分布式截胡,靠这些笔记潜修30天,挺进京东

EA&UML日拱一卒-活动图::Feature和StuctualFeature

Children's programming electronics (graphical programming Scratch secondary level exam parsing (choice) in June 2022

国产手机将用户变成它们的广告肉鸡,难怪消费者都买iPhone了

【JS高级】js之闭包对象_04

web会话管理与xss攻击

Map遍历 key-value 的4种方法

通过二维顺序表实现杨辉三角
随机推荐
kubernetes cks strace etcd
换掉 UUID,更快、更安全!
StarRocks 2.3 新版本特性介绍
如何返回一个数字的所有质因数?
dedecms编辑器支持pdf导入
RAMAN 中 OPTIMIZATION 优化选项的作用
<老式喜剧>
Redis-NoSql
FPGA刷题——跨时钟域传输(FIFO+打拍+握手)
【表达式计算】表达式计算问题的通用解法(练习加强版,含总结)
通过二维顺序表实现杨辉三角
Programmers are a group with a high incidence of occupational diseases. Don’t be naive to think that it’s just as simple as being bald.
基于降噪自编码器与改进卷积神经网络的采煤机健康状态识别
《外太空的莫扎特》
EA&UML日拱一卒-活动图::CallOperationAction(续)
为什么字符串使用final关键字
【JS面试题】面试官问我:遍历一个数组用 for 和 forEach 哪个更快?
Chinese Internet technology companies were besieged by wolves. Google finally suffered a severe setback and its profits fell sharply. It regretted promoting the development of Hongmeng...
无线传感器网络定位综述
三 RedisTemplate 序列化机制配置实战