当前位置:网站首页>Strengthen basic learning records
Strengthen basic learning records
2022-07-06 13:52:00 【I like the strengthened Xiaobai in Curie】
Actor-Critic Strengthen learning record
Reinforcement learning algorithms are roughly divided into three categories ,value-based、policy-based And the combination of the two Actor-Critic, Here is a brief description of the recent right AC Learning experience of .
One 、 Introduction to the environment
What we use here is gym Environmental ’CartPole-v1’, This environment is similar to that of the previous article ’CartPole-v0’ There's almost no difference , The main difference lies in the definition of the maximum number of steps per round and the reward , As shown in the figure below .
In this paper , Want to try to combine On-Policy The algorithm of , Therefore, the maximum number of steps in a single round is limited , The size is 100.
'CartPole-v0’ The detailed introduction of the environment is attached with a link .
link : OpenAI Gym Introduction to classic control environment ——CartPole( Inverted pendulum )
Two 、 A brief introduction to the algorithm
- Actor-Critic
The algorithm has two frameworks , That is, strategy related Actor Network and value related Critic The Internet . Because the randomness strategy is adopted here , therefore Actor The Internet takes advantage of softmax Function normalizes the probability ;Critic For network utilization v Values are calculated . Besides , So this is taking advantage of A2C The dominance function of (Advantage). - On-Policy
Here we take On-Policy The algorithm of , Pay attention to each round 100 Step game , Will produce 100 strip transition, Wait for these transition After storage , Begin to learn , Use this directly 100 Samples , And empty the sample , In order to get new samples in the next round . - AC(A2C) Pseudo code :
- Realization
The implementation here refers to the online tutorial , But the source code is just Policy-Gradient Methods , Here is a simple modification . Besides , Here is the randomness strategy , Itself increases the exploratory , Different from the previous deterministic strategy , Yes torch The sampling function of , The details have not been studied . The results are also attached in the figure below , You can see that after training , Rewards basically converge to 100.
import gym
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt
# Hyperparameters
learning_rate = 0.0002
gamma = 0.98
n_rollout = 100
MAX_EPISODE = 20000
RENDER = False
env = gym.make('CartPole-v1')
env = env.unwrapped
env.seed(1)
torch.manual_seed(1)
#print("env.action_space :", env.action_space)
#print("env.observation_space :", env.observation_space)
n_features = env.observation_space.shape[0]
n_actions = env.action_space.n
class ActorCritic(nn.Module):
def __init__(self):
super(ActorCritic, self).__init__()
self.data = []
hidden_dims = 256
self.feature_layer = nn.Sequential(nn.Linear(n_features, hidden_dims),
nn.ReLU())
self.fc_pi = nn.Linear(hidden_dims, n_actions)
self.fc_v = nn.Linear(hidden_dims, 1)
self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
def pi(self, x):
x = self.feature_layer(x)
x = self.fc_pi(x)
prob = F.softmax(x, dim=-1)
return prob
def v(self, x):
x = self.feature_layer(x)
v = self.fc_v(x)
return v
def put_data(self, transition):
self.data.append(transition)
def make_batch(self):
s_lst, a_lst, r_lst, s_next_lst, done_lst = [], [], [], [], []
for transition in self.data:
s, a, r, s_, done = transition
s_lst.append(s)
a_lst.append([a])
r_lst.append([r / 100.0])
s_next_lst.append(s_)
done_mask = 0.0 if done else 1.0
done_lst.append([done_mask])
s_batch, a_batch, r_batch, s_next_batch, done_batch = torch.tensor(numpy.array(s_lst),
dtype=torch.float), torch.tensor(
a_lst), torch.tensor(numpy.array(r_lst), dtype=torch.float), torch.tensor(
numpy.array(s_next_lst), dtype=torch.float), torch.tensor(
numpy.array(done_lst), dtype=torch.float)
self.data = []
return s_batch, a_batch, r_batch, s_next_batch, done_batch
def train_net(self):
s, a, r, s_, done = self.make_batch()
td_target = r + gamma * self.v(s_) * done
delta = td_target - self.v(s)
def critic_learn():
loss_func = nn.MSELoss()
loss1 = loss_func(self.v(s),td_target)
self.optimizer.zero_grad()
loss1.backward()
self.optimizer.step()
def actor_learn():
pi = self.pi(s)
pi_a = pi.gather(1, a)
loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach())
self.optimizer.zero_grad()
loss.mean().backward()
self.optimizer.step()
critic_learn()
actor_learn()
def main():
model = ActorCritic()
print_interval = 20
score = 0.0
avg_returns = []
for n_epi in range(MAX_EPISODE):
s = env.reset()
for t in range(n_rollout):
prob = model.pi(torch.from_numpy(s).float())
m = Categorical(prob)
a = m.sample().item()
s_next, r, done, info = env.step(a)
model.put_data((s, a, r, s_next, done))
s = s_next
score += r
model.train_net()
if n_epi % print_interval == 0 and n_epi != 0:
avg_score = score / print_interval
print("# of episode :{}, avg score : {:.1f}".format(n_epi, score / print_interval))
avg_returns.append(avg_score)
score = 0.0
env.close()
plt.figure()
plt.plot(range(len(avg_returns)),avg_returns)
plt.xlabel('episodes')
plt.ylabel('avg score')
plt.savefig('./plt_ac.png',format= 'png')
if __name__ == '__main__':
main()
边栏推荐
- 7-7 7003 组合锁(PTA程序设计)
- Brief introduction to XHR - basic use of XHR
- Canvas foundation 2 - arc - draw arc
- 【九阳神功】2018复旦大学应用统计真题+解析
- [the Nine Yang Manual] 2021 Fudan University Applied Statistics real problem + analysis
- 7-9 制作门牌号3.0(PTA程序设计)
- Custom RPC project - frequently asked questions and explanations (Registration Center)
- 5月27日杂谈
- 【头歌educoder数据表中数据的插入、修改和删除】
- Beautified table style
猜你喜欢
Poker game program - man machine confrontation
FAQs and answers to the imitation Niuke technology blog project (II)
PriorityQueue (large root heap / small root heap /topk problem)
A piece of music composed by buzzer (Chengdu)
Nuxtjs快速上手(Nuxt2)
实验六 继承和多态
Meituan dynamic thread pool practice ideas, open source
C语言入门指南
1. First knowledge of C language (1)
The difference between cookies and sessions
随机推荐
(原创)制作一个采用 LCD1602 显示的电子钟,在 LCD 上显示当前的时间。显示格式为“时时:分分:秒秒”。设有 4 个功能键k1~k4,功能如下:(1)k1——进入时间修改。
实验六 继承和多态
[modern Chinese history] Chapter V test
Leetcode. 3. Longest substring without repeated characters - more than 100% solution
Implementation principle of automatic capacity expansion mechanism of ArrayList
为什么要使用Redis
This time, thoroughly understand the MySQL index
String ABC = new string ("ABC"), how many objects are created
4. Branch statements and loop statements
PriorityQueue (large root heap / small root heap /topk problem)
Brief introduction to XHR - basic use of XHR
渗透测试学习与实战阶段分析
实验七 常用类的使用(修正帖)
MySQL lock summary (comprehensive and concise + graphic explanation)
7-15 h0161. 求最大公约数和最小公倍数(PTA程序设计)
ArrayList的自动扩容机制实现原理
7-5 走楼梯升级版(PTA程序设计)
1. C language matrix addition and subtraction method
【九阳神功】2016复旦大学应用统计真题+解析
Write a program to simulate the traffic lights in real life.