当前位置:网站首页>Dqn pytoch example
Dqn pytoch example
2022-07-26 01:58:00 【Nebula】
An agent is a letter o, It's stuck in many _ Between , The goal is to achieve and ensure o On both sides _, This requires o It can move to the left and right , And the speed is slightly faster than the natural movement speed when there is no action , It looks like the following . This is a very simple situation .
pytorch edition :1.11.0+cu113
Code
Because every initialization is the same , There will be a lot of the same data , So first define a data structure , It can be hash Of , Easy to store in the collection .
class Data:
def __init__(self, s: 'list[float]|tuple[float]', a: 'list[float]|tuple[float]', r: float, s_: 'list[float]|tuple[float]') -> None:
self.tuple_ = (tuple(s), tuple(a), r, tuple(s_))
@property
def state(self):
return list(self.tuple_[0])
@property
def action(self):
return list(self.tuple_[1])
@property
def reward(self):
return self.tuple_[2]
@property
def next_state(self):
return list(self.tuple_[3])
def __ne__(self, __o: object) -> bool:
if type(__o) != Data:
return False
return self.tuple_ != __o.tuple_
def __eq__(self, __o: object) -> bool:
if type(__o) != Data:
return False
return self.tuple_ == __o.tuple_
def __hash__(self) -> int:
return hash(self.tuple_)
Print progress bar ,console Spaces of different colors will be displayed inside
import math
def print_bar(epoch, epochs, step, batch_size, etc=None, bar_size=50):
process = math.ceil(bar_size*step/batch_size)
strs = [f"Epoch {
epoch}/{
epochs}", f" |\033[1;30;47m{
' ' * process}\033[0m{
' ' * (bar_size-process)}| ",]
if etc is not None:
strs.append(str(etc))
if step:
strs.insert(0, "\033[A")
print("".join(strs)+" ")
Implementation can copy Parameter container
from collections import OrderedDict, abc
from torch import Tensor, nn
from torch import optim as optim
class Model(nn.Module):
def __init__(self, layers: 'list[tuple[nn.Module,abc.callable|function]]', device=None):
super(Model, self).__init__()
self.layers = [layer for layer, _ in layers]
self.module_list = nn.ModuleList(self.layers).to(device)
self.activations = [f for _, f in layers]
self.deep = len(layers)
def forward(self, x: Tensor) -> Tensor:
a = x
for i in range(self.deep):
a = self.module_list[i](a)
activation = self.activations[i]
if activation:
a = activation(a)
return a
def load_state_dict(self, model: 'Model', rate: float = .5):
for i in range(self.deep):
if rate >= 1.:
self.layers[i].load_state_dict(model.layers[i].state_dict())
else:
local = self.layers[i].state_dict()
forign = model.layers[i].state_dict()
mix = OrderedDict()
for key in local.keys():
mix[key] = local.get(key)*(1-rate) + forign.get(key)*rate
self.layers[i].load_state_dict(mix)
def copy(self) -> 'Model':
params = []
for i in range(self.deep):
params.append((self.layers[i], self.activations[i]))
model = Model(params)
model.load_state_dict(self, 1)
return model
The main part of the training
# -*- coding: utf-8 -*-
from datetime import datetime
import random, time, torch, os
from torch import cuda, device, nn, optim, Tensor
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
GPU = device("cuda" if cuda.is_available() else "cpu")
def draw(length, location, new_line: bool = False, item=None):
string = "_"*int(location)+"o"+"_"*(length-int(location)-1)
if not new_line:
string = "\033[A"+string
print(string[:length] + f" {
str(item)} ")
def env(length, location, speed, force=0) -> 'tuple[int,int,bool]':
if location < length/2:
f = -1
elif location > length/2:
f = 1
else:
f = random.randint(-1, 1)
new_speed = speed+f+force
if location < 0 or location >= length:
living = False
else:
living = True
return location+speed, new_speed, living
def init_data(width: int) -> 'tuple[int,int]':
return(int(width/2), 0)
def make_state(location, width, speed, speed_scale) -> 'list[float]':
return [(location-width/2)/width*2, speed/speed_scale]
def simulate(model: Model, batch_size: int, width: int, speed_scale: int, action_list: 'list', epsilon: float, reward_range: float = .7):
""" Environmental simulation , collecting data """
if reward_range <= .5:
raise ValueError('The arg reward_range cannot leq to .5')
action_count = len(action_list)
location, speed = init_data(width)
cache = set()
live_time = 0
score = 0
max_score = max(1, batch_size*(batch_size+1)/2) # If you haven't fallen down, you can get so many points at most
for _ in range(batch_size):
state = make_state(location, width, speed, speed_scale)
if random.random() <= epsilon: # Decide whether to explore or use
action_index = random.randint(0, action_count-1)
else:
action_index = torch.argmax(model(torch.tensor(data=state, dtype=torch.float32, device=GPU)))
a = action_list[action_index]
location_, speed_, r = env(width, location, speed, action_list[a]) # Calculation sₜ₊₁
data = Data(
state,
[action_index == i for i in range(action_count)], # Data type conversion
1. if (1-reward_range) < location/width < reward_range else 0., # Here is not a reward without falling , But in the middle can be positive rewards
make_state(location_, width, speed_, speed_scale) # Data type conversion
)
cache.add(data) # Fill experience pool
location, speed = location_, speed_ # s=sₜ₊₁
live_time += 1
score += live_time
if not r:
location, speed = init_data(width)
return cache, score/max_score
def train(width: int, speed_scale: int, action_list: 'tuple|list', model: Model, optimizer: optim.Optimizer, loss_func: nn.modules.loss._Loss, epochs: int, batch_size: int, gamma: float = .1, epsilon: float = .1, soft_update_rate=.1, target_accuracy=.99) -> 'list[float]':
policy_net = model.copy()
target_net = model
policy_net.train(mode=True)
target_net.train(mode=False) # The target network is not trained
action_count = len(action_list)
cache = set()
for epoch in range(epochs):
cache_, accuracy = simulate(model=target_net, batch_size=batch_size, speed_scale=speed_scale, width=width, action_list=action_list, epsilon=epsilon)
cache = cache | cache_
states, actions, rewards, state_nexts = [], [], [], []
real_batch_size = min(batch_size, len(cache))
for data in random.sample(cache, real_batch_size): # Randomly select a batch of data from the experience pool
states.append(data.state)
actions.append(data.action)
rewards.append([data.reward, ])
state_nexts.append(data.next_state)
# Convert the list into a tensor
states = torch.tensor(states, device=GPU) + torch.randn_like(states, device=GPU)*.02 # Add an extra noise
actions = torch.tensor(actions, device=GPU)
rewards = torch.tensor(rewards, device=GPU)
state_nexts = torch.tensor(state_nexts, device=GPU)
rewards = rewards.expand((real_batch_size, action_count))
v = target_net(state_nexts).detach()
y = v*(v.argsort(dim=-1, descending=False).eq(0)) * gamma + rewards * (1 - gamma) # Control the proportion of immediate rewards
v_hat = policy_net(states) * actions # because actions It's a tensor , Unselected actions naturally become 0
loss = loss_func(y_hat, y)
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
target_net.load_state_dict(policy_net, soft_update_rate)
print_bar(epoch, epochs, epoch, epochs, ('{:.10f}'.format(loss.item()), '{:.10f}'.format(accuracy),))
if accuracy >= target_accuracy: # If the score meets the requirements, the model training is completed
break
return target_net
if __name__ == "__main__":
WIDTH = 30 # Define platform width
SPEED_SCALE = 8 # Define speed sensitivity
ACTIONS = (-2, 0, 2, ) # Action list
EPOCHS = 10000 # Maximum number of iterations
BATCH_SIZE = 64
layers = [
(torch.nn.Linear(2, 8), torch.sigmoid),
(torch.nn.Linear(8, 3), torch.sigmoid),
] # Model
model = Model(layers=layers, device=GPU)
opt = optim.NAdam(model.parameters(), lr=.07)
loss_func = nn.MSELoss()
model = train(
width=WIDTH, speed_scale=SPEED_SCALE, action_list=ACTIONS, model=model, optimizer=opt, loss_func=loss_func, epochs=EPOCHS, batch_size=BATCH_SIZE,
gamma=.3, epsilon=.2, soft_update_rate=.3, target_accuracy=.93
)
model.to("cpu")
print("\n\n")
location, speed = init_data(WIDTH)
for step in range(200): # Play the animation to show the model training results
state = torch.tensor(make_state(location, WIDTH, speed, SPEED_SCALE))
a = ACTIONS[torch.argmax(model(state))]
location_, speed_, r = env(WIDTH, location, speed, ACTIONS[a])
draw(WIDTH, location, not step, (a, action.tolist(), location, speed))
if r <= 0:
location, speed = init_data(WIDTH)
else:
location, speed = location_, speed_
time.sleep(.1)
边栏推荐
- Worthington木瓜蛋白酶丨从纯化的蛋白聚糖生产糖肽(附文献)
- [in simple terms, play with FPGA learning 11 --- testbench writing skills 1]
- Relationship between HTC mobile official solution, s-on/s-off and super CID
- BGP知识点总结
- Go operation excel library excel use
- 推荐系统-协同过滤在Spark中的实现
- 一种MCU事件型驱动C框架
- IP address of the network
- Pt onnx ncnn conversion problem record (followed by yolov5 training)
- Shell exercises
猜你喜欢

Maximum side length of elements and squares less than or equal to the threshold (source: leetcode)

BGP知识点总结

PHP Alipay transfer to Alipay account

The detailed knowledge summary of MySQL can be collected

IP address of the network

E. Split into two sets

Characteristics and determination of neuraminidase from Clostridium perfringens in Worthington
![[independent station construction] Shopify seller: learn these points and double the sales volume of online stores!](/img/52/8c1520db38ffa8927e975b6f244a65.png)
[independent station construction] Shopify seller: learn these points and double the sales volume of online stores!

网络之二三层转发

DialogRPT-Dialog Ranking Pretrained Transformers
随机推荐
Implementation of recommendation system collaborative filtering in spark
Redis集群搭建(基于6.x)
【独立站建设】shopify卖家:学会这几点,网上商店销量翻倍!
"Wei Lai Cup" 2022 Niuke summer multi school training camp 2 personal problem sets
Overview of database stress testing methods
Zhinai buys melons (DP backpack)
P3166 number triangle (tolerance and exclusion +gcd)
What is cross site scripting (XSS)?
npm ERR! code ETIMEDOUTnpm ERR! syscall connectnpm ERR! errno ETIMEDOUTnpm ERR! network request t
Redis6.x配置参数详解
What is a test case? How to design?
SQL injection tutorial: learn through examples
Qt程序美化之样式表的使用方法,Qt使用图片作为背景与控件透明化,Qt自定义按钮样式
[Verilog digital system design (Xia Yuwen) 4 ----- basic concepts of Verilog syntax 2]
Excuse me, sir. Oracle to PG CDC Oracle, the upper case of the field is the same as that of PG
pt-onnx-ncnn转换的问题记录(接yolov5训练)
给RestTemplate添加拦截器记录请求响应,还需解决流只读一次的问题
Common shell operations in Phoenix
学习笔记:原码, 反码, 补码
npm ERR! code ETIMEDOUTnpm ERR! syscall connectnpm ERR! errno ETIMEDOUTnpm ERR! network request t