当前位置:网站首页>Strengthen basic learning records
Strengthen basic learning records
2022-07-06 13:53:00 【I like the strengthened Xiaobai in Curie】
DDPG Strengthen learning record
DDPG(Deep Deterministic Policy Gradient), be based on Actor-Critic frame , It is proposed to solve the problem of continuous action control . The algorithm is aimed at deterministic strategies , I.e. given state , Choose the definite action , Not like a random strategy , Take a sample .
One 、 Introduction to the environment
What we use here is gym Environmental ’Pendulum-v1’, Give a brief introduction , Detailed introduction with links .
link : OpenAI Gym Introduction to classic control environment ——Pendulum-v1
(1) The rules of the game : The rod rotates , Through training , Make the pole face up .
(2) The state space :
(3) Action space :
I used it here [-1,1] Action space , If you want to change to the same configuration as the environment , Only need to Actor Network implementation tanh Front ride 2, And will choose_action You can also modify it accordingly .
(4) Reward :
(5) Initial state and termination state :
Two 、 A brief introduction to the algorithm
Actor-Critic
DDPG It can be simply regarded as DQN Algorithm plus Actor-Critic frame ,DDPG Used in the algorithm AC The framework is based on the action value function Q Framework .Actor Learning strategy function ,Critic Learn the action value function Q; The learning ideas in the above figure look very convoluted , But it's easy to understand , Take an easy to understand metaphor :Actor Equivalent to an athlete ,Critic Equivalent to a referee ; Athletes Actor To do action , Athletes must want to do better and better , So as to improve their own technology , The referee graded the athletes , Athletes improve their movements according to this score , Get a higher score , So as to achieve the purpose of improving their own technology ; The referee Critic Also improve yourself , Make your scoring more and more accurate , Only in this way can athletes' skills become higher and higher ,Critic It is the reward given by the environment reward To improve yourself and improve your level .
link : DDPG Reference articleOff-Policy:
Here we use Off-Policy Thought , The idea of experience playback is adopted , Break the correlation between samples .DDPG Characteristics :
<1> Use of experience playback

<2> Use of dual target Networks :
stay DQN in , Through to Q Network set target network and behavior network , To stabilize learning . stay DDPG in , Whether it's Actor Internet or Critic The Internet , Both have target Networks , This is also what many bloggers call the four network way .
<3> The use of soft updates :
And DQN The methods of feeding parameters directly to the target network are different , Here, the target network is updated by soft update , Make learning more stable . The target network does not carry out reverse transmission , Equivalent to a reference , Behavioral networks through training , Back propagation , To approach the target network .
<4> Exploratory :
Due to deterministic action , Make learning less exploratory , So when choosing actions , You can use greedy strategies to increase exploratory , At the same time, Gaussian noise can also be added , Increase randomness . ad locum , Just use greedy tactics .Pseudo code

