强化学习Stable Baseline3框架基本使用

Stable Baseline3是一个基于PyTorch的深度强化学习工具包,能够快速完成强化学习算法的搭建和评估,提供预训练的智能体,包括保存和录制视频等等,是一个功能非常强大的库。经常和gym搭配,被广泛应用于各种强化学习训练中

SB3提供了可以直接调用的RL算法模型,如A2C、DDPG、DQN、HER、PPO、SAC、TD3,通过引入已有算法+自定义环境,可以方便的应用RL算法

环境安装

如果使用Anaconda,可以新开一个环境

1
2
conda create -n py37 python=3.7
conda activate py37

进入环境后,安装相应的工具包

1
pip install stable-baselines3[extra]

该命令将自动下载所需要的依赖库

若以上命令安装过程中出现Failed building wheel for AutoROM.accept-rom-license的错误,需要手动安装AutoROM.accept-rom-license

1
pip install AutoROM.accept-rom-license==0.4.2

若要使用gym等环境,还需要安装其他相关依赖

1
2
conda install swig
conda install box2d-py

gym环境调用

SB3中对环境的引入及使用遵循gym的通用格式

1
2
3
4
5
6
7
8
9
10
11
12
import gym

env = gym.make("LunarLander-v2") # 引入环境
env.reset() # 重置环境

print("sample action:",env.action_space.sample()) # sample action: 3 对action进行一次采样

print("observation space shape",env.observation_space.sample()) # observation space shape (8,) 观测值的情况(对应下面的8个值)

print("sample observation",env.observation_space.sample())
# observation space shape [ 2.0437958 2.6050246 0.2743332 -0.7890023 0.44479468 -0.76262844 -1.8722864 0.09138725] 对观测值进行一次采样
env.close() # 关闭

gym环境使用

1
2
3
4
5
6
7
8
9
10
import gym

env = gym.make("LunarLander-v2")
env.reset()

for i in range(200):
env.render() # 渲染环境
obs, reward, done, info = env.step(env.action_space.sample()) # 在action中选择一个去执行,step执行结束后,返回相关信息

env.close()

SB3模型使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import gym
from stable_baselines3 import A2C # 引入A2C算法

env = gym.make("LunarLander-v2")
env.reset()

model = A2C(policy="MlpPolicy", env=env, verbose=1) # 设置算法相关参数
model.learn(total_timesteps=10000) # 进行训练,执行10000次action后停止训练

# 训练结束后,使用该模型
episodes = 10
for ep in range(episodes):
obs = env.reset()
done = False # 本轮游戏是否结束
while not done:
env.render()
obs, reward, done, info = env.step(env.action_space.sample()) # 更新相关状态

env.close()
  • policy:选择网络类型,可选MlpPolicy,CnnPolicy,MultiInputPolicy
  • env:选择训练环境,即为引入的环境
  • verbose( int) – 详细级别:0 无输出,1 信息,2 调试,默认为0

模型保存与加载

模型保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import gym
from stable_baselines3 import PPO
import os

models_dir = "models/PPO"
log_dir = "logs"

if not os.path.exists(models_dir):
os.makedirs(models_dir)

if not os.path.exists(log_dir):
os.makedirs(log_dir)

env = gym.make("LunarLander-v2")
env.reset()

model = PPO(policy="MlpPolicy", env=env, verbose=1, tensorboard_log=log_dir) # 设置tensorboard记录存储目录,用于可视化查看

TIMESTEPS = 10000
for i in range(1,30):
model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name="PPO") # 训练模型
model.save(f"{models_dir}/{TIMESTEPS*i}") # 保存模型

env.close()
  • total_timesteps ( int) – 要训练的环境步数
  • reset_num_timesteps ( bool) – 是否重置当前时间步数(用于日志记录)
  • tb_log_name ( str) – TensorBoard 日志运行的名称

模型加载

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import gym
from stable_baselines3 import PPO

env = gym.make("LunarLander-v2")
env.reset()

models_dir = "models/PPO"
model_path = f"{models_dir}/160000.zip" # tensorboard查看并选出表现较好的模型

model = PPO.load(model_path,env=env) # 加载已保存的模型

episodes = 10
for ep in range(episodes): # 使用模型进行测试
obs = env.reset()
done = False
while not done:
env.render()
action, _ = model.predict(obs) # 根据obs选择action
obs, reward, done, info = env.step(action) # 根据action与环境交互得到反馈

env.close()

自定义环境编写

自定义环境中,主要完成环境的创建、初始化状态

step()函数用于执行每一步action并返回①对环境的观测、②本次奖励、③本回合是否结束、④其他相关信息

reset()函数用于初始化本轮训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import gym
from gym import spaces

class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface"""
metadata = {"render.modes": ["human"]}
def __init__(self, arg1, arg2, ...):
super(CustomEnv, self).__init__()
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions:
self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
# Example for using image as input (channel-first; channel-last also works):
self.observation_space = spaces.Box(low=0, high=255,
shape=(N_CHANNELS, HEIGHT, WIDTH), dtype=np.uint8)

def step(self, action):
...
return observation, reward, done, info
def reset(self):
...
return observation # reward, done, info can't be included
def render(self, mode="human"):
...
def close (self):
...

环境测试

(1)SB3提供了用于测试环境是否符合基本逻辑的工具

1
2
3
4
5
from stable_baselines3.common.env_checker import check_env
from snakeenv import SnekEnv # 从编写好的自定义环境中引入该类

env = SnekEnv()
check_env(env)

(2)二次自测

通过模拟实际训练的过程,测试环境是否可用

1
2
3
4
5
6
7
8
9
10
11
12
13
from snakeenv import SnekEnv

env = SnekEnv()
episodes = 50

for ep in range(episodes):
obs = env.reset()
done = False
while not done:
random_action = env.action_space.sample() # 采用随机采样即可
print("action",random_action)
obs, reward, done, info = env.step(random_action) # 测试环境是否给出形式正确的反馈
print("reward",reward)

流程总结

编写自定义环境 --> 测试环境 --> 调用RL算法进行训练并保存模型 --> 查看并选出表现较好的模型 --> 加载最佳模型用于测试集/实际使用


参考链接:

Reinforcement Learning with Stable Baselines 3 - Introduction (P.1) - YouTube

强化学习基础环境 Gym 简介 - 简书 (jianshu.com)

强化学习之stable_baseline3详细说明和各项功能的使用_微笑小星的博客-CSDN博客