当前位置:网站首页>Strengthen basic learning records
Strengthen basic learning records
2022-07-06 13:53:00 【I like the strengthened Xiaobai in Curie】
DDPG Strengthen learning record
DDPG(Deep Deterministic Policy Gradient), be based on Actor-Critic frame , It is proposed to solve the problem of continuous action control . The algorithm is aimed at deterministic strategies , I.e. given state , Choose the definite action , Not like a random strategy , Take a sample .
One 、 Introduction to the environment
What we use here is gym Environmental ’Pendulum-v1’, Give a brief introduction , Detailed introduction with links .
link : OpenAI Gym Introduction to classic control environment ——Pendulum-v1
(1) The rules of the game : The rod rotates , Through training , Make the pole face up .
(2) The state space :
(3) Action space :
I used it here [-1,1] Action space , If you want to change to the same configuration as the environment , Only need to Actor Network implementation tanh Front ride 2, And will choose_action You can also modify it accordingly .
(4) Reward :
(5) Initial state and termination state :
Two 、 A brief introduction to the algorithm
DDPG It can be simply regarded as DQN Algorithm plus Actor-Critic frame ,DDPG Used in the algorithm AC The framework is based on the action value function Q Framework .Actor Learning strategy function ,Critic Learn the action value function Q; The learning ideas in the above figure look very convoluted , But it's easy to understand , Take an easy to understand metaphor :Actor Equivalent to an athlete ,Critic Equivalent to a referee ; Athletes Actor To do action , Athletes must want to do better and better , So as to improve their own technology , The referee graded the athletes , Athletes improve their movements according to this score , Get a higher score , So as to achieve the purpose of improving their own technology ; The referee Critic Also improve yourself , Make your scoring more and more accurate , Only in this way can athletes' skills become higher and higher ,Critic It is the reward given by the environment reward To improve yourself and improve your level .
link : DDPG Reference articleOff-Policy:
Here we use Off-Policy Thought , The idea of experience playback is adopted , Break the correlation between samples .DDPG Characteristics :
<1> Use of experience playback
<2> Use of dual target Networks :
stay DQN in , Through to Q Network set target network and behavior network , To stabilize learning . stay DDPG in , Whether it's Actor Internet or Critic The Internet , Both have target Networks , This is also what many bloggers call the four network way .
<3> The use of soft updates :
And DQN The methods of feeding parameters directly to the target network are different , Here, the target network is updated by soft update , Make learning more stable . The target network does not carry out reverse transmission , Equivalent to a reference , Behavioral networks through training , Back propagation , To approach the target network .
<4> Exploratory :
Due to deterministic action , Make learning less exploratory , So when choosing actions , You can use greedy strategies to increase exploratory , At the same time, Gaussian noise can also be added , Increase randomness . ad locum , Just use greedy tactics .Pseudo code
The implementation refers to the online code , It's been modified , The effect is not good at first , The reward never goes up , Later, I consulted my senior brother in the Laboratory , Found some problems , Finally, the problem was solved by changing around .
The reasons may be summarized as follows :
<1> The experience pool is full batch_size It's about to be updated , It is not updated only when it is full ;
<2> The soft update must be carried out after the training of the two networks ;
<3> Too few steps per game .
Through the curve, you can see , Reward increases , Gradually tend to 0, Because the training is too slow , It just went on 100 Sub iteration , Therefore, the best effect is not achieved .
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym
batch_size = 32
lr_actor = 0.001
lr_critic = 0.001
gamma = 0.90
epsilon = 0.9
episodes = 100
memory_capacity = 10000
#target_replace_iter = 100
tau = 0.02
env = gym.make('Pendulum-v1')
n_actions = env.action_space.shape[0]
n_states = env.observation_space.shape[0]
class Actor(nn.Module):
def __init__(self):
super(Actor, self).__init__()
self.fc1 = nn.Linear(n_states,256)
self.fc2 = nn.Linear(256,256)
self.out = nn.Linear(256,n_actions)
def forward(self,x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
actions = torch.tanh(self.out(x))
return actions
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.fc1 = nn.Linear(n_states + n_actions,256)
self.fc2 = nn.Linear(256,256)
self.q_out = nn.Linear(256,1)
def forward(self,state,aciton):
x = torch.cat([state,aciton],dim=1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
q_value = self.q_out(x)
return q_value
class AC():
def __init__(self):
self.eval_actor,self.target_actor = Actor(),Actor()
self.eval_critic,self.target_critic = Critic(),Critic()
#self.learn_step_counter = 0
self.memory_counter = 0
self.buffer = []
self.actor_optim = torch.optim.Adam(self.eval_actor.parameters(),lr = lr_actor)
self.critic_optim = torch.optim.Adam(self.eval_critic.parameters(),lr = lr_critic)
def choose_action(self,state):
if np.random.uniform() > epsilon:
action = np.random.uniform(-1,1,1)
inputs = torch.tensor(state, dtype=torch.float).unsqueeze(0)
action = self.eval_actor(inputs).squeeze(0)
action = action.detach().numpy()
return action
def store_transition(self,*transition):
if len(self.buffer) == memory_capacity:
def learn(self):
if len(self.buffer) < batch_size:
samples = random.sample(self.buffer,batch_size)
s0, a0, r1, s1 = zip(*samples)
s0 = torch.tensor(s0, dtype=torch.float)
a0 = torch.tensor(a0, dtype=torch.float)
r1 = torch.tensor(r1, dtype=torch.float).view(batch_size, -1)
s1 = torch.tensor(s1, dtype=torch.float)
def critic_learn():
a1 = self.target_actor(s1).detach()
y_true = r1 + gamma * self.target_critic(s1, a1).detach()
y_pred = self.eval_critic(s0, a0)
loss_fn = nn.MSELoss()
loss = loss_fn(y_pred, y_true)
def actor_learn():
loss = -torch.mean(self.eval_critic(s0, self.eval_actor(s0)))
def soft_update(net_target, net, tau):
for target_param, param in zip(net_target.parameters(), net.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
soft_update(self.target_critic, self.eval_critic, tau)
soft_update(self.target_actor, self.eval_actor, tau)
ac = AC()
returns = []
for i in range(episodes):
s = env.reset()
episode_reward_sum = 0
for step in range(500):
# Start a episode ( Each cycle represents a step )
a = ac.choose_action(s)
s_, r, done, info = env.step(a)
ac.store_transition(s, a, r, s_)
episode_reward_sum += r
s = s_
print('episode%s---reward_sum: %s' % (i, round(episode_reward_sum, 2)))
plt.ylabel('avg score')
plt.savefig('./plt_ddpg.png',format= 'png')
- Principles, advantages and disadvantages of two persistence mechanisms RDB and AOF of redis
- 【九阳神功】2021复旦大学应用统计真题+解析
- 7-8 7104 约瑟夫问题(PTA程序设计)
- 强化学习基础记录
- 抽象类和接口的区别
- Miscellaneous talk on May 27
- Caching mechanism of leveldb
- ABA问题遇到过吗,详细说以下,如何避免ABA问题
- [during the interview] - how can I explain the mechanism of TCP to achieve reliable transmission
- The difference between abstract classes and interfaces
优先队列PriorityQueue (大根堆/小根堆/TopK问题)
Custom RPC project - frequently asked questions and explanations (Registration Center)
4. Branch statements and loop statements
7. Relationship between array, pointer and array
FAQs and answers to the imitation Niuke technology blog project (I)
甲、乙机之间采用方式 1 双向串行通信,具体要求如下: (1)甲机的 k1 按键可通过串行口控制乙机的 LEDI 点亮、LED2 灭,甲机的 k2 按键控制 乙机的 LED1
Nuxtjs quick start (nuxt2)
Zatan 0516
Programme de jeu de cartes - confrontation homme - machine
The difference between cookies and sessions
实验五 类和对象
实验七 常用类的使用(修正帖)
Service ability of Hongmeng harmonyos learning notes to realize cross end communication
Get started with typescript
[modern Chinese history] Chapter 6 test
Using qcommonstyle to draw custom form parts
实验八 异常处理
稻 城 亚 丁
The latest tank battle 2022 - Notes on the whole development -2
canvas基础1 - 画直线(通俗易懂)