Realization
The implementation refers to the online code , It's been modified , The effect is not good at first , The reward never goes up , Later, I consulted my senior brother in the Laboratory , Found some problems , Finally, the problem was solved by changing around .
The reasons may be summarized as follows :
<1> The experience pool is full batch_size It's about to be updated , It is not updated only when it is full ;
<2> The soft update must be carried out after the training of the two networks ;
<3> Too few steps per game .
Through the curve, you can see , Reward increases , Gradually tend to 0, Because the training is too slow , It just went on 100 Sub iteration , Therefore, the best effect is not achieved .
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym
batch_size = 32
lr_actor = 0.001
lr_critic = 0.001
gamma = 0.90
epsilon = 0.9
episodes = 100
memory_capacity = 10000
#target_replace_iter = 100
tau = 0.02
env = gym.make('Pendulum-v1')
n_actions = env.action_space.shape[0]
n_states = env.observation_space.shape[0]
class Actor(nn.Module):
def __init__(self):
super(Actor, self).__init__()
self.fc1 = nn.Linear(n_states,256)
self.fc2 = nn.Linear(256,256)
self.out = nn.Linear(256,n_actions)
def forward(self,x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
actions = torch.tanh(self.out(x))
return actions
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.fc1 = nn.Linear(n_states + n_actions,256)
self.fc2 = nn.Linear(256,256)
self.q_out = nn.Linear(256,1)
def forward(self,state,aciton):
x = torch.cat([state,aciton],dim=1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
q_value = self.q_out(x)
return q_value
class AC():
def __init__(self):
self.eval_actor,self.target_actor = Actor(),Actor()
self.eval_critic,self.target_critic = Critic(),Critic()
#self.learn_step_counter = 0
self.memory_counter = 0
self.buffer = []
self.target_actor.load_state_dict(self.eval_actor.state_dict())
self.target_critic.load_state_dict(self.eval_critic.state_dict())
self.actor_optim = torch.optim.Adam(self.eval_actor.parameters(),lr = lr_actor)
self.critic_optim = torch.optim.Adam(self.eval_critic.parameters(),lr = lr_critic)
def choose_action(self,state):
if np.random.uniform() > epsilon:
action = np.random.uniform(-1,1,1)
else:
inputs = torch.tensor(state, dtype=torch.float).unsqueeze(0)
action = self.eval_actor(inputs).squeeze(0)
action = action.detach().numpy()
return action
def store_transition(self,*transition):
if len(self.buffer) == memory_capacity:
self.buffer.pop(0)
self.buffer.append(transition)
def learn(self):
if len(self.buffer) < batch_size:
return
samples = random.sample(self.buffer,batch_size)
s0, a0, r1, s1 = zip(*samples)
s0 = torch.tensor(s0, dtype=torch.float)
a0 = torch.tensor(a0, dtype=torch.float)
r1 = torch.tensor(r1, dtype=torch.float).view(batch_size, -1)
s1 = torch.tensor(s1, dtype=torch.float)
def critic_learn():
a1 = self.target_actor(s1).detach()
y_true = r1 + gamma * self.target_critic(s1, a1).detach()
y_pred = self.eval_critic(s0, a0)
loss_fn = nn.MSELoss()
loss = loss_fn(y_pred, y_true)
self.critic_optim.zero_grad()
loss.backward()
self.critic_optim.step()
def actor_learn():
loss = -torch.mean(self.eval_critic(s0, self.eval_actor(s0)))
self.actor_optim.zero_grad()
loss.backward()
self.actor_optim.step()
def soft_update(net_target, net, tau):
for target_param, param in zip(net_target.parameters(), net.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
critic_learn()
actor_learn()
soft_update(self.target_critic, self.eval_critic, tau)
soft_update(self.target_actor, self.eval_actor, tau)
ac = AC()
returns = []
for i in range(episodes):
s = env.reset()
episode_reward_sum = 0
for step in range(500):
# Start a episode ( Each cycle represents a step )
#env.render()
a = ac.choose_action(s)
s_, r, done, info = env.step(a)
ac.store_transition(s, a, r, s_)
episode_reward_sum += r
s = s_
ac.learn()
returns.append(episode_reward_sum)
print('episode%s---reward_sum: %s' % (i, round(episode_reward_sum, 2)))
plt.figure()
plt.plot(range(len(returns)),returns)
plt.xlabel('episodes')
plt.ylabel('avg score')
plt.savefig('./plt_ddpg.png',format= 'png')

边栏推荐
- canvas基础2 - arc - 画弧线
- 5月14日杂谈
- 抽象类和接口的区别
- 仿牛客技术博客项目常见问题及解答(一)
- 力扣152题乘数最大子数组
- Differences among fianl, finally, and finalize
- 受检异常和非受检异常的区别和理解
- 7. Relationship between array, pointer and array
- . How to upload XMIND files to Jinshan document sharing online editing?
- 2022 Teddy cup data mining challenge question C idea and post game summary
猜你喜欢

Relationship between hashcode() and equals()

A piece of music composed by buzzer (Chengdu)

强化学习基础记录

C language Getting Started Guide

Mortal immortal cultivation pointer-1

SRC mining ideas and methods

Reinforcement learning series (I): basic principles and concepts

记一次猫舍由外到内的渗透撞库操作提取-flag

A comprehensive summary of MySQL transactions and implementation principles, and no longer have to worry about interviews

9. Pointer (upper)
随机推荐
[面試時]——我如何講清楚TCP實現可靠傳輸的機制
[面试时]——我如何讲清楚TCP实现可靠传输的机制
Differences among fianl, finally, and finalize
ArrayList的自动扩容机制实现原理
【九阳神功】2022复旦大学应用统计真题+解析
Nuxtjs快速上手(Nuxt2)
Matlab opens M file garbled solution
[the Nine Yang Manual] 2017 Fudan University Applied Statistics real problem + analysis
这次,彻底搞清楚MySQL索引
[the Nine Yang Manual] 2019 Fudan University Applied Statistics real problem + analysis
C语言入门指南
. How to upload XMIND files to Jinshan document sharing online editing?
Leetcode. 3. Longest substring without repeated characters - more than 100% solution
强化学习基础记录
强化学习基础记录
Nuxtjs quick start (nuxt2)
甲、乙机之间采用方式 1 双向串行通信,具体要求如下: (1)甲机的 k1 按键可通过串行口控制乙机的 LEDI 点亮、LED2 灭,甲机的 k2 按键控制 乙机的 LED1
Mode 1 two-way serial communication is adopted between machine a and machine B, and the specific requirements are as follows: (1) the K1 key of machine a can control the ledi of machine B to turn on a
附加简化版示例数据库到SqlServer数据库实例中
杂谈0516