diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 0e4fe1022..c8d5838a8 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -24,7 +24,7 @@ jobs: - name: Test with pytest # ignore test/throughput which only profiles the code run: | - pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v + pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 with: diff --git a/README.md b/README.md index 295ca96db..e519d2051 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock) -[![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/latest) +[![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master) [![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/latest/) [![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) @@ -14,7 +14,6 @@ [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE) -[![Gitter](https://badges.gitter.im/thu-ml/tianshou.svg)](https://gitter.im/thu-ml/tianshou?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) **Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include: diff --git a/examples/vizdoom/.gitignore b/examples/vizdoom/.gitignore new file mode 100644 index 000000000..b8bde54ff --- /dev/null +++ b/examples/vizdoom/.gitignore @@ -0,0 +1 @@ +_vizdoom.ini diff --git a/examples/vizdoom/README.md b/examples/vizdoom/README.md new file mode 100644 index 000000000..ed68a46c9 --- /dev/null +++ b/examples/vizdoom/README.md @@ -0,0 +1,66 @@ +# ViZDoom + +[ViZDoom](https://github.com/mwydmuch/ViZDoom) is a popular RL env for a famous first-person shooting game Doom. Here we provide some results and intuitions for this scenario. + +## Train + +To train an agent: + +```bash +python3 vizdoom_c51.py --task {D1_basic|D3_battle|D4_battle2} +``` + +D1 (health gathering) should finish training (no death) in less than 500k env step (5 epochs); + +D3 can reach 1600+ reward (75+ killcount in 5 minutes); + +D4 can reach 700+ reward. Here is the result: + +(episode length, the maximum length is 2625 because we use frameskip=4, that is 10500/4=2625) + +![](results/c51/length.png) + +(episode reward) + +![](results/c51/reward.png) + +To evaluate an agent's performance: + +```bash +python3 vizdoom_c51.py --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} +``` + +To save `.lmp` files for recording: + +```bash +python3 vizdoom_c51.py --save-lmp --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} +``` + +it will store `lmp` file in `lmps/` directory. To watch these `lmp` files (for example, d3 lmp): + +```bash +python3 replay.py maps/D3_battle.cfg episode_8_25.lmp +``` + +We provide two lmp files (d3 best and d4 best) under `results/c51`, you can use the following command to enjoy: + +```bash +python3 replay.py maps/D3_battle.cfg results/c51/d3.lmp +python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp +``` + +## Maps + +See [maps/README.md](maps/README.md) + +## Algorithms + +The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example. + +## Reward + +1. living reward is bad +2. combo-action is really important +3. negative reward for health and ammo2 is really helpful for d3/d4 +4. only with positive reward for health is really helpful for d1 +5. remove MOVE_BACKWARD may converge faster but the final performance may be lower diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py new file mode 100644 index 000000000..017ab7750 --- /dev/null +++ b/examples/vizdoom/env.py @@ -0,0 +1,129 @@ +import os +import cv2 +import gym +import numpy as np +import vizdoom as vzd + + +def normal_button_comb(): + actions = [] + m_forward = [[0.0], [1.0]] + t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] + for i in m_forward: + for j in t_left_right: + actions.append(i + j) + return actions + + +def battle_button_comb(): + actions = [] + m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] + m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] + t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] + attack = [[0.0], [1.0]] + speed = [[0.0], [1.0]] + + for m in attack: + for n in speed: + for j in m_left_right: + for i in m_forward_backward: + for k in t_left_right: + actions.append(i + j + k + m + n) + return actions + + +class Env(gym.Env): + def __init__( + self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False + ): + super().__init__() + self.save_lmp = save_lmp + self.health_setting = "battle" in cfg_path + if save_lmp: + os.makedirs("lmps", exist_ok=True) + self.res = res + self.skip = frameskip + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=res, dtype=np.float32 + ) + self.game = vzd.DoomGame() + self.game.load_config(cfg_path) + self.game.init() + if "battle" in cfg_path: + self.available_actions = battle_button_comb() + else: + self.available_actions = normal_button_comb() + self.action_num = len(self.available_actions) + self.action_space = gym.spaces.Discrete(self.action_num) + self.spec = gym.envs.registration.EnvSpec("vizdoom-v0") + self.count = 0 + + def get_obs(self): + state = self.game.get_state() + if state is None: + return + obs = state.screen_buffer + self.obs_buffer[:-1] = self.obs_buffer[1:] + self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2])) + + def reset(self): + if self.save_lmp: + self.game.new_episode(f"lmps/episode_{self.count}.lmp") + else: + self.game.new_episode() + self.count += 1 + self.obs_buffer = np.zeros(self.res, dtype=np.uint8) + self.get_obs() + self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH) + self.killcount = self.game.get_game_variable( + vzd.GameVariable.KILLCOUNT) + self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) + return self.obs_buffer + + def step(self, action): + self.game.make_action(self.available_actions[action], self.skip) + reward = 0.0 + self.get_obs() + health = self.game.get_game_variable(vzd.GameVariable.HEALTH) + if self.health_setting: + reward += health - self.health + elif health > self.health: # positive health reward only for d1/d2 + reward += health - self.health + self.health = health + killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) + reward += 20 * (killcount - self.killcount) + self.killcount = killcount + ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) + # if ammo2 > self.ammo2: + reward += ammo2 - self.ammo2 + self.ammo2 = ammo2 + done = False + info = {} + if self.game.is_player_dead() or self.game.get_state() is None: + done = True + elif self.game.is_episode_finished(): + done = True + info["TimeLimit.truncated"] = True + return self.obs_buffer, reward, done, info + + def render(self): + pass + + def close(self): + self.game.close() + + +if __name__ == '__main__': + # env = Env("maps/D1_basic.cfg", 4, (4, 84, 84)) + env = Env("maps/D3_battle.cfg", 4, (4, 84, 84)) + print(env.available_actions) + action_num = env.action_space.n + obs = env.reset() + print(env.spec.reward_threshold) + print(obs.shape, action_num) + for i in range(4000): + obs, rew, done, info = env.step(0) + if done: + env.reset() + print(obs.shape, rew, done) + cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3]) diff --git a/examples/vizdoom/maps/D1_basic.cfg b/examples/vizdoom/maps/D1_basic.cfg new file mode 100644 index 000000000..1c7431e02 --- /dev/null +++ b/examples/vizdoom/maps/D1_basic.cfg @@ -0,0 +1,39 @@ +# Lines starting with # are treated as comments (or with whitespaces+#). +# It doesn't matter if you use capital letters or not. +# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. + +doom_scenario_path = D1_basic.wad +doom_map = map01 + +# Rewards + +# Each step is good for you! +living_reward = 0 +# And death is not! +death_penalty = 0 + +# Rendering options +screen_resolution = RES_160X120 +screen_format = GRAY8 +render_hud = false +render_crosshair = false +render_weapon = false +render_decals = false +render_particles = false +window_visible = false + +# make episodes finish after 10500 actions (tics) +episode_timeout = 10500 + +# Available buttons +available_buttons = +{ + MOVE_FORWARD + TURN_LEFT + TURN_RIGHT +} + +# Game variables that will be in the state +available_game_variables = { HEALTH } + +mode = PLAYER diff --git a/examples/vizdoom/maps/D1_basic.wad b/examples/vizdoom/maps/D1_basic.wad new file mode 100644 index 000000000..51034a24a Binary files /dev/null and b/examples/vizdoom/maps/D1_basic.wad differ diff --git a/examples/vizdoom/maps/D2_navigation.cfg b/examples/vizdoom/maps/D2_navigation.cfg new file mode 100644 index 000000000..402424e17 --- /dev/null +++ b/examples/vizdoom/maps/D2_navigation.cfg @@ -0,0 +1,39 @@ +# Lines starting with # are treated as comments (or with whitespaces+#). +# It doesn't matter if you use capital letters or not. +# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. + +doom_scenario_path = D2_navigation.wad +doom_map = map01 + +# Rewards + +# Each step is good for you! +living_reward = 0 +# And death is not! +death_penalty = 0 + +# Rendering options +screen_resolution = RES_160X120 +screen_format = GRAY8 +render_hud = false +render_crosshair = false +render_weapon = false +render_decals = false +render_particles = false +window_visible = false + +# make episodes finish after 10500 actions (tics) +episode_timeout = 10500 + +# Available buttons +available_buttons = +{ + MOVE_FORWARD + TURN_LEFT + TURN_RIGHT +} + +# Game variables that will be in the state +available_game_variables = { HEALTH } + +mode = PLAYER diff --git a/examples/vizdoom/maps/D2_navigation.wad b/examples/vizdoom/maps/D2_navigation.wad new file mode 100644 index 000000000..b4327d67c Binary files /dev/null and b/examples/vizdoom/maps/D2_navigation.wad differ diff --git a/examples/vizdoom/maps/D3_battle.cfg b/examples/vizdoom/maps/D3_battle.cfg new file mode 100644 index 000000000..9d05781f4 --- /dev/null +++ b/examples/vizdoom/maps/D3_battle.cfg @@ -0,0 +1,48 @@ +# Lines starting with # are treated as comments (or with whitespaces+#). +# It doesn't matter if you use capital letters or not. +# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. + +doom_scenario_path = D3_battle.wad +doom_map = map01 + +# Rewards + +living_reward = 0 +death_penalty = 100 + +# Rendering options +screen_resolution = RES_160X120 +screen_format = GRAY8 +render_hud = false +render_crosshair = true +render_weapon = true +render_decals = false +render_particles = false +window_visible = false + +# make episodes finish after 10500 actions (tics) +episode_timeout = 10500 + +# Available buttons +available_buttons = +{ + MOVE_FORWARD + MOVE_BACKWARD + MOVE_LEFT + MOVE_RIGHT + TURN_LEFT + TURN_RIGHT + ATTACK + SPEED +} + +# Game variables that will be in the state +available_game_variables = +{ + KILLCOUNT + AMMO2 + HEALTH +} + +mode = PLAYER +doom_skill = 2 diff --git a/examples/vizdoom/maps/D3_battle.wad b/examples/vizdoom/maps/D3_battle.wad new file mode 100644 index 000000000..de7877ef2 Binary files /dev/null and b/examples/vizdoom/maps/D3_battle.wad differ diff --git a/examples/vizdoom/maps/D4_battle2.cfg b/examples/vizdoom/maps/D4_battle2.cfg new file mode 100644 index 000000000..8f7d7798c --- /dev/null +++ b/examples/vizdoom/maps/D4_battle2.cfg @@ -0,0 +1,48 @@ +# Lines starting with # are treated as comments (or with whitespaces+#). +# It doesn't matter if you use capital letters or not. +# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. + +doom_scenario_path = D4_battle2.wad +doom_map = map01 + +# Rewards + +living_reward = 0 +death_penalty = 100 + +# Rendering options +screen_resolution = RES_160X120 +screen_format = GRAY8 +render_hud = false +render_crosshair = true +render_weapon = true +render_decals = false +render_particles = false +window_visible = false + +# make episodes finish after 10500 actions (tics) +episode_timeout = 10500 + +# Available buttons +available_buttons = +{ + MOVE_FORWARD + MOVE_BACKWARD + MOVE_LEFT + MOVE_RIGHT + TURN_LEFT + TURN_RIGHT + ATTACK + SPEED +} + +# Game variables that will be in the state +available_game_variables = +{ + KILLCOUNT + AMMO2 + HEALTH +} + +mode = PLAYER +doom_skill = 2 diff --git a/examples/vizdoom/maps/D4_battle2.wad b/examples/vizdoom/maps/D4_battle2.wad new file mode 100644 index 000000000..864306956 Binary files /dev/null and b/examples/vizdoom/maps/D4_battle2.wad differ diff --git a/examples/vizdoom/maps/README.md b/examples/vizdoom/maps/README.md new file mode 100644 index 000000000..3e4769018 --- /dev/null +++ b/examples/vizdoom/maps/README.md @@ -0,0 +1,3 @@ +D1-D4 maps are from https://github.com/intel-isl/DirectFuturePrediction/ + +More maps and cfgs: https://github.com/mwydmuch/ViZDoom/tree/master/scenarios diff --git a/examples/vizdoom/maps/spectator.py b/examples/vizdoom/maps/spectator.py new file mode 100644 index 000000000..d4d7e8c7c --- /dev/null +++ b/examples/vizdoom/maps/spectator.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +##################################################################### +# This script presents SPECTATOR mode. In SPECTATOR mode you play and +# your agent can learn from it. +# Configuration is loaded from "../../scenarios/.cfg" file. +# +# To see the scenario description go to "../../scenarios/README.md" +##################################################################### + +from __future__ import print_function + +from time import sleep +import vizdoom as vzd +from argparse import ArgumentParser +# import cv2 + + +if __name__ == "__main__": + parser = ArgumentParser("ViZDoom example showing how to use SPECTATOR mode.") + parser.add_argument('-c', type=str, dest="config", default="D3_battle.cfg") + parser.add_argument('-w', type=str, dest="wad_file", default="D3_battle.wad") + args = parser.parse_args() + game = vzd.DoomGame() + + # Choose scenario config file you wish to watch. + # Don't load two configs cause the second will overrite the first one. + # Multiple config files are ok but combining these ones doesn't make much sense. + + game.load_config(args.config) + game.set_doom_scenario_path(args.wad_file) + # Enables freelook in engine + game.add_game_args("+freelook 1") + + game.set_screen_resolution(vzd.ScreenResolution.RES_640X480) + + # Enables spectator mode, so you can play. + # Sounds strange but it is the agent who is supposed to watch not you. + game.set_window_visible(True) + game.set_mode(vzd.Mode.SPECTATOR) + + game.init() + + episodes = 1 + + for i in range(episodes): + print("Episode #" + str(i + 1)) + + game.new_episode() + while not game.is_episode_finished(): + state = game.get_state() + print(state.screen_buffer.dtype, state.screen_buffer.shape) + # cv2.imwrite(f'imgs/{state.number}.png', state.screen_buffer) + + # game.make_action([0, 0, 0]) + game.advance_action() + last_action = game.get_last_action() + reward = game.get_last_reward() + + print("State #" + str(state.number)) + print("Game variables: ", state.game_variables) + print("Action:", last_action) + print("Reward:", reward) + print("=====================") + + print("Episode finished!") + print("Total reward:", game.get_total_reward()) + print("************************") + sleep(2.0) + + game.close() diff --git a/examples/vizdoom/network.py b/examples/vizdoom/network.py new file mode 120000 index 000000000..a0c543acb --- /dev/null +++ b/examples/vizdoom/network.py @@ -0,0 +1 @@ +../atari/atari_network.py \ No newline at end of file diff --git a/examples/vizdoom/replay.py b/examples/vizdoom/replay.py new file mode 100755 index 000000000..a1e556fce --- /dev/null +++ b/examples/vizdoom/replay.py @@ -0,0 +1,35 @@ +# import cv2 +import sys +import time +import tqdm +import vizdoom as vzd + + +def main(cfg_path="maps/D3_battle.cfg", lmp_path="test.lmp"): + game = vzd.DoomGame() + game.load_config(cfg_path) + game.set_screen_format(vzd.ScreenFormat.CRCGCB) + game.set_screen_resolution(vzd.ScreenResolution.RES_1024X576) + game.set_window_visible(True) + game.set_render_hud(True) + game.init() + game.replay_episode(lmp_path) + + killcount = 0 + with tqdm.trange(10500) as tq: + while not game.is_episode_finished(): + game.advance_action() + state = game.get_state() + if state is None: + break + killcount = game.get_game_variable(vzd.GameVariable.KILLCOUNT) + time.sleep(1 / 35) + # cv2.imwrite(f"imgs/{tq.n}.png", + # state.screen_buffer.transpose(1, 2, 0)[..., ::-1]) + tq.update(1) + game.close() + print("killcount:", killcount) + + +if __name__ == '__main__': + main(*sys.argv[-2:]) diff --git a/examples/vizdoom/results/c51/d3.lmp b/examples/vizdoom/results/c51/d3.lmp new file mode 100644 index 000000000..5e0e31d6c Binary files /dev/null and b/examples/vizdoom/results/c51/d3.lmp differ diff --git a/examples/vizdoom/results/c51/d4.lmp b/examples/vizdoom/results/c51/d4.lmp new file mode 100644 index 000000000..358fbfe22 Binary files /dev/null and b/examples/vizdoom/results/c51/d4.lmp differ diff --git a/examples/vizdoom/results/c51/length.png b/examples/vizdoom/results/c51/length.png new file mode 100644 index 000000000..e6315ba93 Binary files /dev/null and b/examples/vizdoom/results/c51/length.png differ diff --git a/examples/vizdoom/results/c51/reward.png b/examples/vizdoom/results/c51/reward.png new file mode 100644 index 000000000..3a93f28ea Binary files /dev/null and b/examples/vizdoom/results/c51/reward.png differ diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py new file mode 100644 index 000000000..7391de087 --- /dev/null +++ b/examples/vizdoom/vizdoom_c51.py @@ -0,0 +1,177 @@ +import os +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import C51Policy +from tianshou.utils import BasicLogger +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, VectorReplayBuffer + +from env import Env +from network import C51 + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='D1_basic') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps-test', type=float, default=0.005) + parser.add_argument('--eps-train', type=float, default=1.) + parser.add_argument('--eps-train-final', type=float, default=0.05) + parser.add_argument('--buffer-size', type=int, default=2000000) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--num-atoms', type=int, default=51) + parser.add_argument('--v-min', type=float, default=-10.) + parser.add_argument('--v-max', type=float, default=10.) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=500) + parser.add_argument('--epoch', type=int, default=300) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument('--skip-num', type=int, default=4) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument('--watch', default=False, action='store_true', + help='watch the play of pre-trained policy only') + parser.add_argument('--save-lmp', default=False, action='store_true', + help='save lmp file for replay whole episode') + parser.add_argument('--save-buffer-name', type=str, default=None) + return parser.parse_args() + + +def test_c51(args=get_args()): + args.cfg_path = f"maps/{args.task}.cfg" + args.wad_path = f"maps/{args.task}.wad" + args.res = (args.skip_num, 84, 84) + env = Env(args.cfg_path, args.frames_stack, args.res) + args.state_shape = args.res + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + train_envs = SubprocVectorEnv([ + lambda: Env(args.cfg_path, args.frames_stack, args.res) + for _ in range(args.training_num)]) + test_envs = SubprocVectorEnv([ + lambda: Env(args.cfg_path, args.frames_stack, + args.res, args.save_lmp) + for _ in range(min(os.cpu_count() - 1, args.test_num))]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + net = C51(*args.state_shape, args.action_shape, + args.num_atoms, args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy = C51Policy( + net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, + args.n_step, target_update_freq=args.target_update_freq + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # log + log_path = os.path.join(args.logdir, args.task, 'c51') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + elif 'Pong' in args.task: + return mean_rewards >= 20 + else: + return False + + def train_fn(epoch, env_step): + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * \ + (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + logger.write('train/eps', env_step, eps) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(test_envs), + ignore_obs_next=True, save_only_last_obs=True, + stack_num=args.frames_stack) + collector = Collector(policy, test_envs, buffer, + exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, + render=args.render) + rew = result["rews"].mean() + lens = result["lens"].mean() * args.skip_num + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f'Mean length (over {result["n/ep"]} episodes): {lens}') + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step, test_in_train=False) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_c51(get_args()) diff --git a/setup.py b/setup.py index 284ae0f56..8104892f0 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def get_version() -> str: ], extras_require={ "dev": [ - "Sphinx", + "sphinx<4", "sphinx_rtd_theme", "sphinxcontrib-bibtex", "flake8", diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index d89a7f4bc..ffeb6b911 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -22,7 +22,7 @@ def get_args(): parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--episode-per-collect', type=int, default=1) parser.add_argument('--training-num', type=int, default=1) - parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.0) parser.add_argument('--rew-mean-prior', type=float, default=0.0) @@ -36,12 +36,12 @@ def get_args(): def test_psrl(args=get_args()): env = gym.make(args.task) if args.task == "NChain-v0": - env.spec.reward_threshold = 3647 # described in PSRL paper + env.spec.reward_threshold = 3400 + # env.spec.reward_threshold = 3647 # described in PSRL paper print("reward threshold:", env.spec.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # train_envs = gym.make(args.task) - # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 84cc0077f..fb362d054 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, exploration -__version__ = "0.4.1" +__version__ = "0.4.2" __all__ = [ "env", diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 55a00dacd..f0ce76c2b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -526,7 +526,12 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None elif all(isinstance(e, (Batch, dict)) for e in v): # third often self.__dict__[k] = Batch.stack(v, axis) else: # most often case is np.ndarray - self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis)) + try: + self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis)) + except ValueError: + warnings.warn("You are using tensors with different shape," + " fallback to dtype=object by default.") + self.__dict__[k] = np.array(v, dtype=object) # all the keys keys_total = set.union(*[set(b.keys()) for b in batches]) # keys that are reserved in all batches diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 4189207a9..1c8fc617a 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -53,6 +53,7 @@ def __init__( self._save_only_last_obs = save_only_last_obs self._sample_avail = sample_avail self._meta: Batch = Batch() + self._ep_rew: Union[float, np.ndarray] self.reset() def __len__(self) -> int: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 14f6f5616..f9fede557 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -56,7 +56,8 @@ def __init__( exploration_noise: bool = False, ) -> None: super().__init__() - if not isinstance(env, BaseVectorEnv): + if isinstance(env, gym.Env) and not hasattr(env, "__len__"): + warnings.warn("Single environment detected, wrap to DummyVectorEnv.") env = DummyVectorEnv([lambda: env]) self.env = env self.env_num = len(env) @@ -223,7 +224,8 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids) + obs_next, rew, done, info = self.env.step( + action_remap, ready_env_ids) # type: ignore self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: @@ -426,7 +428,8 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids) + obs_next, rew, done, info = self.env.step( + action_remap, ready_env_ids) # type: ignore # change self.data here because ready_env_ids has changed ready_env_ids = np.array([i["env_id"] for i in info]) diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index f5b203f3e..05e4b2655 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -75,8 +75,8 @@ def __init__( self._min_q_weight = min_q_weight def sync_weight(self) -> None: - self.actor_old.load_state_dict(self.actor.state_dict()) # type: ignore - self.critic_old.load_state_dict(self.critic.state_dict()) # type: ignore + self.actor_old.load_state_dict(self.actor.state_dict()) + self.critic_old.load_state_dict(self.critic.state_dict()) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore if self._target and self._iter % self._freq == 0: diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index e52af9e6e..929512b86 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -72,7 +72,7 @@ def train(self, mode: bool = True) -> "DQNPolicy": def sync_weight(self) -> None: """Synchronize the weight for the target network.""" - self.model_old.load_state_dict(self.model.state_dict()) # type: ignore + self.model_old.load_state_dict(self.model.state_dict()) def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index a77817352..b0ba63f11 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -64,6 +64,7 @@ def __init__( self._max_backtracks = max_backtracks self._delta = max_kl self._backtrack_coeff = backtrack_coeff + self._optim_critic_iters: int def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any diff --git a/tianshou/utils/log_tools.py b/tianshou/utils/log_tools.py index 2a8d3a251..5a25f1394 100644 --- a/tianshou/utils/log_tools.py +++ b/tianshou/utils/log_tools.py @@ -88,7 +88,7 @@ class BasicLogger(BaseLogger): You can also rewrite write() func to use your own writer. :param SummaryWriter writer: the writer to log data. - :param int train_interval: the log interval in log_train_data(). Default to 1. + :param int train_interval: the log interval in log_train_data(). Default to 1000. :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). Default to 1000. :param int save_interval: the save interval in save_data(). Default to 1 (save at @@ -98,7 +98,7 @@ class BasicLogger(BaseLogger): def __init__( self, writer: SummaryWriter, - train_interval: int = 1, + train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, save_interval: int = 1,