当前位置:网站首页>强化学习基础记录
强化学习基础记录
2022-07-06 09:22:00 【喜欢库里的强化小白】
DDPG(Deep Deterministic Policy Gradient),基于Actor-Critic框架,是为了解决连续动作控制问题而提出的。该算法针对确定性策略,即给定状态,选出确定动作,而不是像随机性策略那样,进行抽样。
一、环境介绍
这里使用的是gym环境的’Pendulum-v1’,做简要介绍,详细介绍附上链接。
链接: OpenAI Gym 经典控制环境介绍——Pendulum-v1
(1)游戏规则:杆子进行转动,通过训练,使杆子能够朝上。
(2)状态空间:
(3)动作空间:
这里我用了[-1,1]的动作空间,如果要改成和环境一样的配置,只需要将Actor网络实现中的tanh前乘2,并将choose_action中也进行相应修改即可。
(4)奖励:
(5)初始状态和终止状态:
二、算法简单介绍
Actor-Critic
DDPG可以简单的看成是DQN算法加上Actor-Critic框架,DDPG算法中所使用的AC框架是基于动作价值函数Q的框架。Actor学习策略函数,Critic学习动作价值函数Q;上图中的学习思路看着很绕,但是其实很好理解,举一个容易理解的比喻:Actor相当于一个运动员,Critic相当于一个裁判;运动员Actor要做动作,运动员肯定是想要动作做得越来越好,从而提高自己的技术,裁判给运动员打分,运动员根据这个打分来改进自己的动作,得到更高的分数,从而达到改进自己技术的目的;裁判Critic也要提高自己,让自己的打分越来越精准,这样才能让运动员的技术也越来越高,Critic是靠环境给的奖励reward来改进自己提高自己的水平。
链接: DDPG参考文章Off-Policy:
这里采用了Off-Policy的思想,采用了经验回放的思想,打破样本之间的关联性。DDPG的特点:
<1>经验回放的使用
<2>双目标网络的使用:
在DQN中,通过对Q网络设置目标网络和行为网络,来稳定学习。在DDPG中,无论是Actor网络还是Critic网络,都具有目标网络,这也是很多博主所说的四网络方式。
<3>软更新的使用:
与DQN直接给目标网络喂参数的方法不同,这里采用软更新的方式对目标网络进行更新,使得学习更加稳定。其中目标网络不进行反向传递,相当于一个参照,行为网络通过训练,进行反向传播,以接近目标网络。
<4>探索性:
由于确定性动作,使得学习的探索性会降低,因此在选择动作的时候,可以使用贪心策略来增加探索性,同时也可以通过加入高斯噪声等方式,增加随机性。在这里,仅仅采用贪心策略。伪代码
实现
实现参考了网上的代码,进行了修改,起初效果不好,奖励一直上不去,后来请教了实验室的师兄,找到了一些问题,最后改来改去解决了问题。
总结原因可能有以下几个:
<1>经验池存满batch_size就要进行更新,不是存满才更新;
<2>软更新一定要在两个网络进行训练之后进行;
<3>每局的步数太少。
通过曲线可以看到,奖励升高,逐渐趋于0,由于训练太慢,只进行了100次迭代,因此没有达到最理想的效果。
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym
batch_size = 32
lr_actor = 0.001
lr_critic = 0.001
gamma = 0.90
epsilon = 0.9
episodes = 100
memory_capacity = 10000
#target_replace_iter = 100
tau = 0.02
env = gym.make('Pendulum-v1')
n_actions = env.action_space.shape[0]
n_states = env.observation_space.shape[0]
class Actor(nn.Module):
def __init__(self):
super(Actor, self).__init__()
self.fc1 = nn.Linear(n_states,256)
self.fc2 = nn.Linear(256,256)
self.out = nn.Linear(256,n_actions)
def forward(self,x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
actions = torch.tanh(self.out(x))
return actions
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.fc1 = nn.Linear(n_states + n_actions,256)
self.fc2 = nn.Linear(256,256)
self.q_out = nn.Linear(256,1)
def forward(self,state,aciton):
x = torch.cat([state,aciton],dim=1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
q_value = self.q_out(x)
return q_value
class AC():
def __init__(self):
self.eval_actor,self.target_actor = Actor(),Actor()
self.eval_critic,self.target_critic = Critic(),Critic()
#self.learn_step_counter = 0
self.memory_counter = 0
self.buffer = []
self.target_actor.load_state_dict(self.eval_actor.state_dict())
self.target_critic.load_state_dict(self.eval_critic.state_dict())
self.actor_optim = torch.optim.Adam(self.eval_actor.parameters(),lr = lr_actor)
self.critic_optim = torch.optim.Adam(self.eval_critic.parameters(),lr = lr_critic)
def choose_action(self,state):
if np.random.uniform() > epsilon:
action = np.random.uniform(-1,1,1)
else:
inputs = torch.tensor(state, dtype=torch.float).unsqueeze(0)
action = self.eval_actor(inputs).squeeze(0)
action = action.detach().numpy()
return action
def store_transition(self,*transition):
if len(self.buffer) == memory_capacity:
self.buffer.pop(0)
self.buffer.append(transition)
def learn(self):
if len(self.buffer) < batch_size:
return
samples = random.sample(self.buffer,batch_size)
s0, a0, r1, s1 = zip(*samples)
s0 = torch.tensor(s0, dtype=torch.float)
a0 = torch.tensor(a0, dtype=torch.float)
r1 = torch.tensor(r1, dtype=torch.float).view(batch_size, -1)
s1 = torch.tensor(s1, dtype=torch.float)
def critic_learn():
a1 = self.target_actor(s1).detach()
y_true = r1 + gamma * self.target_critic(s1, a1).detach()
y_pred = self.eval_critic(s0, a0)
loss_fn = nn.MSELoss()
loss = loss_fn(y_pred, y_true)
self.critic_optim.zero_grad()
loss.backward()
self.critic_optim.step()
def actor_learn():
loss = -torch.mean(self.eval_critic(s0, self.eval_actor(s0)))
self.actor_optim.zero_grad()
loss.backward()
self.actor_optim.step()
def soft_update(net_target, net, tau):
for target_param, param in zip(net_target.parameters(), net.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
critic_learn()
actor_learn()
soft_update(self.target_critic, self.eval_critic, tau)
soft_update(self.target_actor, self.eval_actor, tau)
ac = AC()
returns = []
for i in range(episodes):
s = env.reset()
episode_reward_sum = 0
for step in range(500):
# 开始一个episode (每一个循环代表一步)
#env.render()
a = ac.choose_action(s)
s_, r, done, info = env.step(a)
ac.store_transition(s, a, r, s_)
episode_reward_sum += r
s = s_
ac.learn()
returns.append(episode_reward_sum)
print('episode%s---reward_sum: %s' % (i, round(episode_reward_sum, 2)))
plt.figure()
plt.plot(range(len(returns)),returns)
plt.xlabel('episodes')
plt.ylabel('avg score')
plt.savefig('./plt_ddpg.png',format= 'png')
边栏推荐
- [dark horse morning post] Shanghai Municipal Bureau of supervision responded that Zhong Xue had a high fever and did not melt; Michael admitted that two batches of pure milk were unqualified; Wechat i
- Thoroughly understand LRU algorithm - explain 146 questions in detail and eliminate LRU cache in redis
- [中国近代史] 第五章测验
- 甲、乙机之间采用方式 1 双向串行通信,具体要求如下: (1)甲机的 k1 按键可通过串行口控制乙机的 LEDI 点亮、LED2 灭,甲机的 k2 按键控制 乙机的 LED1
- String abc = new String(“abc“),到底创建了几个对象
- MySQL锁总结(全面简洁 + 图文详解)
- 【毕业季·进击的技术er】再见了,我的学生时代
- Implementation of count (*) in MySQL
- 6. Function recursion
- 记一次猫舍由外到内的渗透撞库操作提取-flag
猜你喜欢
记一次猫舍由外到内的渗透撞库操作提取-flag
C language Getting Started Guide
QT meta object qmetaobject indexofslot and other functions to obtain class methods attention
9. Pointer (upper)
A piece of music composed by buzzer (Chengdu)
fianl、finally、finalize三者的区别
4. Binary search
7. Relationship between array, pointer and array
Cookie和Session的区别
Redis的两种持久化机制RDB和AOF的原理和优缺点
随机推荐
Detailed explanation of redis' distributed lock principle
Redis的两种持久化机制RDB和AOF的原理和优缺点
JS several ways to judge whether an object is an array
(original) make an electronic clock with LCD1602 display to display the current time on the LCD. The display format is "hour: minute: Second: second". There are four function keys K1 ~ K4, and the fun
The latest tank battle 2022 - full development notes-3
FAQs and answers to the imitation Niuke technology blog project (I)
1. First knowledge of C language (1)
[graduation season · advanced technology Er] goodbye, my student days
7-9 制作门牌号3.0(PTA程序设计)
Cookie和Session的区别
(原创)制作一个采用 LCD1602 显示的电子钟,在 LCD 上显示当前的时间。显示格式为“时时:分分:秒秒”。设有 4 个功能键k1~k4,功能如下:(1)k1——进入时间修改。
9. Pointer (upper)
7-15 h0161. 求最大公约数和最小公倍数(PTA程序设计)
7-1 输出2到n之间的全部素数(PTA程序设计)
2022 Teddy cup data mining challenge question C idea and post game summary
Have you encountered ABA problems? Let's talk about the following in detail, how to avoid ABA problems
C语言入门指南
[the Nine Yang Manual] 2017 Fudan University Applied Statistics real problem + analysis
简单理解ES6的Promise
【毕业季·进击的技术er】再见了,我的学生时代