From 1cdc203c077c22732e97c03c607fe32b0fe50554 Mon Sep 17 00:00:00 2001 From: Jialu Zhu Date: Mon, 14 Dec 2020 20:36:45 -0800 Subject: [PATCH 01/23] work --- test/discrete/test_bcq.py | 140 +++++++++++++++++++ test/discrete/test_dqn.py | 3 +- tianshou/data/buffer.py | 1 + tianshou/policy/__init__.py | 3 + tianshou/policy/base.py | 1 + tianshou/policy/modelfree/bcq.py | 223 ++++++++++++++++++++++++++++++ tianshou/policy/modelfree/dqn.py | 5 + tianshou/trainer/__init__.py | 2 + tianshou/trainer/offline.py | 145 +++++++++++++++++++ zjlenv/bin/activate | 78 +++++++++++ zjlenv/bin/activate.csh | 42 ++++++ zjlenv/bin/activate.fish | 101 ++++++++++++++ zjlenv/bin/activate.ps1 | 60 ++++++++ zjlenv/bin/activate.xsh | 39 ++++++ zjlenv/bin/activate_this.py | 46 ++++++ zjlenv/bin/chardetect | 8 ++ zjlenv/bin/convert-caffe2-to-onnx | 8 ++ zjlenv/bin/convert-onnx-to-caffe2 | 8 ++ zjlenv/bin/coverage | 8 ++ zjlenv/bin/coverage-3.7 | 8 ++ zjlenv/bin/coverage3 | 8 ++ zjlenv/bin/dmypy | 8 ++ zjlenv/bin/doc8 | 8 ++ zjlenv/bin/easy_install | 10 ++ zjlenv/bin/easy_install-3.7 | 10 ++ zjlenv/bin/f2py | 8 ++ zjlenv/bin/f2py3 | 8 ++ zjlenv/bin/f2py3.7 | 8 ++ zjlenv/bin/flake8 | 8 ++ zjlenv/bin/futurize | 8 ++ zjlenv/bin/google-oauthlib-tool | 8 ++ zjlenv/bin/jsonschema | 8 ++ zjlenv/bin/markdown_py | 8 ++ zjlenv/bin/mypy | 8 ++ zjlenv/bin/mypyc | 55 ++++++++ zjlenv/bin/numba | 8 ++ zjlenv/bin/pasteurize | 8 ++ zjlenv/bin/pbr | 8 ++ zjlenv/bin/pip | 10 ++ zjlenv/bin/pip3 | 10 ++ zjlenv/bin/pip3.7 | 10 ++ zjlenv/bin/py.test | 8 ++ zjlenv/bin/pybabel | 8 ++ zjlenv/bin/pybtex | 8 ++ zjlenv/bin/pybtex-convert | 8 ++ zjlenv/bin/pybtex-format | 8 ++ zjlenv/bin/pycc | 3 + zjlenv/bin/pycodestyle | 8 ++ zjlenv/bin/pydocstyle | 8 ++ zjlenv/bin/pyflakes | 8 ++ zjlenv/bin/pygmentize | 8 ++ zjlenv/bin/pyrsa-decrypt | 8 ++ zjlenv/bin/pyrsa-encrypt | 8 ++ zjlenv/bin/pyrsa-keygen | 8 ++ zjlenv/bin/pyrsa-priv2pub | 8 ++ zjlenv/bin/pyrsa-sign | 8 ++ zjlenv/bin/pyrsa-verify | 8 ++ zjlenv/bin/pytest | 8 ++ zjlenv/bin/python | 1 + zjlenv/bin/python-config | 78 +++++++++++ zjlenv/bin/python3 | 1 + zjlenv/bin/python3.7 | Bin 0 -> 8632 bytes zjlenv/bin/ray | 8 ++ zjlenv/bin/restructuredtext-lint | 8 ++ zjlenv/bin/rllib | 8 ++ zjlenv/bin/rst-lint | 8 ++ zjlenv/bin/rst2html.py | 23 +++ zjlenv/bin/rst2html4.py | 26 ++++ zjlenv/bin/rst2html5.py | 35 +++++ zjlenv/bin/rst2latex.py | 26 ++++ zjlenv/bin/rst2man.py | 26 ++++ zjlenv/bin/rst2odt.py | 30 ++++ zjlenv/bin/rst2odt_prepstyles.py | 67 +++++++++ zjlenv/bin/rst2pseudoxml.py | 23 +++ zjlenv/bin/rst2s5.py | 24 ++++ zjlenv/bin/rst2xetex.py | 27 ++++ zjlenv/bin/rst2xml.py | 23 +++ zjlenv/bin/rstpep2html.py | 25 ++++ zjlenv/bin/sphinx-apidoc | 8 ++ zjlenv/bin/sphinx-autogen | 8 ++ zjlenv/bin/sphinx-build | 8 ++ zjlenv/bin/sphinx-quickstart | 8 ++ zjlenv/bin/stubgen | 8 ++ zjlenv/bin/stubtest | 8 ++ zjlenv/bin/tensorboard | 8 ++ zjlenv/bin/tqdm | 8 ++ zjlenv/bin/tune | 8 ++ zjlenv/bin/wheel | 10 ++ zjlenv/include/python3.7m | 1 + 89 files changed, 1834 insertions(+), 1 deletion(-) create mode 100644 test/discrete/test_bcq.py create mode 100644 tianshou/policy/modelfree/bcq.py create mode 100644 tianshou/trainer/offline.py create mode 100644 zjlenv/bin/activate create mode 100644 zjlenv/bin/activate.csh create mode 100644 zjlenv/bin/activate.fish create mode 100644 zjlenv/bin/activate.ps1 create mode 100644 zjlenv/bin/activate.xsh create mode 100644 zjlenv/bin/activate_this.py create mode 100755 zjlenv/bin/chardetect create mode 100755 zjlenv/bin/convert-caffe2-to-onnx create mode 100755 zjlenv/bin/convert-onnx-to-caffe2 create mode 100755 zjlenv/bin/coverage create mode 100755 zjlenv/bin/coverage-3.7 create mode 100755 zjlenv/bin/coverage3 create mode 100755 zjlenv/bin/dmypy create mode 100755 zjlenv/bin/doc8 create mode 100755 zjlenv/bin/easy_install create mode 100755 zjlenv/bin/easy_install-3.7 create mode 100755 zjlenv/bin/f2py create mode 100755 zjlenv/bin/f2py3 create mode 100755 zjlenv/bin/f2py3.7 create mode 100755 zjlenv/bin/flake8 create mode 100755 zjlenv/bin/futurize create mode 100755 zjlenv/bin/google-oauthlib-tool create mode 100755 zjlenv/bin/jsonschema create mode 100755 zjlenv/bin/markdown_py create mode 100755 zjlenv/bin/mypy create mode 100755 zjlenv/bin/mypyc create mode 100755 zjlenv/bin/numba create mode 100755 zjlenv/bin/pasteurize create mode 100755 zjlenv/bin/pbr create mode 100755 zjlenv/bin/pip create mode 100755 zjlenv/bin/pip3 create mode 100755 zjlenv/bin/pip3.7 create mode 100755 zjlenv/bin/py.test create mode 100755 zjlenv/bin/pybabel create mode 100755 zjlenv/bin/pybtex create mode 100755 zjlenv/bin/pybtex-convert create mode 100755 zjlenv/bin/pybtex-format create mode 100755 zjlenv/bin/pycc create mode 100755 zjlenv/bin/pycodestyle create mode 100755 zjlenv/bin/pydocstyle create mode 100755 zjlenv/bin/pyflakes create mode 100755 zjlenv/bin/pygmentize create mode 100755 zjlenv/bin/pyrsa-decrypt create mode 100755 zjlenv/bin/pyrsa-encrypt create mode 100755 zjlenv/bin/pyrsa-keygen create mode 100755 zjlenv/bin/pyrsa-priv2pub create mode 100755 zjlenv/bin/pyrsa-sign create mode 100755 zjlenv/bin/pyrsa-verify create mode 100755 zjlenv/bin/pytest create mode 120000 zjlenv/bin/python create mode 100755 zjlenv/bin/python-config create mode 120000 zjlenv/bin/python3 create mode 100755 zjlenv/bin/python3.7 create mode 100755 zjlenv/bin/ray create mode 100755 zjlenv/bin/restructuredtext-lint create mode 100755 zjlenv/bin/rllib create mode 100755 zjlenv/bin/rst-lint create mode 100755 zjlenv/bin/rst2html.py create mode 100755 zjlenv/bin/rst2html4.py create mode 100755 zjlenv/bin/rst2html5.py create mode 100755 zjlenv/bin/rst2latex.py create mode 100755 zjlenv/bin/rst2man.py create mode 100755 zjlenv/bin/rst2odt.py create mode 100755 zjlenv/bin/rst2odt_prepstyles.py create mode 100755 zjlenv/bin/rst2pseudoxml.py create mode 100755 zjlenv/bin/rst2s5.py create mode 100755 zjlenv/bin/rst2xetex.py create mode 100755 zjlenv/bin/rst2xml.py create mode 100755 zjlenv/bin/rstpep2html.py create mode 100755 zjlenv/bin/sphinx-apidoc create mode 100755 zjlenv/bin/sphinx-autogen create mode 100755 zjlenv/bin/sphinx-build create mode 100755 zjlenv/bin/sphinx-quickstart create mode 100755 zjlenv/bin/stubgen create mode 100755 zjlenv/bin/stubtest create mode 100755 zjlenv/bin/tensorboard create mode 100755 zjlenv/bin/tqdm create mode 100755 zjlenv/bin/tune create mode 100755 zjlenv/bin/wheel create mode 120000 zjlenv/include/python3.7m diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py new file mode 100644 index 000000000..47fe9ddcb --- /dev/null +++ b/test/discrete/test_bcq.py @@ -0,0 +1,140 @@ +from tianshou.policy import BCQPolicy +from tianshou.policy import BCQN +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter +import random +from tianshou.policy import DQNPolicy +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offline_trainer +from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer + + + +parser = argparse.ArgumentParser() +parser.add_argument('--task', type=str, default='CartPole-v0') +parser.add_argument('--seed', type=int, default=1626) +parser.add_argument('--batch-size', type=int, default=2) +parser.add_argument('--epoch', type=int, default=10) +parser.add_argument('--test-num', type=int, default=1) +parser.add_argument('--logdir', type=str, default='log') +parser.add_argument('--buffer-size', type=int, default=10) +parser.add_argument('--hidden-dim', type=int, default=5) +parser.add_argument('--test-frequency', type=int, default=5) +parser.add_argument('--target-update-frequency', type=int, default=5) +parser.add_argument('--episode-per-test', type=int, default=5) +parser.add_argument('--tao', type=float, default=0.8) + +args = parser.parse_known_args()[0] + +env = gym.make(args.task) +state_shape = env.observation_space.shape or env.observation_space.n +state_shape = state_shape[0] +action_shape = env.action_space.shape or env.action_space.n + +# print(state_shape) +# print(action_shape) +# exit() + +model = BCQN(state_shape,action_shape,args.hidden_dim,args.hidden_dim) +optim = torch.optim.Adam(model.parameters(), lr=0.5) +policy = BCQPolicy(model, optim, args.tao, args.target_update_frequency, 'cpu') + +buffer = ReplayBuffer(size=args.buffer_size) +for i in range(args.buffer_size): + buffer.add(obs=torch.rand(state_shape), act=random.randint(0, action_shape-1), rew=1, done=False, obs_next=torch.rand(state_shape), info={}) +# buf.add(obs=[1.0, 1.0, 1.0, 1.0], act=[1], rew=1.0, done=False, obs_next=[1.0, 1.0, 1.0, 1.0], info={}) + + +# buffer.add(obs=torch.Tensor([1.0, 1.0, 1.0, 1.0]), act=1, rew=1, done=False, obs_next=torch.Tensor([2.0, 2.0, 2.0, 2.0]), info={}) +# buffer.add(obs=torch.Tensor([2.0, 2.0, 2.0, 2.0]), act=2, rew=2, done=True, obs_next=torch.Tensor([-1.0, -1.0, -1.0, -1.0]), info={}) + +test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) +# seed +np.random.seed(args.seed) +torch.manual_seed(args.seed) +# train_envs.seed(args.seed) +test_envs.seed(args.seed) + +test_collector = Collector(policy, test_envs) + +log_path = os.path.join(args.logdir, 'writer') +writer = SummaryWriter(log_path) + +best_policy_path = os.path.join(args.logdir, 'best_policy') + +res = offline_trainer(policy, buffer, test_collector, args.epoch, args.batch_size, args.episode_per_test, writer, args.test_frequency) +print('zjlbest_reward', res['best_reward']) + +# TODOzjl save policy torch.save(policy.state_dict(), os.path.join(best_policy_save_dir, 'policy.pth')) + + +# # batch = buffer.sample(1) +# print(buffer.obs) +# print(buffer.rew) +# buffer.rew = torch.Tensor([1,2]) +# print(buffer.rew) + + +# buffer.obs = torch.Tensor([[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]]) +# buffer.act = torch.Tensor([1, 2]) +# buffer.rew = torch.Tensor([10, 20]) +# buffer.done = torch.Tensor([False, True]) +# buffer.obs_next = torch.Tensor([[1.0, 1.0, 1.0, 1.0], [-1.0, -1.0, -1.0, -1.0]]) + +# print(buffer.obs) +# print(buffer) +# buffer.add( +# obs=torch.Tensor([[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]]), +# act=torch.Tensor([1, 2]), +# rew=torch.Tensor([1.0, 2.0]), +# done=torch.Tensor([False, True]), +# obs_next=torch.Tensor([[1.0, 1.0, 1.0, 1.0], [-1.0, -1.0, -1.0, -1.0]]), +# ) +# print(buffer) +# print(buffer) +# batch = buffer.sample(2)[0] +# print(batch) +# batch.to_torch() +# print(batch) +# batch = policy.forward(batch) +# print(batch) + +# loss = policy.learn(batch) +# print(loss) +########################### +# best_reward = -1 +# best_policy = policy + +# global_iter = 0 + +# total_iter = len(buffer) // args.batch_size +# for epoch in range(1, 1 + args.epoch): +# for iter in range(total_iter): +# global_iter += 1 +# loss = policy.update(args.batch_size, buffer) +# # batch = buffer.sample(args.batch_size)[0] +# # batch.to_torch() +# if global_iter % log_frequency == 0: +# writer.add_scalar( +# "train/loss", loss['loss'], global_step=global_iter) +# test_collector = Collector(policy, test_envs) + +# test_result = test_episode( +# policy, test_collector, None, +# epoch, args.episode_per_test, writer, global_iter) +# # for k in result.keys(): +# # writer.add_scalar( +# # "train/" + k, result[k], global_step=env_step) + +# if best_reward < result["rew"]: +# best_reward = result["rew"] +# best_policy = policy +# # epoch, args.episode_per_test, writer, env_step) + \ No newline at end of file diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 8ec9fa380..ecab9c017 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -27,7 +27,8 @@ def get_args(): parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--batch-size', type=int, default=2) + # parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=3) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index c8c572a22..038620dcd 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -431,6 +431,7 @@ def add( **kwargs: Any, ) -> None: """Add a batch of data into replay buffer.""" + print('zjlobs',obs) if weight is None: weight = self._max_prio else: diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 456993d5d..13e539ed2 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -9,6 +9,8 @@ from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.policy.modelfree.bcq import BCQPolicy +from tianshou.policy.modelfree.bcq import BCQN from tianshou.policy.modelbase.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -27,4 +29,5 @@ "DiscreteSACPolicy", "PSRLPolicy", "MultiAgentPolicyManager", + "BCQManager", ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 809751fe7..c9d33c754 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -165,6 +165,7 @@ def update( """ if buffer is None: return {} + print(buffer) batch, indice = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indice) diff --git a/tianshou/policy/modelfree/bcq.py b/tianshou/policy/modelfree/bcq.py new file mode 100644 index 000000000..78d4839b0 --- /dev/null +++ b/tianshou/policy/modelfree/bcq.py @@ -0,0 +1,223 @@ +import torch +import numpy as np +import torch.nn.functional as F +from typing import Any, Dict, Union, Optional + +from tianshou.data import Batch, to_torch +from tianshou.policy import BasePolicy +import torch.nn as nn +from copy import deepcopy + +from inspect import currentframe, getframeinfo + +frameinfo = getframeinfo(currentframe()) + +# def p(x): +# print(frameinfo.filename, frameinfo.lineno, x) + +class BCQN(nn.Module): + """BCQ NN for dialogue policy. It includes a net for imitation and a net for Q-value""" + + def __init__( + self, input_size, n_actions, imitation_model_hidden_dim, policy_model_hidden_dim + ): + super(BCQN, self).__init__() + self.q1 = nn.Linear(input_size, policy_model_hidden_dim) + self.q2 = nn.Linear(policy_model_hidden_dim, policy_model_hidden_dim) + self.q3 = nn.Linear(policy_model_hidden_dim, n_actions) + + self.i1 = nn.Linear(input_size, imitation_model_hidden_dim) + self.i2 = nn.Linear(imitation_model_hidden_dim, imitation_model_hidden_dim) + self.i3 = nn.Linear(imitation_model_hidden_dim, n_actions) + + def forward(self, state): + q = F.relu(self.q1(state)) + q = F.relu(self.q2(q)) + + i = F.relu(self.i1(state)) + i = F.relu(self.i2(i)) + i = F.relu(self.i3(i)) + return self.q3(q), F.log_softmax(i, dim=1), i + + +class BCQPolicy(BasePolicy): + """Implementation of vanilla imitation learning. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> a) + :param torch.optim.Optimizer optim: for optimizing the model. + :param str mode: indicate the imitation type ("continuous" or "discrete" + action space), defaults to "continuous". + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + tao: float, + target_update_frequency: int, + device: str, + gamma: float = 0.9, + imitation_logits_penalty: float=0.1, + # mode: str = "continuous", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._policy_net = model + self._optimizer = optim + self._cnt = 0 + self._device = device + self._gamma = gamma + # assert ( + # mode in ["continuous", "discrete"] + # ), f"Mode {mode} is not in ['continuous', 'discrete']." + # self.mode = mode + + #TODOzjl chuan can shu + self._tao = tao + self._target_net = deepcopy(self._policy_net) + self._target_net.eval() + self._target_update_frequency = target_update_frequency + self._imitation_logits_penalty = imitation_logits_penalty + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + # input: str = "obs", + **kwargs: Any, + ) -> Batch: + # logits, h = self.model(batch.obs, state=state, info=batch.info) + # if self.mode == "discrete": + # a = logits.max(dim=1)[1] + # else: + # a = logits + batch.to_torch() + state = batch.obs + # state = batch['obs'] + q, imt, _ = self._policy_net(state.float()) + imt = imt.exp() + imt = (imt / imt.max(1, keepdim=True)[0] > self._tao).float() + # Use large negative number to mask actions from argmax + action = (imt * q + (1.0 - imt) * -1e8).argmax(1) + return Batch(act=action, state=state, qvalue=q) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + + batch.to_torch() + + if self._cnt % self._target_update_frequency == 0: + # self._logger.info("Updating target network...") + self._target_net.load_state_dict(self._policy_net.state_dict()) + self._target_net.eval() + + # state = self(batch, input='obs_next') + # qvalue = batch. + + non_final_mask = torch.tensor( + tuple(map(lambda s: not s, batch.done)), + device=self._device, + dtype=torch.bool, + ) + # p(non_final_mask) + try: + non_final_next_states = torch.cat( + [s.obs.unsqueeze(0) for s in batch if not s.done], dim=0 + ) + except Exception: + non_final_next_states = None + + # print(non_final_next_states) + + # # Compute the target Q value + with torch.no_grad(): + expected_state_action_values = batch.rew.float() + + # Add target Q value for non-final next_state + if non_final_next_states is not None: + q, imt, _ = self._policy_net(non_final_next_states) + imt = imt.exp() + imt = (imt / imt.max(1, keepdim=True)[0] > self._tao).float() + # print(imt) + # print(q) + # Use large negative number to mask actions from argmax + next_action = (imt * q + (1 - imt) * -1e8).argmax(1, keepdim=True) + + print('zjlnext_action', next_action) + + q, _, _ = self._target_net(non_final_next_states) + q = q.gather(1, next_action).reshape(-1, 1) + + print('zjlq', q) + + # print('zjlaaa', torch.zeros(len(batch), device=self._device).float()) + + next_state_values = ( + torch.zeros(len(batch), device=self._device) + # torch.zeros(self._batch_size, device=self._device) + .float() + # .unsqueeze(1) + ) + print('zjl1', next_state_values) + print('zjlnon_final_mask', non_final_mask) + print('zjlq', q) + next_state_values[non_final_mask] = q.squeeze() + print('zjl2', next_state_values) + + print(expected_state_action_values) + # print(next_state_values) + expected_state_action_values += next_state_values * self._gamma + + # # Get current Q estimate + current_Q, imt, i = self._policy_net(batch.obs) + + print('zjlcurrent_Q', current_Q) + print('zjlbatch.act', batch.act) + current_Q = current_Q.gather(1, batch.act.unsqueeze(1)).squeeze() + + print('zjlfinal', current_Q, expected_state_action_values) + + # # Compute Q loss + q_loss = F.smooth_l1_loss(current_Q, expected_state_action_values) + i_loss = F.nll_loss(imt, batch.act.reshape(-1)) + + Q_loss = q_loss + i_loss + self._imitation_logits_penalty * i.pow(2).mean() + + self._optimizer.zero_grad() + Q_loss.backward() + self._optimizer.step() + + return {"loss": Q_loss.item()} + + + # if self._target and self._cnt % self._freq == 0: + # self.sync_weight() + # self.optim.zero_grad() + # weight = batch.pop("weight", 1.0) + # q = self(batch).logits + # q = q[np.arange(len(q)), batch.act] + # r = to_torch_as(batch.returns.flatten(), q) + # td = r - q + # loss = (td.pow(2) * weight).mean() + # batch.weight = td # prio-buffer + # loss.backward() + # self.optim.step() + # self._cnt += 1 + # return {"loss": loss.item()} + + # if self.mode == "continuous": # regression + # a = self(batch).act + # a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) + # loss = F.mse_loss(a, a_) # type: ignore + # elif self.mode == "discrete": # classification + # a = self(batch).logits + # a_ = to_torch(batch.act, dtype=torch.long, device=a.device) + # loss = F.nll_loss(a, a_) # type: ignore + # loss.backward() + # self.optim.step() + # return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 91cca6139..ac87d1ce7 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -165,14 +165,19 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() + print('zjlbatch', batch) weight = batch.pop("weight", 1.0) q = self(batch).logits + print('zjlq1', q) q = q[np.arange(len(q)), batch.act] + print('zjlq2', q) r = to_torch_as(batch.returns.flatten(), q) + print('zjlr', r) td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() self._cnt += 1 + exit() return {"loss": loss.item()} diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 36a8ed487..22fc1eea1 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,10 +1,12 @@ from tianshou.trainer.utils import test_episode, gather_info from tianshou.trainer.onpolicy import onpolicy_trainer from tianshou.trainer.offpolicy import offpolicy_trainer +from tianshou.trainer.offline import offline_trainer __all__ = [ "gather_info", "test_episode", "onpolicy_trainer", "offpolicy_trainer", + "offline_trainer", ] diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py new file mode 100644 index 000000000..34376ba82 --- /dev/null +++ b/tianshou/trainer/offline.py @@ -0,0 +1,145 @@ +import time +import tqdm +from torch.utils.tensorboard import SummaryWriter +from typing import Dict, List, Union, Callable, Optional + +from tianshou.data import Collector, ReplayBuffer +from tianshou.policy import BasePolicy +from tianshou.utils import tqdm_config, MovAvg +from tianshou.trainer import test_episode, gather_info + + +def offline_trainer( + policy: BasePolicy, + buffer: ReplayBuffer, + # train_collector: Collector, + test_collector: Collector, + # max_epoch: int, + epochs: int, + # step_per_epoch: int, + # collect_per_step: int, + # episode_per_test: Union[int, List[int]], + batch_size: int, + episode_per_test: int, + # best_policy_save_dir: Optional[str], + # update_per_step: int = 1, + # train_fn: Optional[Callable[[int, int], None]] = None, + # test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + # stop_fn: Optional[Callable[[float], bool]] = None, + # save_fn: Optional[Callable[[BasePolicy], None]] = None, + writer: Optional[SummaryWriter] = None, + test_frequency: int = 1, + # verbose: bool = True, + # test_in_train: bool = True, +) -> Dict[str, Union[float, str]]: + + best_reward = -1 + best_policy = policy + + total_iter = 0 + + iter_per_epoch = len(buffer) // batch_size + for epoch in range(1, 1 + epochs): + for iter in range(iter_per_epoch): + total_iter += 1 + loss = policy.update(batch_size, buffer) + # batch = buffer.sample(args.batch_size)[0] + # batch.to_torch() + if total_iter % test_frequency == 0: + writer.add_scalar( + "train/loss", loss['loss'], global_step=total_iter) + # test_collector = Collector(policy, test_envs) + + test_result = test_episode( + policy, test_collector, None, + epoch, episode_per_test, writer, total_iter) + # for k in result.keys(): + # writer.add_scalar( + # "train/" + k, result[k], global_step=env_step) + + if best_reward < test_result["rew"]: + best_reward = test_result["rew"] + best_policy = policy + # epoch, args.episode_per_test, writer, env_step) + + + return {'best_reward': best_reward, 'best_policy': best_policy} + + # env_step, gradient_step = 0, 0 + # best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 + # stat: Dict[str, MovAvg] = {} + # start_time = time.time() + # train_collector.reset_stat() + # test_collector.reset_stat() + # test_in_train = test_in_train and train_collector.policy == policy + # for epoch in range(1, 1 + max_epoch): + # # train + # policy.train() + # with tqdm.tqdm( + # total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + # ) as t: + # while t.n < t.total: + # if train_fn: + # train_fn(epoch, env_step) + # result = train_collector.collect(n_step=collect_per_step) + # env_step += int(result["n/st"]) + # data = { + # "env_step": str(env_step), + # "rew": f"{result['rew']:.2f}", + # "len": str(int(result["len"])), + # "n/ep": str(int(result["n/ep"])), + # "n/st": str(int(result["n/st"])), + # "v/ep": f"{result['v/ep']:.2f}", + # "v/st": f"{result['v/st']:.2f}", + # } + # if writer and env_step % log_interval == 0: + # for k in result.keys(): + # writer.add_scalar( + # "train/" + k, result[k], global_step=env_step) + # if test_in_train and stop_fn and stop_fn(result["rew"]): + # test_result = test_episode( + # policy, test_collector, test_fn, + # epoch, episode_per_test, writer, env_step) + # if stop_fn(test_result["rew"]): + # if save_fn: + # save_fn(policy) + # for k in result.keys(): + # data[k] = f"{result[k]:.2f}" + # t.set_postfix(**data) + # return gather_info( + # start_time, train_collector, test_collector, + # test_result["rew"], test_result["rew_std"]) + # else: + # policy.train() + # for i in range(update_per_step * min( + # result["n/st"] // collect_per_step, t.total - t.n)): + # gradient_step += 1 + # losses = policy.update(batch_size, train_collector.buffer) + # for k in losses.keys(): + # if stat.get(k) is None: + # stat[k] = MovAvg() + # stat[k].add(losses[k]) + # data[k] = f"{stat[k].get():.6f}" + # if writer and gradient_step % log_interval == 0: + # writer.add_scalar( + # k, stat[k].get(), global_step=gradient_step) + # t.update(1) + # t.set_postfix(**data) + # if t.n <= t.total: + # t.update() + # # test + # result = test_episode(policy, test_collector, test_fn, epoch, + # episode_per_test, writer, env_step) + # if best_epoch == -1 or best_reward < result["rew"]: + # best_reward, best_reward_std = result["rew"], result["rew_std"] + # best_epoch = epoch + # if save_fn: + # save_fn(policy) + # if verbose: + # print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " + # f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " + # f"{best_reward_std:.6f} in #{best_epoch}") + # if stop_fn and stop_fn(best_reward): + # break + # return gather_info(start_time, train_collector, test_collector, + # best_reward, best_reward_std) diff --git a/zjlenv/bin/activate b/zjlenv/bin/activate new file mode 100644 index 000000000..2174f277c --- /dev/null +++ b/zjlenv/bin/activate @@ -0,0 +1,78 @@ +# This file must be used with "source bin/activate" *from bash* +# you cannot run it directly + +deactivate () { + unset -f pydoc >/dev/null 2>&1 + + # reset old environment variables + # ! [ -z ${VAR+_} ] returns true if VAR is declared at all + if ! [ -z "${_OLD_VIRTUAL_PATH+_}" ] ; then + PATH="$_OLD_VIRTUAL_PATH" + export PATH + unset _OLD_VIRTUAL_PATH + fi + if ! [ -z "${_OLD_VIRTUAL_PYTHONHOME+_}" ] ; then + PYTHONHOME="$_OLD_VIRTUAL_PYTHONHOME" + export PYTHONHOME + unset _OLD_VIRTUAL_PYTHONHOME + fi + + # This should detect bash and zsh, which have a hash command that must + # be called to get it to forget past commands. Without forgetting + # past commands the $PATH changes we made may not be respected + if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ] ; then + hash -r 2>/dev/null + fi + + if ! [ -z "${_OLD_VIRTUAL_PS1+_}" ] ; then + PS1="$_OLD_VIRTUAL_PS1" + export PS1 + unset _OLD_VIRTUAL_PS1 + fi + + unset VIRTUAL_ENV + if [ ! "${1-}" = "nondestructive" ] ; then + # Self destruct! + unset -f deactivate + fi +} + +# unset irrelevant variables +deactivate nondestructive + +VIRTUAL_ENV="/Users/jialu.zhu/Desktop/tianshou/zjlenv" +export VIRTUAL_ENV + +_OLD_VIRTUAL_PATH="$PATH" +PATH="$VIRTUAL_ENV/bin:$PATH" +export PATH + +# unset PYTHONHOME if set +if ! [ -z "${PYTHONHOME+_}" ] ; then + _OLD_VIRTUAL_PYTHONHOME="$PYTHONHOME" + unset PYTHONHOME +fi + +if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT-}" ] ; then + _OLD_VIRTUAL_PS1="${PS1-}" + if [ "x" != x ] ; then + PS1="${PS1-}" + else + PS1="(`basename \"$VIRTUAL_ENV\"`) ${PS1-}" + fi + export PS1 +fi + +# Make sure to unalias pydoc if it's already there +alias pydoc 2>/dev/null >/dev/null && unalias pydoc || true + +pydoc () { + python -m pydoc "$@" +} + +# This should detect bash and zsh, which have a hash command that must +# be called to get it to forget past commands. Without forgetting +# past commands the $PATH changes we made may not be respected +if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ] ; then + hash -r 2>/dev/null +fi diff --git a/zjlenv/bin/activate.csh b/zjlenv/bin/activate.csh new file mode 100644 index 000000000..eb7c950f9 --- /dev/null +++ b/zjlenv/bin/activate.csh @@ -0,0 +1,42 @@ +# This file must be used with "source bin/activate.csh" *from csh*. +# You cannot run it directly. +# Created by Davide Di Blasi . + +set newline='\ +' + +alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH:q" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT:q" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; test "\!:*" != "nondestructive" && unalias deactivate && unalias pydoc' + +# Unset irrelevant variables. +deactivate nondestructive + +setenv VIRTUAL_ENV "/Users/jialu.zhu/Desktop/tianshou/zjlenv" + +set _OLD_VIRTUAL_PATH="$PATH:q" +setenv PATH "$VIRTUAL_ENV:q/bin:$PATH:q" + + + +if ("" != "") then + set env_name = "" +else + set env_name = "$VIRTUAL_ENV:t:q" +endif + +# Could be in a non-interactive environment, +# in which case, $prompt is undefined and we wouldn't +# care about the prompt anyway. +if ( $?prompt ) then + set _OLD_VIRTUAL_PROMPT="$prompt:q" +if ( "$prompt:q" =~ *"$newline:q"* ) then + : +else + set prompt = "[$env_name:q] $prompt:q" +endif +endif + +unset env_name + +alias pydoc python -m pydoc + +rehash diff --git a/zjlenv/bin/activate.fish b/zjlenv/bin/activate.fish new file mode 100644 index 000000000..92f3e7df2 --- /dev/null +++ b/zjlenv/bin/activate.fish @@ -0,0 +1,101 @@ +# This file must be used using `source bin/activate.fish` *within a running fish ( http://fishshell.com ) session*. +# Do not run it directly. + +function _bashify_path -d "Converts a fish path to something bash can recognize" + set fishy_path $argv + set bashy_path $fishy_path[1] + for path_part in $fishy_path[2..-1] + set bashy_path "$bashy_path:$path_part" + end + echo $bashy_path +end + +function _fishify_path -d "Converts a bash path to something fish can recognize" + echo $argv | tr ':' '\n' +end + +function deactivate -d 'Exit virtualenv mode and return to the normal environment.' + # reset old environment variables + if test -n "$_OLD_VIRTUAL_PATH" + # https://github.com/fish-shell/fish-shell/issues/436 altered PATH handling + if test (echo $FISH_VERSION | tr "." "\n")[1] -lt 3 + set -gx PATH (_fishify_path $_OLD_VIRTUAL_PATH) + else + set -gx PATH $_OLD_VIRTUAL_PATH + end + set -e _OLD_VIRTUAL_PATH + end + + if test -n "$_OLD_VIRTUAL_PYTHONHOME" + set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME + set -e _OLD_VIRTUAL_PYTHONHOME + end + + if test -n "$_OLD_FISH_PROMPT_OVERRIDE" + # Set an empty local `$fish_function_path` to allow the removal of `fish_prompt` using `functions -e`. + set -l fish_function_path + + # Erase virtualenv's `fish_prompt` and restore the original. + functions -e fish_prompt + functions -c _old_fish_prompt fish_prompt + functions -e _old_fish_prompt + set -e _OLD_FISH_PROMPT_OVERRIDE + end + + set -e VIRTUAL_ENV + + if test "$argv[1]" != 'nondestructive' + # Self-destruct! + functions -e pydoc + functions -e deactivate + functions -e _bashify_path + functions -e _fishify_path + end +end + +# Unset irrelevant variables. +deactivate nondestructive + +set -gx VIRTUAL_ENV "/Users/jialu.zhu/Desktop/tianshou/zjlenv" + +# https://github.com/fish-shell/fish-shell/issues/436 altered PATH handling +if test (echo $FISH_VERSION | tr "." "\n")[1] -lt 3 + set -gx _OLD_VIRTUAL_PATH (_bashify_path $PATH) +else + set -gx _OLD_VIRTUAL_PATH $PATH +end +set -gx PATH "$VIRTUAL_ENV/bin" $PATH + +# Unset `$PYTHONHOME` if set. +if set -q PYTHONHOME + set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME + set -e PYTHONHOME +end + +function pydoc + python -m pydoc $argv +end + +if test -z "$VIRTUAL_ENV_DISABLE_PROMPT" + # Copy the current `fish_prompt` function as `_old_fish_prompt`. + functions -c fish_prompt _old_fish_prompt + + function fish_prompt + # Save the current $status, for fish_prompts that display it. + set -l old_status $status + + # Prompt override provided? + # If not, just prepend the environment name. + if test -n "" + printf '%s%s' "" (set_color normal) + else + printf '%s(%s) ' (set_color normal) (basename "$VIRTUAL_ENV") + end + + # Restore the original $status + echo "exit $old_status" | source + _old_fish_prompt + end + + set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV" +end diff --git a/zjlenv/bin/activate.ps1 b/zjlenv/bin/activate.ps1 new file mode 100644 index 000000000..6d8ae2aa4 --- /dev/null +++ b/zjlenv/bin/activate.ps1 @@ -0,0 +1,60 @@ +# This file must be dot sourced from PoSh; you cannot run it directly. Do this: . ./activate.ps1 + +$script:THIS_PATH = $myinvocation.mycommand.path +$script:BASE_DIR = split-path (resolve-path "$THIS_PATH/..") -Parent + +function global:deactivate([switch] $NonDestructive) +{ + if (test-path variable:_OLD_VIRTUAL_PATH) + { + $env:PATH = $variable:_OLD_VIRTUAL_PATH + remove-variable "_OLD_VIRTUAL_PATH" -scope global + } + + if (test-path function:_old_virtual_prompt) + { + $function:prompt = $function:_old_virtual_prompt + remove-item function:\_old_virtual_prompt + } + + if ($env:VIRTUAL_ENV) + { + $old_env = split-path $env:VIRTUAL_ENV -leaf + remove-item env:VIRTUAL_ENV -erroraction silentlycontinue + } + + if (!$NonDestructive) + { + # Self destruct! + remove-item function:deactivate + remove-item function:pydoc + } +} + +function global:pydoc +{ + python -m pydoc $args +} + +# unset irrelevant variables +deactivate -nondestructive + +$VIRTUAL_ENV = $BASE_DIR +$env:VIRTUAL_ENV = $VIRTUAL_ENV + +$global:_OLD_VIRTUAL_PATH = $env:PATH +$env:PATH = "$env:VIRTUAL_ENV/bin:" + $env:PATH +if (!$env:VIRTUAL_ENV_DISABLE_PROMPT) +{ + function global:_old_virtual_prompt + { + "" + } + $function:_old_virtual_prompt = $function:prompt + function global:prompt + { + # Add a prefix to the current prompt, but don't discard it. + write-host "($( split-path $env:VIRTUAL_ENV -leaf )) " -nonewline + & $function:_old_virtual_prompt + } +} diff --git a/zjlenv/bin/activate.xsh b/zjlenv/bin/activate.xsh new file mode 100644 index 000000000..2dde4b647 --- /dev/null +++ b/zjlenv/bin/activate.xsh @@ -0,0 +1,39 @@ +"""Xonsh activate script for virtualenv""" +from xonsh.tools import get_sep as _get_sep + +def _deactivate(args): + if "pydoc" in aliases: + del aliases["pydoc"] + + if ${...}.get("_OLD_VIRTUAL_PATH", ""): + $PATH = $_OLD_VIRTUAL_PATH + del $_OLD_VIRTUAL_PATH + + if ${...}.get("_OLD_VIRTUAL_PYTHONHOME", ""): + $PYTHONHOME = $_OLD_VIRTUAL_PYTHONHOME + del $_OLD_VIRTUAL_PYTHONHOME + + if "VIRTUAL_ENV" in ${...}: + del $VIRTUAL_ENV + + if "nondestructive" not in args: + # Self destruct! + del aliases["deactivate"] + + +# unset irrelevant variables +_deactivate(["nondestructive"]) +aliases["deactivate"] = _deactivate + +$VIRTUAL_ENV = r"/Users/jialu.zhu/Desktop/tianshou/zjlenv" + +$_OLD_VIRTUAL_PATH = $PATH +$PATH = $PATH[:] +$PATH.add($VIRTUAL_ENV + _get_sep() + "bin", front=True, replace=True) + +if ${...}.get("PYTHONHOME", ""): + # unset PYTHONHOME if set + $_OLD_VIRTUAL_PYTHONHOME = $PYTHONHOME + del $PYTHONHOME + +aliases["pydoc"] = ["python", "-m", "pydoc"] diff --git a/zjlenv/bin/activate_this.py b/zjlenv/bin/activate_this.py new file mode 100644 index 000000000..59b5d7242 --- /dev/null +++ b/zjlenv/bin/activate_this.py @@ -0,0 +1,46 @@ +"""Activate virtualenv for current interpreter: + +Use exec(open(this_file).read(), {'__file__': this_file}). + +This can be used when you must use an existing Python interpreter, not the virtualenv bin/python. +""" +import os +import site +import sys + +try: + __file__ +except NameError: + raise AssertionError("You must use exec(open(this_file).read(), {'__file__': this_file}))") + +# prepend bin to PATH (this file is inside the bin directory) +bin_dir = os.path.dirname(os.path.abspath(__file__)) +os.environ["PATH"] = os.pathsep.join([bin_dir] + os.environ.get("PATH", "").split(os.pathsep)) + +base = os.path.dirname(bin_dir) + +# virtual env is right above bin directory +os.environ["VIRTUAL_ENV"] = base + +# add the virtual environments site-package to the host python import mechanism +IS_PYPY = hasattr(sys, "pypy_version_info") +IS_JYTHON = sys.platform.startswith("java") +if IS_JYTHON: + site_packages = os.path.join(base, "Lib", "site-packages") +elif IS_PYPY: + site_packages = os.path.join(base, "site-packages") +else: + IS_WIN = sys.platform == "win32" + if IS_WIN: + site_packages = os.path.join(base, "Lib", "site-packages") + else: + site_packages = os.path.join(base, "lib", "python{}".format(sys.version[:3]), "site-packages") + +prev = set(sys.path) +site.addsitedir(site_packages) +sys.real_prefix = sys.prefix +sys.prefix = base + +# Move the added items to the front of the path, in place +new = list(sys.path) +sys.path[:] = [i for i in new if i not in prev] + [i for i in new if i in prev] diff --git a/zjlenv/bin/chardetect b/zjlenv/bin/chardetect new file mode 100755 index 000000000..39be4a317 --- /dev/null +++ b/zjlenv/bin/chardetect @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from chardet.cli.chardetect import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/convert-caffe2-to-onnx b/zjlenv/bin/convert-caffe2-to-onnx new file mode 100755 index 000000000..8da84784a --- /dev/null +++ b/zjlenv/bin/convert-caffe2-to-onnx @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from caffe2.python.onnx.bin.conversion import caffe2_to_onnx +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(caffe2_to_onnx()) diff --git a/zjlenv/bin/convert-onnx-to-caffe2 b/zjlenv/bin/convert-onnx-to-caffe2 new file mode 100755 index 000000000..0a3e867c7 --- /dev/null +++ b/zjlenv/bin/convert-onnx-to-caffe2 @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from caffe2.python.onnx.bin.conversion import onnx_to_caffe2 +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(onnx_to_caffe2()) diff --git a/zjlenv/bin/coverage b/zjlenv/bin/coverage new file mode 100755 index 000000000..1d31b1d4d --- /dev/null +++ b/zjlenv/bin/coverage @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from coverage.cmdline import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/coverage-3.7 b/zjlenv/bin/coverage-3.7 new file mode 100755 index 000000000..1d31b1d4d --- /dev/null +++ b/zjlenv/bin/coverage-3.7 @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from coverage.cmdline import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/coverage3 b/zjlenv/bin/coverage3 new file mode 100755 index 000000000..1d31b1d4d --- /dev/null +++ b/zjlenv/bin/coverage3 @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from coverage.cmdline import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/dmypy b/zjlenv/bin/dmypy new file mode 100755 index 000000000..8c7401b5a --- /dev/null +++ b/zjlenv/bin/dmypy @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from mypy.dmypy.client import console_entry +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(console_entry()) diff --git a/zjlenv/bin/doc8 b/zjlenv/bin/doc8 new file mode 100755 index 000000000..d6cd48c32 --- /dev/null +++ b/zjlenv/bin/doc8 @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from doc8.main import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/easy_install b/zjlenv/bin/easy_install new file mode 100755 index 000000000..11711685f --- /dev/null +++ b/zjlenv/bin/easy_install @@ -0,0 +1,10 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys + +from setuptools.command.easy_install import main + +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/easy_install-3.7 b/zjlenv/bin/easy_install-3.7 new file mode 100755 index 000000000..11711685f --- /dev/null +++ b/zjlenv/bin/easy_install-3.7 @@ -0,0 +1,10 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys + +from setuptools.command.easy_install import main + +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/f2py b/zjlenv/bin/f2py new file mode 100755 index 000000000..3ec4f2c7a --- /dev/null +++ b/zjlenv/bin/f2py @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from numpy.f2py.f2py2e import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/f2py3 b/zjlenv/bin/f2py3 new file mode 100755 index 000000000..3ec4f2c7a --- /dev/null +++ b/zjlenv/bin/f2py3 @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from numpy.f2py.f2py2e import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/f2py3.7 b/zjlenv/bin/f2py3.7 new file mode 100755 index 000000000..3ec4f2c7a --- /dev/null +++ b/zjlenv/bin/f2py3.7 @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from numpy.f2py.f2py2e import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/flake8 b/zjlenv/bin/flake8 new file mode 100755 index 000000000..2a55e0048 --- /dev/null +++ b/zjlenv/bin/flake8 @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from flake8.main.cli import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/futurize b/zjlenv/bin/futurize new file mode 100755 index 000000000..cc2e6b567 --- /dev/null +++ b/zjlenv/bin/futurize @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from libfuturize.main import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/google-oauthlib-tool b/zjlenv/bin/google-oauthlib-tool new file mode 100755 index 000000000..705e66710 --- /dev/null +++ b/zjlenv/bin/google-oauthlib-tool @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from google_auth_oauthlib.tool.__main__ import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/jsonschema b/zjlenv/bin/jsonschema new file mode 100755 index 000000000..22598c645 --- /dev/null +++ b/zjlenv/bin/jsonschema @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from jsonschema.cli import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/markdown_py b/zjlenv/bin/markdown_py new file mode 100755 index 000000000..837d31afd --- /dev/null +++ b/zjlenv/bin/markdown_py @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from markdown.__main__ import run +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(run()) diff --git a/zjlenv/bin/mypy b/zjlenv/bin/mypy new file mode 100755 index 000000000..2b9ec040c --- /dev/null +++ b/zjlenv/bin/mypy @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from mypy.__main__ import console_entry +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(console_entry()) diff --git a/zjlenv/bin/mypyc b/zjlenv/bin/mypyc new file mode 100755 index 000000000..e9ff9b11a --- /dev/null +++ b/zjlenv/bin/mypyc @@ -0,0 +1,55 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +"""Mypyc command-line tool. + +Usage: + + $ mypyc foo.py [...] + $ python3 -c 'import foo' # Uses compiled 'foo' + + +This is just a thin wrapper that generates a setup.py file that uses +mypycify, suitable for prototyping and testing. +""" + +import os +import os.path +import subprocess +import sys +import tempfile +import time + +base_path = os.path.join(os.path.dirname(__file__), '..') + +setup_format = """\ +from distutils.core import setup +from mypyc.build import mypycify + +setup(name='mypyc_output', + ext_modules=mypycify({}, opt_level="{}"), +) +""" + +def main() -> None: + build_dir = 'build' # can this be overridden?? + try: + os.mkdir(build_dir) + except FileExistsError: + pass + + opt_level = os.getenv("MYPYC_OPT_LEVEL", '3') + + setup_file = os.path.join(build_dir, 'setup.py') + with open(setup_file, 'w') as f: + f.write(setup_format.format(sys.argv[1:], opt_level)) + + # We don't use run_setup (like we do in the test suite) because it throws + # away the error code from distutils, and we don't care about the slight + # performance loss here. + env = os.environ.copy() + base_path = os.path.join(os.path.dirname(__file__), '..') + env['PYTHONPATH'] = base_path + os.pathsep + env.get('PYTHONPATH', '') + cmd = subprocess.run([sys.executable, setup_file, 'build_ext', '--inplace'], env=env) + sys.exit(cmd.returncode) + +if __name__ == '__main__': + main() diff --git a/zjlenv/bin/numba b/zjlenv/bin/numba new file mode 100755 index 000000000..3eedf343e --- /dev/null +++ b/zjlenv/bin/numba @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: UTF-8 -*- +from __future__ import print_function, division, absolute_import + +from numba.misc.numba_entry import main + +if __name__ == "__main__": + main() diff --git a/zjlenv/bin/pasteurize b/zjlenv/bin/pasteurize new file mode 100755 index 000000000..971710717 --- /dev/null +++ b/zjlenv/bin/pasteurize @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from libpasteurize.main import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pbr b/zjlenv/bin/pbr new file mode 100755 index 000000000..63eb221d5 --- /dev/null +++ b/zjlenv/bin/pbr @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pbr.cmd.main import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pip b/zjlenv/bin/pip new file mode 100755 index 000000000..90b0b9d2b --- /dev/null +++ b/zjlenv/bin/pip @@ -0,0 +1,10 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys + +from pip._internal.cli.main import main + +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pip3 b/zjlenv/bin/pip3 new file mode 100755 index 000000000..90b0b9d2b --- /dev/null +++ b/zjlenv/bin/pip3 @@ -0,0 +1,10 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys + +from pip._internal.cli.main import main + +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pip3.7 b/zjlenv/bin/pip3.7 new file mode 100755 index 000000000..90b0b9d2b --- /dev/null +++ b/zjlenv/bin/pip3.7 @@ -0,0 +1,10 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys + +from pip._internal.cli.main import main + +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/py.test b/zjlenv/bin/py.test new file mode 100755 index 000000000..76ae5d160 --- /dev/null +++ b/zjlenv/bin/py.test @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pytest import console_main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(console_main()) diff --git a/zjlenv/bin/pybabel b/zjlenv/bin/pybabel new file mode 100755 index 000000000..6de123547 --- /dev/null +++ b/zjlenv/bin/pybabel @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from babel.messages.frontend import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pybtex b/zjlenv/bin/pybtex new file mode 100755 index 000000000..3a2fe04c1 --- /dev/null +++ b/zjlenv/bin/pybtex @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pybtex.__main__ import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pybtex-convert b/zjlenv/bin/pybtex-convert new file mode 100755 index 000000000..22a0ff550 --- /dev/null +++ b/zjlenv/bin/pybtex-convert @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pybtex.database.convert.__main__ import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pybtex-format b/zjlenv/bin/pybtex-format new file mode 100755 index 000000000..f4982fc8e --- /dev/null +++ b/zjlenv/bin/pybtex-format @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pybtex.database.format.__main__ import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pycc b/zjlenv/bin/pycc new file mode 100755 index 000000000..fd239d851 --- /dev/null +++ b/zjlenv/bin/pycc @@ -0,0 +1,3 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +from numba.pycc import main +main() diff --git a/zjlenv/bin/pycodestyle b/zjlenv/bin/pycodestyle new file mode 100755 index 000000000..c20b84a57 --- /dev/null +++ b/zjlenv/bin/pycodestyle @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pycodestyle import _main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(_main()) diff --git a/zjlenv/bin/pydocstyle b/zjlenv/bin/pydocstyle new file mode 100755 index 000000000..aab2f9446 --- /dev/null +++ b/zjlenv/bin/pydocstyle @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pydocstyle.cli import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pyflakes b/zjlenv/bin/pyflakes new file mode 100755 index 000000000..de8fd38b8 --- /dev/null +++ b/zjlenv/bin/pyflakes @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pyflakes.api import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pygmentize b/zjlenv/bin/pygmentize new file mode 100755 index 000000000..811516c4e --- /dev/null +++ b/zjlenv/bin/pygmentize @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pygments.cmdline import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/pyrsa-decrypt b/zjlenv/bin/pyrsa-decrypt new file mode 100755 index 000000000..a56dfc1e3 --- /dev/null +++ b/zjlenv/bin/pyrsa-decrypt @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from rsa.cli import decrypt +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(decrypt()) diff --git a/zjlenv/bin/pyrsa-encrypt b/zjlenv/bin/pyrsa-encrypt new file mode 100755 index 000000000..40244720b --- /dev/null +++ b/zjlenv/bin/pyrsa-encrypt @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from rsa.cli import encrypt +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(encrypt()) diff --git a/zjlenv/bin/pyrsa-keygen b/zjlenv/bin/pyrsa-keygen new file mode 100755 index 000000000..17d3b5200 --- /dev/null +++ b/zjlenv/bin/pyrsa-keygen @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from rsa.cli import keygen +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(keygen()) diff --git a/zjlenv/bin/pyrsa-priv2pub b/zjlenv/bin/pyrsa-priv2pub new file mode 100755 index 000000000..2a0676a54 --- /dev/null +++ b/zjlenv/bin/pyrsa-priv2pub @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from rsa.util import private_to_public +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(private_to_public()) diff --git a/zjlenv/bin/pyrsa-sign b/zjlenv/bin/pyrsa-sign new file mode 100755 index 000000000..057548544 --- /dev/null +++ b/zjlenv/bin/pyrsa-sign @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from rsa.cli import sign +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(sign()) diff --git a/zjlenv/bin/pyrsa-verify b/zjlenv/bin/pyrsa-verify new file mode 100755 index 000000000..cf4ba7bc1 --- /dev/null +++ b/zjlenv/bin/pyrsa-verify @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from rsa.cli import verify +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(verify()) diff --git a/zjlenv/bin/pytest b/zjlenv/bin/pytest new file mode 100755 index 000000000..76ae5d160 --- /dev/null +++ b/zjlenv/bin/pytest @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from pytest import console_main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(console_main()) diff --git a/zjlenv/bin/python b/zjlenv/bin/python new file mode 120000 index 000000000..940bee389 --- /dev/null +++ b/zjlenv/bin/python @@ -0,0 +1 @@ +python3.7 \ No newline at end of file diff --git a/zjlenv/bin/python-config b/zjlenv/bin/python-config new file mode 100755 index 000000000..e33168611 --- /dev/null +++ b/zjlenv/bin/python-config @@ -0,0 +1,78 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python + +import sys +import getopt +import sysconfig + +valid_opts = ['prefix', 'exec-prefix', 'includes', 'libs', 'cflags', + 'ldflags', 'help'] + +if sys.version_info >= (3, 2): + valid_opts.insert(-1, 'extension-suffix') + valid_opts.append('abiflags') +if sys.version_info >= (3, 3): + valid_opts.append('configdir') + + +def exit_with_usage(code=1): + sys.stderr.write("Usage: {0} [{1}]\n".format( + sys.argv[0], '|'.join('--'+opt for opt in valid_opts))) + sys.exit(code) + +try: + opts, args = getopt.getopt(sys.argv[1:], '', valid_opts) +except getopt.error: + exit_with_usage() + +if not opts: + exit_with_usage() + +pyver = sysconfig.get_config_var('VERSION') +getvar = sysconfig.get_config_var + +opt_flags = [flag for (flag, val) in opts] + +if '--help' in opt_flags: + exit_with_usage(code=0) + +for opt in opt_flags: + if opt == '--prefix': + print(sysconfig.get_config_var('prefix')) + + elif opt == '--exec-prefix': + print(sysconfig.get_config_var('exec_prefix')) + + elif opt in ('--includes', '--cflags'): + flags = ['-I' + sysconfig.get_path('include'), + '-I' + sysconfig.get_path('platinclude')] + if opt == '--cflags': + flags.extend(getvar('CFLAGS').split()) + print(' '.join(flags)) + + elif opt in ('--libs', '--ldflags'): + abiflags = getattr(sys, 'abiflags', '') + libs = ['-lpython' + pyver + abiflags] + libs += getvar('LIBS').split() + libs += getvar('SYSLIBS').split() + # add the prefix/lib/pythonX.Y/config dir, but only if there is no + # shared library in prefix/lib/. + if opt == '--ldflags': + if not getvar('Py_ENABLE_SHARED'): + libs.insert(0, '-L' + getvar('LIBPL')) + if not getvar('PYTHONFRAMEWORK'): + libs.extend(getvar('LINKFORSHARED').split()) + print(' '.join(libs)) + + elif opt == '--extension-suffix': + ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') + if ext_suffix is None: + ext_suffix = sysconfig.get_config_var('SO') + print(ext_suffix) + + elif opt == '--abiflags': + if not getattr(sys, 'abiflags', None): + exit_with_usage() + print(sys.abiflags) + + elif opt == '--configdir': + print(sysconfig.get_config_var('LIBPL')) diff --git a/zjlenv/bin/python3 b/zjlenv/bin/python3 new file mode 120000 index 000000000..940bee389 --- /dev/null +++ b/zjlenv/bin/python3 @@ -0,0 +1 @@ +python3.7 \ No newline at end of file diff --git a/zjlenv/bin/python3.7 b/zjlenv/bin/python3.7 new file mode 100755 index 0000000000000000000000000000000000000000..8e1f2b49fe2438795c05cacff8e2a132f794f570 GIT binary patch literal 8632 zcmeHMOK%fN5bhyBvS5M3uC%Z$$RI#GRx=BPhgN%tS{sIy=_yOGExTh5t9yzSEXAYR}>*-ErOoF%~bxGy!>Z<;_=Ib7-?Rxd` zzfZ$LI733rE)znmgU+rH;(;ijC&UKm38<8l#p9(brSoSwniVH&E@~0y9|TG{Rk}Qt z4N>#+*)gGQG>^<8hBD zv@>|g@cgHgK~>(2{6>%r4dQ*!c{$Gp9;N zPnt7aNN`@5#bHy7O&sIm_*T4Lzq30L?+YPCN4IwX#>~)+?|n8==hO)X2-#&40Jr032t;^zg!Pw z%a5y`>w4~FH?B5AVU}C~-tQjrv=dkq@^lz#@^_r4_?mGRxC4L5qqs+0@K`HzmJmCj z|4JK@Zh2qkeR>Qn^`5?4v(MiwIow0@>K6fw1V;R% zx-jmZ7tT!H$>)Ygdg>;w9lGfOYmwn>khLrWmI2FvWxz6E8L$jk1}p=X0n318z%pPN z_^ucj-I_nggnt}ou}>AlWf}Em6tB-kq9>yKx^dM0yK*|L&7AgYA;y$AQZ}kGsiFm_ zcl(_3^OB{<|2HSoKK`3yIiIC1)XS9ohEMaC1Lt3d@^6fXh()9R;-H@;Ui#PHXf|BG z)v5<>wj>DOke-9Gqrx +# Copyright: This module has been placed in the public domain. + +""" +A minimal front end to the Docutils Publisher, producing HTML. +""" + +try: + import locale + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline, default_description + + +description = ('Generates (X)HTML documents from standalone reStructuredText ' + 'sources. ' + default_description) + +publish_cmdline(writer_name='html', description=description) diff --git a/zjlenv/bin/rst2html4.py b/zjlenv/bin/rst2html4.py new file mode 100755 index 000000000..29a95c807 --- /dev/null +++ b/zjlenv/bin/rst2html4.py @@ -0,0 +1,26 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# $Id: rst2html4.py 7994 2016-12-10 17:41:45Z milde $ +# Author: David Goodger +# Copyright: This module has been placed in the public domain. + +""" +A minimal front end to the Docutils Publisher, producing (X)HTML. + +The output conforms to XHTML 1.0 transitional +and almost to HTML 4.01 transitional (except for closing empty tags). +""" + +try: + import locale + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline, default_description + + +description = ('Generates (X)HTML documents from standalone reStructuredText ' + 'sources. ' + default_description) + +publish_cmdline(writer_name='html4', description=description) diff --git a/zjlenv/bin/rst2html5.py b/zjlenv/bin/rst2html5.py new file mode 100755 index 000000000..e8a9f2af3 --- /dev/null +++ b/zjlenv/bin/rst2html5.py @@ -0,0 +1,35 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf8 -*- +# :Copyright: © 2015 Günter Milde. +# :License: Released under the terms of the `2-Clause BSD license`_, in short: +# +# Copying and distribution of this file, with or without modification, +# are permitted in any medium without royalty provided the copyright +# notice and this notice are preserved. +# This file is offered as-is, without any warranty. +# +# .. _2-Clause BSD license: http://www.spdx.org/licenses/BSD-2-Clause +# +# Revision: $Revision: 8410 $ +# Date: $Date: 2019-11-04 22:14:43 +0100 (Mo, 04. Nov 2019) $ + +""" +A minimal front end to the Docutils Publisher, producing HTML 5 documents. + +The output also conforms to XHTML 1.0 transitional +(except for the doctype declaration). +""" + +try: + import locale # module missing in Jython + locale.setlocale(locale.LC_ALL, '') +except locale.Error: + pass + +from docutils.core import publish_cmdline, default_description + +description = (u'Generates HTML 5 documents from standalone ' + u'reStructuredText sources ' + + default_description) + +publish_cmdline(writer_name='html5', description=description) diff --git a/zjlenv/bin/rst2latex.py b/zjlenv/bin/rst2latex.py new file mode 100755 index 000000000..c7eef261d --- /dev/null +++ b/zjlenv/bin/rst2latex.py @@ -0,0 +1,26 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# $Id: rst2latex.py 5905 2009-04-16 12:04:49Z milde $ +# Author: David Goodger +# Copyright: This module has been placed in the public domain. + +""" +A minimal front end to the Docutils Publisher, producing LaTeX. +""" + +try: + import locale + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline + +description = ('Generates LaTeX documents from standalone reStructuredText ' + 'sources. ' + 'Reads from (default is stdin) and writes to ' + ' (default is stdout). See ' + ' for ' + 'the full reference.') + +publish_cmdline(writer_name='latex', description=description) diff --git a/zjlenv/bin/rst2man.py b/zjlenv/bin/rst2man.py new file mode 100755 index 000000000..aaccd9274 --- /dev/null +++ b/zjlenv/bin/rst2man.py @@ -0,0 +1,26 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# Author: +# Contact: grubert@users.sf.net +# Copyright: This module has been placed in the public domain. + +""" +man.py +====== + +This module provides a simple command line interface that uses the +man page writer to output from ReStructuredText source. +""" + +import locale +try: + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline, default_description +from docutils.writers import manpage + +description = ("Generates plain unix manual documents. " + default_description) + +publish_cmdline(writer=manpage.Writer(), description=description) diff --git a/zjlenv/bin/rst2odt.py b/zjlenv/bin/rst2odt.py new file mode 100755 index 000000000..f8f9b9e8b --- /dev/null +++ b/zjlenv/bin/rst2odt.py @@ -0,0 +1,30 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# $Id: rst2odt.py 5839 2009-01-07 19:09:28Z dkuhlman $ +# Author: Dave Kuhlman +# Copyright: This module has been placed in the public domain. + +""" +A front end to the Docutils Publisher, producing OpenOffice documents. +""" + +import sys +try: + import locale + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline_to_binary, default_description +from docutils.writers.odf_odt import Writer, Reader + + +description = ('Generates OpenDocument/OpenOffice/ODF documents from ' + 'standalone reStructuredText sources. ' + default_description) + + +writer = Writer() +reader = Reader() +output = publish_cmdline_to_binary(reader=reader, writer=writer, + description=description) + diff --git a/zjlenv/bin/rst2odt_prepstyles.py b/zjlenv/bin/rst2odt_prepstyles.py new file mode 100755 index 000000000..656c33e93 --- /dev/null +++ b/zjlenv/bin/rst2odt_prepstyles.py @@ -0,0 +1,67 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# $Id: rst2odt_prepstyles.py 8346 2019-08-26 12:11:32Z milde $ +# Author: Dave Kuhlman +# Copyright: This module has been placed in the public domain. + +""" +Fix a word-processor-generated styles.odt for odtwriter use: Drop page size +specifications from styles.xml in STYLE_FILE.odt. +""" + +# Author: Michael Schutte + +from __future__ import print_function + +from lxml import etree +import sys +import zipfile +from tempfile import mkstemp +import shutil +import os + +NAMESPACES = { + "style": "urn:oasis:names:tc:opendocument:xmlns:style:1.0", + "fo": "urn:oasis:names:tc:opendocument:xmlns:xsl-fo-compatible:1.0" +} + + +def prepstyle(filename): + + zin = zipfile.ZipFile(filename) + styles = zin.read("styles.xml") + + root = etree.fromstring(styles) + for el in root.xpath("//style:page-layout-properties", + namespaces=NAMESPACES): + for attr in el.attrib: + if attr.startswith("{%s}" % NAMESPACES["fo"]): + del el.attrib[attr] + + tempname = mkstemp() + zout = zipfile.ZipFile(os.fdopen(tempname[0], "w"), "w", + zipfile.ZIP_DEFLATED) + + for item in zin.infolist(): + if item.filename == "styles.xml": + zout.writestr(item, etree.tostring(root)) + else: + zout.writestr(item, zin.read(item.filename)) + + zout.close() + zin.close() + shutil.move(tempname[1], filename) + + +def main(): + args = sys.argv[1:] + if len(args) != 1: + print(__doc__, file=sys.stderr) + print("Usage: %s STYLE_FILE.odt\n" % sys.argv[0], file=sys.stderr) + sys.exit(1) + filename = args[0] + prepstyle(filename) + + +if __name__ == '__main__': + main() diff --git a/zjlenv/bin/rst2pseudoxml.py b/zjlenv/bin/rst2pseudoxml.py new file mode 100755 index 000000000..fefd51aa3 --- /dev/null +++ b/zjlenv/bin/rst2pseudoxml.py @@ -0,0 +1,23 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# $Id: rst2pseudoxml.py 4564 2006-05-21 20:44:42Z wiemann $ +# Author: David Goodger +# Copyright: This module has been placed in the public domain. + +""" +A minimal front end to the Docutils Publisher, producing pseudo-XML. +""" + +try: + import locale + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline, default_description + + +description = ('Generates pseudo-XML from standalone reStructuredText ' + 'sources (for testing purposes). ' + default_description) + +publish_cmdline(description=description) diff --git a/zjlenv/bin/rst2s5.py b/zjlenv/bin/rst2s5.py new file mode 100755 index 000000000..66a257d1c --- /dev/null +++ b/zjlenv/bin/rst2s5.py @@ -0,0 +1,24 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# $Id: rst2s5.py 4564 2006-05-21 20:44:42Z wiemann $ +# Author: Chris Liechti +# Copyright: This module has been placed in the public domain. + +""" +A minimal front end to the Docutils Publisher, producing HTML slides using +the S5 template system. +""" + +try: + import locale + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline, default_description + + +description = ('Generates S5 (X)HTML slideshow documents from standalone ' + 'reStructuredText sources. ' + default_description) + +publish_cmdline(writer_name='s5', description=description) diff --git a/zjlenv/bin/rst2xetex.py b/zjlenv/bin/rst2xetex.py new file mode 100755 index 000000000..835e34da8 --- /dev/null +++ b/zjlenv/bin/rst2xetex.py @@ -0,0 +1,27 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# $Id: rst2xetex.py 7847 2015-03-17 17:30:47Z milde $ +# Author: Guenter Milde +# Copyright: This module has been placed in the public domain. + +""" +A minimal front end to the Docutils Publisher, producing Lua/XeLaTeX code. +""" + +try: + import locale + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline + +description = ('Generates LaTeX documents from standalone reStructuredText ' + 'sources for compilation with the Unicode-aware TeX variants ' + 'XeLaTeX or LuaLaTeX. ' + 'Reads from (default is stdin) and writes to ' + ' (default is stdout). See ' + ' for ' + 'the full reference.') + +publish_cmdline(writer_name='xetex', description=description) diff --git a/zjlenv/bin/rst2xml.py b/zjlenv/bin/rst2xml.py new file mode 100755 index 000000000..bc61abe8c --- /dev/null +++ b/zjlenv/bin/rst2xml.py @@ -0,0 +1,23 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# $Id: rst2xml.py 4564 2006-05-21 20:44:42Z wiemann $ +# Author: David Goodger +# Copyright: This module has been placed in the public domain. + +""" +A minimal front end to the Docutils Publisher, producing Docutils XML. +""" + +try: + import locale + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline, default_description + + +description = ('Generates Docutils-native XML from standalone ' + 'reStructuredText sources. ' + default_description) + +publish_cmdline(writer_name='xml', description=description) diff --git a/zjlenv/bin/rstpep2html.py b/zjlenv/bin/rstpep2html.py new file mode 100755 index 000000000..0f9d6c4f6 --- /dev/null +++ b/zjlenv/bin/rstpep2html.py @@ -0,0 +1,25 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 + +# $Id: rstpep2html.py 4564 2006-05-21 20:44:42Z wiemann $ +# Author: David Goodger +# Copyright: This module has been placed in the public domain. + +""" +A minimal front end to the Docutils Publisher, producing HTML from PEP +(Python Enhancement Proposal) documents. +""" + +try: + import locale + locale.setlocale(locale.LC_ALL, '') +except: + pass + +from docutils.core import publish_cmdline, default_description + + +description = ('Generates (X)HTML from reStructuredText-format PEP files. ' + + default_description) + +publish_cmdline(reader_name='pep', writer_name='pep_html', + description=description) diff --git a/zjlenv/bin/sphinx-apidoc b/zjlenv/bin/sphinx-apidoc new file mode 100755 index 000000000..4eed5f711 --- /dev/null +++ b/zjlenv/bin/sphinx-apidoc @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from sphinx.ext.apidoc import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/sphinx-autogen b/zjlenv/bin/sphinx-autogen new file mode 100755 index 000000000..8bd5818af --- /dev/null +++ b/zjlenv/bin/sphinx-autogen @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from sphinx.ext.autosummary.generate import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/sphinx-build b/zjlenv/bin/sphinx-build new file mode 100755 index 000000000..b76dbb1b3 --- /dev/null +++ b/zjlenv/bin/sphinx-build @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from sphinx.cmd.build import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/sphinx-quickstart b/zjlenv/bin/sphinx-quickstart new file mode 100755 index 000000000..a0ef78ffb --- /dev/null +++ b/zjlenv/bin/sphinx-quickstart @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from sphinx.cmd.quickstart import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/stubgen b/zjlenv/bin/stubgen new file mode 100755 index 000000000..d73e4a697 --- /dev/null +++ b/zjlenv/bin/stubgen @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from mypy.stubgen import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/stubtest b/zjlenv/bin/stubtest new file mode 100755 index 000000000..fc6ce4a03 --- /dev/null +++ b/zjlenv/bin/stubtest @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from mypy.stubtest import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/tensorboard b/zjlenv/bin/tensorboard new file mode 100755 index 000000000..c976b11fd --- /dev/null +++ b/zjlenv/bin/tensorboard @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from tensorboard.main import run_main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(run_main()) diff --git a/zjlenv/bin/tqdm b/zjlenv/bin/tqdm new file mode 100755 index 000000000..85f1b2ae5 --- /dev/null +++ b/zjlenv/bin/tqdm @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from tqdm.cli import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/bin/tune b/zjlenv/bin/tune new file mode 100755 index 000000000..62a89f244 --- /dev/null +++ b/zjlenv/bin/tune @@ -0,0 +1,8 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys +from ray.tune.scripts import cli +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(cli()) diff --git a/zjlenv/bin/wheel b/zjlenv/bin/wheel new file mode 100755 index 000000000..8e18047d6 --- /dev/null +++ b/zjlenv/bin/wheel @@ -0,0 +1,10 @@ +#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 +# -*- coding: utf-8 -*- +import re +import sys + +from wheel.cli import main + +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) + sys.exit(main()) diff --git a/zjlenv/include/python3.7m b/zjlenv/include/python3.7m new file mode 120000 index 000000000..7e2d975bd --- /dev/null +++ b/zjlenv/include/python3.7m @@ -0,0 +1 @@ +/Library/Frameworks/Python.framework/Versions/3.7/include/python3.7m \ No newline at end of file From ae8c12e3b12ec23d2d5dd80fc80e7afbfbd70c14 Mon Sep 17 00:00:00 2001 From: Jialu Zhu Date: Mon, 14 Dec 2020 20:45:21 -0800 Subject: [PATCH 02/23] removing_comments --- test/discrete/test_bcq.py | 15 +++------ tianshou/policy/modelfree/bcq.py | 57 +++++--------------------------- tianshou/trainer/offline.py | 25 +++----------- 3 files changed, 18 insertions(+), 79 deletions(-) diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py index 47fe9ddcb..23a08e6ff 100644 --- a/test/discrete/test_bcq.py +++ b/test/discrete/test_bcq.py @@ -37,9 +37,6 @@ state_shape = state_shape[0] action_shape = env.action_space.shape or env.action_space.n -# print(state_shape) -# print(action_shape) -# exit() model = BCQN(state_shape,action_shape,args.hidden_dim,args.hidden_dim) optim = torch.optim.Adam(model.parameters(), lr=0.5) @@ -48,26 +45,22 @@ buffer = ReplayBuffer(size=args.buffer_size) for i in range(args.buffer_size): buffer.add(obs=torch.rand(state_shape), act=random.randint(0, action_shape-1), rew=1, done=False, obs_next=torch.rand(state_shape), info={}) -# buf.add(obs=[1.0, 1.0, 1.0, 1.0], act=[1], rew=1.0, done=False, obs_next=[1.0, 1.0, 1.0, 1.0], info={}) - - -# buffer.add(obs=torch.Tensor([1.0, 1.0, 1.0, 1.0]), act=1, rew=1, done=False, obs_next=torch.Tensor([2.0, 2.0, 2.0, 2.0]), info={}) -# buffer.add(obs=torch.Tensor([2.0, 2.0, 2.0, 2.0]), act=2, rew=2, done=True, obs_next=torch.Tensor([-1.0, -1.0, -1.0, -1.0]), info={}) test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) -# train_envs.seed(args.seed) test_envs.seed(args.seed) test_collector = Collector(policy, test_envs) log_path = os.path.join(args.logdir, 'writer') writer = SummaryWriter(log_path) - -best_policy_path = os.path.join(args.logdir, 'best_policy') +if not os.path.exists(log_path): + os.makedirs(log_path) +# best_policy_path = os.path.join(args.logdir, 'best_policy') res = offline_trainer(policy, buffer, test_collector, args.epoch, args.batch_size, args.episode_per_test, writer, args.test_frequency) print('zjlbest_reward', res['best_reward']) diff --git a/tianshou/policy/modelfree/bcq.py b/tianshou/policy/modelfree/bcq.py index 78d4839b0..6050b18fc 100644 --- a/tianshou/policy/modelfree/bcq.py +++ b/tianshou/policy/modelfree/bcq.py @@ -41,7 +41,8 @@ def forward(self, state): class BCQPolicy(BasePolicy): - """Implementation of vanilla imitation learning. + """Implementation discrete BCQ algorithm. Some code is from + https://github.com/sfujim/BCQ/tree/master/discrete_BCQ :param torch.nn.Module model: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> a) @@ -64,7 +65,6 @@ def __init__( device: str, gamma: float = 0.9, imitation_logits_penalty: float=0.1, - # mode: str = "continuous", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -72,13 +72,8 @@ def __init__( self._optimizer = optim self._cnt = 0 self._device = device + # TODOzjl shanchu moren self._gamma = gamma - # assert ( - # mode in ["continuous", "discrete"] - # ), f"Mode {mode} is not in ['continuous', 'discrete']." - # self.mode = mode - - #TODOzjl chuan can shu self._tao = tao self._target_net = deepcopy(self._policy_net) self._target_net.eval() @@ -89,17 +84,10 @@ def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, - # input: str = "obs", **kwargs: Any, ) -> Batch: - # logits, h = self.model(batch.obs, state=state, info=batch.info) - # if self.mode == "discrete": - # a = logits.max(dim=1)[1] - # else: - # a = logits batch.to_torch() state = batch.obs - # state = batch['obs'] q, imt, _ = self._policy_net(state.float()) imt = imt.exp() imt = (imt / imt.max(1, keepdim=True)[0] > self._tao).float() @@ -108,23 +96,17 @@ def forward( return Batch(act=action, state=state, qvalue=q) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - batch.to_torch() - if self._cnt % self._target_update_frequency == 0: - # self._logger.info("Updating target network...") self._target_net.load_state_dict(self._policy_net.state_dict()) self._target_net.eval() - # state = self(batch, input='obs_next') - # qvalue = batch. - non_final_mask = torch.tensor( tuple(map(lambda s: not s, batch.done)), device=self._device, dtype=torch.bool, ) - # p(non_final_mask) + try: non_final_next_states = torch.cat( [s.obs.unsqueeze(0) for s in batch if not s.done], dim=0 @@ -132,9 +114,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: except Exception: non_final_next_states = None - # print(non_final_next_states) - - # # Compute the target Q value + # Compute the target Q value with torch.no_grad(): expected_state_action_values = batch.rew.float() @@ -143,46 +123,27 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: q, imt, _ = self._policy_net(non_final_next_states) imt = imt.exp() imt = (imt / imt.max(1, keepdim=True)[0] > self._tao).float() - # print(imt) - # print(q) + # Use large negative number to mask actions from argmax next_action = (imt * q + (1 - imt) * -1e8).argmax(1, keepdim=True) - print('zjlnext_action', next_action) - q, _, _ = self._target_net(non_final_next_states) q = q.gather(1, next_action).reshape(-1, 1) - print('zjlq', q) - - # print('zjlaaa', torch.zeros(len(batch), device=self._device).float()) - next_state_values = ( torch.zeros(len(batch), device=self._device) - # torch.zeros(self._batch_size, device=self._device) .float() - # .unsqueeze(1) ) - print('zjl1', next_state_values) - print('zjlnon_final_mask', non_final_mask) - print('zjlq', q) - next_state_values[non_final_mask] = q.squeeze() - print('zjl2', next_state_values) - print(expected_state_action_values) - # print(next_state_values) + next_state_values[non_final_mask] = q.squeeze() expected_state_action_values += next_state_values * self._gamma - # # Get current Q estimate + # Get current Q estimate current_Q, imt, i = self._policy_net(batch.obs) - print('zjlcurrent_Q', current_Q) - print('zjlbatch.act', batch.act) current_Q = current_Q.gather(1, batch.act.unsqueeze(1)).squeeze() - print('zjlfinal', current_Q, expected_state_action_values) - - # # Compute Q loss + # Compute Q loss q_loss = F.smooth_l1_loss(current_Q, expected_state_action_values) i_loss = F.nll_loss(imt, batch.act.reshape(-1)) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 34376ba82..550339662 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -12,25 +12,12 @@ def offline_trainer( policy: BasePolicy, buffer: ReplayBuffer, - # train_collector: Collector, test_collector: Collector, - # max_epoch: int, epochs: int, - # step_per_epoch: int, - # collect_per_step: int, - # episode_per_test: Union[int, List[int]], batch_size: int, episode_per_test: int, - # best_policy_save_dir: Optional[str], - # update_per_step: int = 1, - # train_fn: Optional[Callable[[int, int], None]] = None, - # test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - # stop_fn: Optional[Callable[[float], bool]] = None, - # save_fn: Optional[Callable[[BasePolicy], None]] = None, writer: Optional[SummaryWriter] = None, test_frequency: int = 1, - # verbose: bool = True, - # test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: best_reward = -1 @@ -43,24 +30,22 @@ def offline_trainer( for iter in range(iter_per_epoch): total_iter += 1 loss = policy.update(batch_size, buffer) - # batch = buffer.sample(args.batch_size)[0] - # batch.to_torch() if total_iter % test_frequency == 0: writer.add_scalar( "train/loss", loss['loss'], global_step=total_iter) - # test_collector = Collector(policy, test_envs) test_result = test_episode( policy, test_collector, None, epoch, episode_per_test, writer, total_iter) - # for k in result.keys(): - # writer.add_scalar( - # "train/" + k, result[k], global_step=env_step) if best_reward < test_result["rew"]: best_reward = test_result["rew"] best_policy = policy - # epoch, args.episode_per_test, writer, env_step) + + print(loss['loss']) + print(test_result) + print(best_reward) + print('---------------') return {'best_reward': best_reward, 'best_policy': best_policy} From b5eba6c4eaf2c8f4c66e272f74699e73d5a47b0d Mon Sep 17 00:00:00 2001 From: Jialu Zhu Date: Mon, 14 Dec 2020 22:10:54 -0800 Subject: [PATCH 03/23] removing_comments --- test/discrete/test_bcq.py | 236 ++++++++++++++++--------------- tianshou/data/buffer.py | 1 - tianshou/policy/base.py | 1 - tianshou/policy/modelfree/bcq.py | 15 +- tianshou/policy/modelfree/dqn.py | 5 - 5 files changed, 127 insertions(+), 131 deletions(-) diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py index 23a08e6ff..eb14f3d02 100644 --- a/test/discrete/test_bcq.py +++ b/test/discrete/test_bcq.py @@ -15,119 +15,125 @@ from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer - -parser = argparse.ArgumentParser() -parser.add_argument('--task', type=str, default='CartPole-v0') -parser.add_argument('--seed', type=int, default=1626) -parser.add_argument('--batch-size', type=int, default=2) -parser.add_argument('--epoch', type=int, default=10) -parser.add_argument('--test-num', type=int, default=1) -parser.add_argument('--logdir', type=str, default='log') -parser.add_argument('--buffer-size', type=int, default=10) -parser.add_argument('--hidden-dim', type=int, default=5) -parser.add_argument('--test-frequency', type=int, default=5) -parser.add_argument('--target-update-frequency', type=int, default=5) -parser.add_argument('--episode-per-test', type=int, default=5) -parser.add_argument('--tao', type=float, default=0.8) - -args = parser.parse_known_args()[0] - -env = gym.make(args.task) -state_shape = env.observation_space.shape or env.observation_space.n -state_shape = state_shape[0] -action_shape = env.action_space.shape or env.action_space.n - - -model = BCQN(state_shape,action_shape,args.hidden_dim,args.hidden_dim) -optim = torch.optim.Adam(model.parameters(), lr=0.5) -policy = BCQPolicy(model, optim, args.tao, args.target_update_frequency, 'cpu') - -buffer = ReplayBuffer(size=args.buffer_size) -for i in range(args.buffer_size): - buffer.add(obs=torch.rand(state_shape), act=random.randint(0, action_shape-1), rew=1, done=False, obs_next=torch.rand(state_shape), info={}) - -test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - -# seed -np.random.seed(args.seed) -torch.manual_seed(args.seed) -test_envs.seed(args.seed) - -test_collector = Collector(policy, test_envs) - -log_path = os.path.join(args.logdir, 'writer') -writer = SummaryWriter(log_path) -if not os.path.exists(log_path): - os.makedirs(log_path) -# best_policy_path = os.path.join(args.logdir, 'best_policy') - -res = offline_trainer(policy, buffer, test_collector, args.epoch, args.batch_size, args.episode_per_test, writer, args.test_frequency) -print('zjlbest_reward', res['best_reward']) - -# TODOzjl save policy torch.save(policy.state_dict(), os.path.join(best_policy_save_dir, 'policy.pth')) - - -# # batch = buffer.sample(1) -# print(buffer.obs) -# print(buffer.rew) -# buffer.rew = torch.Tensor([1,2]) -# print(buffer.rew) - - -# buffer.obs = torch.Tensor([[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]]) -# buffer.act = torch.Tensor([1, 2]) -# buffer.rew = torch.Tensor([10, 20]) -# buffer.done = torch.Tensor([False, True]) -# buffer.obs_next = torch.Tensor([[1.0, 1.0, 1.0, 1.0], [-1.0, -1.0, -1.0, -1.0]]) - -# print(buffer.obs) -# print(buffer) -# buffer.add( -# obs=torch.Tensor([[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]]), -# act=torch.Tensor([1, 2]), -# rew=torch.Tensor([1.0, 2.0]), -# done=torch.Tensor([False, True]), -# obs_next=torch.Tensor([[1.0, 1.0, 1.0, 1.0], [-1.0, -1.0, -1.0, -1.0]]), -# ) -# print(buffer) -# print(buffer) -# batch = buffer.sample(2)[0] -# print(batch) -# batch.to_torch() -# print(batch) -# batch = policy.forward(batch) -# print(batch) - -# loss = policy.learn(batch) -# print(loss) -########################### -# best_reward = -1 -# best_policy = policy - -# global_iter = 0 - -# total_iter = len(buffer) // args.batch_size -# for epoch in range(1, 1 + args.epoch): -# for iter in range(total_iter): -# global_iter += 1 -# loss = policy.update(args.batch_size, buffer) -# # batch = buffer.sample(args.batch_size)[0] -# # batch.to_torch() -# if global_iter % log_frequency == 0: -# writer.add_scalar( -# "train/loss", loss['loss'], global_step=global_iter) -# test_collector = Collector(policy, test_envs) - -# test_result = test_episode( -# policy, test_collector, None, -# epoch, args.episode_per_test, writer, global_iter) -# # for k in result.keys(): -# # writer.add_scalar( -# # "train/" + k, result[k], global_step=env_step) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--batch-size', type=int, default=2) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--test-num', type=int, default=1) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--buffer-size', type=int, default=10) + parser.add_argument('--hidden-dim', type=int, default=5) + parser.add_argument('--test-frequency', type=int, default=5) + parser.add_argument('--target-update-frequency', type=int, default=5) + parser.add_argument('--episode-per-test', type=int, default=5) + parser.add_argument('--tao', type=float, default=0.8) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--imitation_logits_penalty', type=float, default=0.1) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + +def test_bcq(args=get_args()): + env = gym.make(args.task) + state_shape = env.observation_space.shape or env.observation_space.n + state_shape = state_shape[0] + action_shape = env.action_space.shape or env.action_space.n + + model = BCQN(state_shape,action_shape,args.hidden_dim,args.hidden_dim) + optim = torch.optim.Adam(model.parameters(), lr=0.5) + policy = BCQPolicy(model, optim, args.tao, args.target_update_frequency, args.device, args.gamma, args.imitation_logits_penalty) + + buffer = ReplayBuffer(size=args.buffer_size) + for i in range(args.buffer_size): + buffer.add(obs=torch.rand(state_shape), act=random.randint(0, action_shape-1), rew=1, done=False, obs_next=torch.rand(state_shape), info={}) + + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + + test_collector = Collector(policy, test_envs) + + log_path = os.path.join(args.logdir, 'writer') + writer = SummaryWriter(log_path) + if not os.path.exists(log_path): + os.makedirs(log_path) + + res = offline_trainer(policy, buffer, test_collector, args.epoch, args.batch_size, args.episode_per_test, writer, args.test_frequency) + print('final best_reward', res['best_reward']) -# if best_reward < result["rew"]: -# best_reward = result["rew"] -# best_policy = policy -# # epoch, args.episode_per_test, writer, env_step) - \ No newline at end of file +if __name__ == '__main__': + test_bcq(get_args()) + # TODOzjl save policy torch.save(policy.state_dict(), os.path.join(best_policy_save_dir, 'policy.pth')) + + + # # batch = buffer.sample(1) + # print(buffer.obs) + # print(buffer.rew) + # buffer.rew = torch.Tensor([1,2]) + # print(buffer.rew) + + + # buffer.obs = torch.Tensor([[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]]) + # buffer.act = torch.Tensor([1, 2]) + # buffer.rew = torch.Tensor([10, 20]) + # buffer.done = torch.Tensor([False, True]) + # buffer.obs_next = torch.Tensor([[1.0, 1.0, 1.0, 1.0], [-1.0, -1.0, -1.0, -1.0]]) + + # print(buffer.obs) + # print(buffer) + # buffer.add( + # obs=torch.Tensor([[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]]), + # act=torch.Tensor([1, 2]), + # rew=torch.Tensor([1.0, 2.0]), + # done=torch.Tensor([False, True]), + # obs_next=torch.Tensor([[1.0, 1.0, 1.0, 1.0], [-1.0, -1.0, -1.0, -1.0]]), + # ) + # print(buffer) + # print(buffer) + # batch = buffer.sample(2)[0] + # print(batch) + # batch.to_torch() + # print(batch) + # batch = policy.forward(batch) + # print(batch) + + # loss = policy.learn(batch) + # print(loss) + ########################### + # best_reward = -1 + # best_policy = policy + + # global_iter = 0 + + # total_iter = len(buffer) // args.batch_size + # for epoch in range(1, 1 + args.epoch): + # for iter in range(total_iter): + # global_iter += 1 + # loss = policy.update(args.batch_size, buffer) + # # batch = buffer.sample(args.batch_size)[0] + # # batch.to_torch() + # if global_iter % log_frequency == 0: + # writer.add_scalar( + # "train/loss", loss['loss'], global_step=global_iter) + # test_collector = Collector(policy, test_envs) + + # test_result = test_episode( + # policy, test_collector, None, + # epoch, args.episode_per_test, writer, global_iter) + # # for k in result.keys(): + # # writer.add_scalar( + # # "train/" + k, result[k], global_step=env_step) + + # if best_reward < result["rew"]: + # best_reward = result["rew"] + # best_policy = policy + # # epoch, args.episode_per_test, writer, env_step) + \ No newline at end of file diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 038620dcd..c8c572a22 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -431,7 +431,6 @@ def add( **kwargs: Any, ) -> None: """Add a batch of data into replay buffer.""" - print('zjlobs',obs) if weight is None: weight = self._max_prio else: diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index c9d33c754..809751fe7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -165,7 +165,6 @@ def update( """ if buffer is None: return {} - print(buffer) batch, indice = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indice) diff --git a/tianshou/policy/modelfree/bcq.py b/tianshou/policy/modelfree/bcq.py index 6050b18fc..829a205cf 100644 --- a/tianshou/policy/modelfree/bcq.py +++ b/tianshou/policy/modelfree/bcq.py @@ -15,6 +15,7 @@ # def p(x): # print(frameinfo.filename, frameinfo.lineno, x) + class BCQN(nn.Module): """BCQ NN for dialogue policy. It includes a net for imitation and a net for Q-value""" @@ -41,7 +42,7 @@ def forward(self, state): class BCQPolicy(BasePolicy): - """Implementation discrete BCQ algorithm. Some code is from + """Implementation discrete BCQ algorithm. Some code is from https://github.com/sfujim/BCQ/tree/master/discrete_BCQ :param torch.nn.Module model: a model following the rules in @@ -63,8 +64,8 @@ def __init__( tao: float, target_update_frequency: int, device: str, - gamma: float = 0.9, - imitation_logits_penalty: float=0.1, + gamma: float, + imitation_logits_penalty: float, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -123,17 +124,14 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: q, imt, _ = self._policy_net(non_final_next_states) imt = imt.exp() imt = (imt / imt.max(1, keepdim=True)[0] > self._tao).float() - + # Use large negative number to mask actions from argmax next_action = (imt * q + (1 - imt) * -1e8).argmax(1, keepdim=True) q, _, _ = self._target_net(non_final_next_states) q = q.gather(1, next_action).reshape(-1, 1) - next_state_values = ( - torch.zeros(len(batch), device=self._device) - .float() - ) + next_state_values = torch.zeros(len(batch), device=self._device).float() next_state_values[non_final_mask] = q.squeeze() expected_state_action_values += next_state_values * self._gamma @@ -155,7 +153,6 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: return {"loss": Q_loss.item()} - # if self._target and self._cnt % self._freq == 0: # self.sync_weight() # self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index ac87d1ce7..91cca6139 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -165,19 +165,14 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() - print('zjlbatch', batch) weight = batch.pop("weight", 1.0) q = self(batch).logits - print('zjlq1', q) q = q[np.arange(len(q)), batch.act] - print('zjlq2', q) r = to_torch_as(batch.returns.flatten(), q) - print('zjlr', r) td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() self._cnt += 1 - exit() return {"loss": loss.item()} From 5e6873e6d5f833fc841bf95dca00d10d5cedab62 Mon Sep 17 00:00:00 2001 From: Jialu Zhu Date: Mon, 14 Dec 2020 22:12:13 -0800 Subject: [PATCH 04/23] format --- test/discrete/test_bcq.py | 82 ++++++++++++++++++++++++------------- tianshou/trainer/offline.py | 26 +++++++----- 2 files changed, 68 insertions(+), 40 deletions(-) diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py index eb14f3d02..0a8498051 100644 --- a/test/discrete/test_bcq.py +++ b/test/discrete/test_bcq.py @@ -17,42 +17,59 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--batch-size', type=int, default=2) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--test-num', type=int, default=1) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--buffer-size', type=int, default=10) - parser.add_argument('--hidden-dim', type=int, default=5) - parser.add_argument('--test-frequency', type=int, default=5) - parser.add_argument('--target-update-frequency', type=int, default=5) - parser.add_argument('--episode-per-test', type=int, default=5) - parser.add_argument('--tao', type=float, default=0.8) - parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--imitation_logits_penalty', type=float, default=0.1) + parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--epoch", type=int, default=10) + parser.add_argument("--test-num", type=int, default=1) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--buffer-size", type=int, default=10) + parser.add_argument("--hidden-dim", type=int, default=5) + parser.add_argument("--test-frequency", type=int, default=5) + parser.add_argument("--target-update-frequency", type=int, default=5) + parser.add_argument("--episode-per-test", type=int, default=5) + parser.add_argument("--tao", type=float, default=0.8) + parser.add_argument("--gamma", type=float, default=0.9) + parser.add_argument("--imitation_logits_penalty", type=float, default=0.1) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) args = parser.parse_known_args()[0] return args + def test_bcq(args=get_args()): env = gym.make(args.task) state_shape = env.observation_space.shape or env.observation_space.n state_shape = state_shape[0] action_shape = env.action_space.shape or env.action_space.n - model = BCQN(state_shape,action_shape,args.hidden_dim,args.hidden_dim) + model = BCQN(state_shape, action_shape, args.hidden_dim, args.hidden_dim) optim = torch.optim.Adam(model.parameters(), lr=0.5) - policy = BCQPolicy(model, optim, args.tao, args.target_update_frequency, args.device, args.gamma, args.imitation_logits_penalty) + policy = BCQPolicy( + model, + optim, + args.tao, + args.target_update_frequency, + args.device, + args.gamma, + args.imitation_logits_penalty, + ) buffer = ReplayBuffer(size=args.buffer_size) for i in range(args.buffer_size): - buffer.add(obs=torch.rand(state_shape), act=random.randint(0, action_shape-1), rew=1, done=False, obs_next=torch.rand(state_shape), info={}) + buffer.add( + obs=torch.rand(state_shape), + act=random.randint(0, action_shape - 1), + rew=1, + done=False, + obs_next=torch.rand(state_shape), + info={}, + ) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) @@ -61,26 +78,34 @@ def test_bcq(args=get_args()): test_collector = Collector(policy, test_envs) - log_path = os.path.join(args.logdir, 'writer') + log_path = os.path.join(args.logdir, "writer") writer = SummaryWriter(log_path) if not os.path.exists(log_path): os.makedirs(log_path) - res = offline_trainer(policy, buffer, test_collector, args.epoch, args.batch_size, args.episode_per_test, writer, args.test_frequency) - print('final best_reward', res['best_reward']) - -if __name__ == '__main__': + res = offline_trainer( + policy, + buffer, + test_collector, + args.epoch, + args.batch_size, + args.episode_per_test, + writer, + args.test_frequency, + ) + print("final best_reward", res["best_reward"]) + + +if __name__ == "__main__": test_bcq(get_args()) # TODOzjl save policy torch.save(policy.state_dict(), os.path.join(best_policy_save_dir, 'policy.pth')) - # # batch = buffer.sample(1) # print(buffer.obs) # print(buffer.rew) # buffer.rew = torch.Tensor([1,2]) # print(buffer.rew) - # buffer.obs = torch.Tensor([[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]]) # buffer.act = torch.Tensor([1, 2]) # buffer.rew = torch.Tensor([10, 20]) @@ -131,9 +156,8 @@ def test_bcq(args=get_args()): # # for k in result.keys(): # # writer.add_scalar( # # "train/" + k, result[k], global_step=env_step) - + # if best_reward < result["rew"]: # best_reward = result["rew"] # best_policy = policy # # epoch, args.episode_per_test, writer, env_step) - \ No newline at end of file diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 550339662..46a64d80c 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -31,24 +31,28 @@ def offline_trainer( total_iter += 1 loss = policy.update(batch_size, buffer) if total_iter % test_frequency == 0: - writer.add_scalar( - "train/loss", loss['loss'], global_step=total_iter) + writer.add_scalar("train/loss", loss["loss"], global_step=total_iter) test_result = test_episode( - policy, test_collector, None, - epoch, episode_per_test, writer, total_iter) - + policy, + test_collector, + None, + epoch, + episode_per_test, + writer, + total_iter, + ) + if best_reward < test_result["rew"]: best_reward = test_result["rew"] best_policy = policy - - print(loss['loss']) + + print(loss["loss"]) print(test_result) print(best_reward) - print('---------------') - - - return {'best_reward': best_reward, 'best_policy': best_policy} + print("---------------") + + return {"best_reward": best_reward, "best_policy": best_policy} # env_step, gradient_step = 0, 0 # best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 From 8035e8be0e0db6c4473ab122e01f50539d4b8fbd Mon Sep 17 00:00:00 2001 From: Jialu Zhu Date: Tue, 15 Dec 2020 00:05:16 -0800 Subject: [PATCH 05/23] cleaning --- test/discrete/test_bcq.py | 78 ++------------------- tianshou/policy/__init__.py | 2 +- tianshou/policy/modelfree/bcq.py | 116 ++++++++++--------------------- tianshou/trainer/offline.py | 87 ++--------------------- 4 files changed, 50 insertions(+), 233 deletions(-) diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py index 0a8498051..fa258d6f4 100644 --- a/test/discrete/test_bcq.py +++ b/test/discrete/test_bcq.py @@ -1,16 +1,13 @@ from tianshou.policy import BCQPolicy -from tianshou.policy import BCQN import os import gym import torch import pprint import argparse +import random import numpy as np from torch.utils.tensorboard import SummaryWriter -import random -from tianshou.policy import DQNPolicy from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer @@ -28,8 +25,9 @@ def get_args(): parser.add_argument("--test-frequency", type=int, default=5) parser.add_argument("--target-update-frequency", type=int, default=5) parser.add_argument("--episode-per-test", type=int, default=5) - parser.add_argument("--tao", type=float, default=0.8) + parser.add_argument("--tau", type=float, default=0.8) parser.add_argument("--gamma", type=float, default=0.9) + parser.add_argument("--learning-rate", type=float, default=0.01) parser.add_argument("--imitation_logits_penalty", type=float, default=0.1) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" @@ -43,19 +41,20 @@ def test_bcq(args=get_args()): state_shape = env.observation_space.shape or env.observation_space.n state_shape = state_shape[0] action_shape = env.action_space.shape or env.action_space.n + model = BCQPolicy.BCQN(state_shape, action_shape, args.hidden_dim, args.hidden_dim) + optim = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - model = BCQN(state_shape, action_shape, args.hidden_dim, args.hidden_dim) - optim = torch.optim.Adam(model.parameters(), lr=0.5) policy = BCQPolicy( model, optim, - args.tao, + args.tau, args.target_update_frequency, args.device, args.gamma, args.imitation_logits_penalty, ) + # Make up some dummy training data in replay buffer buffer = ReplayBuffer(size=args.buffer_size) for i in range(args.buffer_size): buffer.add( @@ -98,66 +97,3 @@ def test_bcq(args=get_args()): if __name__ == "__main__": test_bcq(get_args()) - # TODOzjl save policy torch.save(policy.state_dict(), os.path.join(best_policy_save_dir, 'policy.pth')) - - # # batch = buffer.sample(1) - # print(buffer.obs) - # print(buffer.rew) - # buffer.rew = torch.Tensor([1,2]) - # print(buffer.rew) - - # buffer.obs = torch.Tensor([[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]]) - # buffer.act = torch.Tensor([1, 2]) - # buffer.rew = torch.Tensor([10, 20]) - # buffer.done = torch.Tensor([False, True]) - # buffer.obs_next = torch.Tensor([[1.0, 1.0, 1.0, 1.0], [-1.0, -1.0, -1.0, -1.0]]) - - # print(buffer.obs) - # print(buffer) - # buffer.add( - # obs=torch.Tensor([[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]]), - # act=torch.Tensor([1, 2]), - # rew=torch.Tensor([1.0, 2.0]), - # done=torch.Tensor([False, True]), - # obs_next=torch.Tensor([[1.0, 1.0, 1.0, 1.0], [-1.0, -1.0, -1.0, -1.0]]), - # ) - # print(buffer) - # print(buffer) - # batch = buffer.sample(2)[0] - # print(batch) - # batch.to_torch() - # print(batch) - # batch = policy.forward(batch) - # print(batch) - - # loss = policy.learn(batch) - # print(loss) - ########################### - # best_reward = -1 - # best_policy = policy - - # global_iter = 0 - - # total_iter = len(buffer) // args.batch_size - # for epoch in range(1, 1 + args.epoch): - # for iter in range(total_iter): - # global_iter += 1 - # loss = policy.update(args.batch_size, buffer) - # # batch = buffer.sample(args.batch_size)[0] - # # batch.to_torch() - # if global_iter % log_frequency == 0: - # writer.add_scalar( - # "train/loss", loss['loss'], global_step=global_iter) - # test_collector = Collector(policy, test_envs) - - # test_result = test_episode( - # policy, test_collector, None, - # epoch, args.episode_per_test, writer, global_iter) - # # for k in result.keys(): - # # writer.add_scalar( - # # "train/" + k, result[k], global_step=env_step) - - # if best_reward < result["rew"]: - # best_reward = result["rew"] - # best_policy = policy - # # epoch, args.episode_per_test, writer, env_step) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 13e539ed2..d21efdd73 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -10,7 +10,6 @@ from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.modelfree.bcq import BCQPolicy -from tianshou.policy.modelfree.bcq import BCQN from tianshou.policy.modelbase.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -30,4 +29,5 @@ "PSRLPolicy", "MultiAgentPolicyManager", "BCQManager", + "BCQPolicy", ] diff --git a/tianshou/policy/modelfree/bcq.py b/tianshou/policy/modelfree/bcq.py index 829a205cf..abd7ccd46 100644 --- a/tianshou/policy/modelfree/bcq.py +++ b/tianshou/policy/modelfree/bcq.py @@ -8,60 +8,45 @@ import torch.nn as nn from copy import deepcopy -from inspect import currentframe, getframeinfo - -frameinfo = getframeinfo(currentframe()) - -# def p(x): -# print(frameinfo.filename, frameinfo.lineno, x) - - -class BCQN(nn.Module): - """BCQ NN for dialogue policy. It includes a net for imitation and a net for Q-value""" - - def __init__( - self, input_size, n_actions, imitation_model_hidden_dim, policy_model_hidden_dim - ): - super(BCQN, self).__init__() - self.q1 = nn.Linear(input_size, policy_model_hidden_dim) - self.q2 = nn.Linear(policy_model_hidden_dim, policy_model_hidden_dim) - self.q3 = nn.Linear(policy_model_hidden_dim, n_actions) - - self.i1 = nn.Linear(input_size, imitation_model_hidden_dim) - self.i2 = nn.Linear(imitation_model_hidden_dim, imitation_model_hidden_dim) - self.i3 = nn.Linear(imitation_model_hidden_dim, n_actions) - - def forward(self, state): - q = F.relu(self.q1(state)) - q = F.relu(self.q2(q)) - - i = F.relu(self.i1(state)) - i = F.relu(self.i2(i)) - i = F.relu(self.i3(i)) - return self.q3(q), F.log_softmax(i, dim=1), i - class BCQPolicy(BasePolicy): """Implementation discrete BCQ algorithm. Some code is from https://github.com/sfujim/BCQ/tree/master/discrete_BCQ - :param torch.nn.Module model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param torch.optim.Optimizer optim: for optimizing the model. - :param str mode: indicate the imitation type ("continuous" or "discrete" - action space), defaults to "continuous". + """ - .. seealso:: + class BCQN(nn.Module): + """BCQ NN for dialogue policy. It includes a net for imitation and a net for Q-value""" - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ + def __init__( + self, input_size, n_actions, imitation_model_hidden_dim, policy_model_hidden_dim + ): + super(BCQPolicy.BCQN, self).__init__() + self.q1 = nn.Linear(input_size, policy_model_hidden_dim) + self.q2 = nn.Linear(policy_model_hidden_dim, policy_model_hidden_dim) + self.q3 = nn.Linear(policy_model_hidden_dim, n_actions) + + self.i1 = nn.Linear(input_size, imitation_model_hidden_dim) + self.i2 = nn.Linear(imitation_model_hidden_dim, imitation_model_hidden_dim) + self.i3 = nn.Linear(imitation_model_hidden_dim, n_actions) + + def forward(self, state): + q = F.relu(self.q1(state)) + q = F.relu(self.q2(q)) + + i = F.relu(self.i1(state)) + i = F.relu(self.i2(i)) + i = F.relu(self.i3(i)) + return self.q3(q), F.log_softmax(i, dim=1), i def __init__( self, model: torch.nn.Module, + # state_dim: int, + # action_dim: int, + # hidden_dim: int, optim: torch.optim.Optimizer, - tao: float, + tau: float, target_update_frequency: int, device: str, gamma: float, @@ -73,9 +58,8 @@ def __init__( self._optimizer = optim self._cnt = 0 self._device = device - # TODOzjl shanchu moren self._gamma = gamma - self._tao = tao + self._tau = tau self._target_net = deepcopy(self._policy_net) self._target_net.eval() self._target_update_frequency = target_update_frequency @@ -88,19 +72,19 @@ def forward( **kwargs: Any, ) -> Batch: batch.to_torch() + state = batch.obs q, imt, _ = self._policy_net(state.float()) imt = imt.exp() - imt = (imt / imt.max(1, keepdim=True)[0] > self._tao).float() + imt = (imt / imt.max(1, keepdim=True)[0] > self._tau).float() + # Use large negative number to mask actions from argmax action = (imt * q + (1.0 - imt) * -1e8).argmax(1) + return Batch(act=action, state=state, qvalue=q) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: batch.to_torch() - if self._cnt % self._target_update_frequency == 0: - self._target_net.load_state_dict(self._policy_net.state_dict()) - self._target_net.eval() non_final_mask = torch.tensor( tuple(map(lambda s: not s, batch.done)), @@ -123,7 +107,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if non_final_next_states is not None: q, imt, _ = self._policy_net(non_final_next_states) imt = imt.exp() - imt = (imt / imt.max(1, keepdim=True)[0] > self._tao).float() + imt = (imt / imt.max(1, keepdim=True)[0] > self._tau).float() # Use large negative number to mask actions from argmax next_action = (imt * q + (1 - imt) * -1e8).argmax(1, keepdim=True) @@ -132,13 +116,12 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: q = q.gather(1, next_action).reshape(-1, 1) next_state_values = torch.zeros(len(batch), device=self._device).float() - next_state_values[non_final_mask] = q.squeeze() + expected_state_action_values += next_state_values * self._gamma # Get current Q estimate current_Q, imt, i = self._policy_net(batch.obs) - current_Q = current_Q.gather(1, batch.act.unsqueeze(1)).squeeze() # Compute Q loss @@ -151,31 +134,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: Q_loss.backward() self._optimizer.step() - return {"loss": Q_loss.item()} + if self._cnt % self._target_update_frequency == 0: + self._target_net.load_state_dict(self._policy_net.state_dict()) + self._target_net.eval() - # if self._target and self._cnt % self._freq == 0: - # self.sync_weight() - # self.optim.zero_grad() - # weight = batch.pop("weight", 1.0) - # q = self(batch).logits - # q = q[np.arange(len(q)), batch.act] - # r = to_torch_as(batch.returns.flatten(), q) - # td = r - q - # loss = (td.pow(2) * weight).mean() - # batch.weight = td # prio-buffer - # loss.backward() - # self.optim.step() - # self._cnt += 1 - # return {"loss": loss.item()} - - # if self.mode == "continuous": # regression - # a = self(batch).act - # a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) - # loss = F.mse_loss(a, a_) # type: ignore - # elif self.mode == "discrete": # classification - # a = self(batch).logits - # a_ = to_torch(batch.act, dtype=torch.long, device=a.device) - # loss = F.nll_loss(a, a_) # type: ignore - # loss.backward() - # self.optim.step() - # return {"loss": loss.item()} + return {"loss": Q_loss.item()} diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 46a64d80c..df0eb4810 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -47,88 +47,9 @@ def offline_trainer( best_reward = test_result["rew"] best_policy = policy - print(loss["loss"]) - print(test_result) - print(best_reward) - print("---------------") + print("loss:", loss["loss"]) + print("test_result:", test_result) + print("best_reward:", best_reward) + print(f"------- epoch: {epoch}, iter: {total_iter} --------") return {"best_reward": best_reward, "best_policy": best_policy} - - # env_step, gradient_step = 0, 0 - # best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 - # stat: Dict[str, MovAvg] = {} - # start_time = time.time() - # train_collector.reset_stat() - # test_collector.reset_stat() - # test_in_train = test_in_train and train_collector.policy == policy - # for epoch in range(1, 1 + max_epoch): - # # train - # policy.train() - # with tqdm.tqdm( - # total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config - # ) as t: - # while t.n < t.total: - # if train_fn: - # train_fn(epoch, env_step) - # result = train_collector.collect(n_step=collect_per_step) - # env_step += int(result["n/st"]) - # data = { - # "env_step": str(env_step), - # "rew": f"{result['rew']:.2f}", - # "len": str(int(result["len"])), - # "n/ep": str(int(result["n/ep"])), - # "n/st": str(int(result["n/st"])), - # "v/ep": f"{result['v/ep']:.2f}", - # "v/st": f"{result['v/st']:.2f}", - # } - # if writer and env_step % log_interval == 0: - # for k in result.keys(): - # writer.add_scalar( - # "train/" + k, result[k], global_step=env_step) - # if test_in_train and stop_fn and stop_fn(result["rew"]): - # test_result = test_episode( - # policy, test_collector, test_fn, - # epoch, episode_per_test, writer, env_step) - # if stop_fn(test_result["rew"]): - # if save_fn: - # save_fn(policy) - # for k in result.keys(): - # data[k] = f"{result[k]:.2f}" - # t.set_postfix(**data) - # return gather_info( - # start_time, train_collector, test_collector, - # test_result["rew"], test_result["rew_std"]) - # else: - # policy.train() - # for i in range(update_per_step * min( - # result["n/st"] // collect_per_step, t.total - t.n)): - # gradient_step += 1 - # losses = policy.update(batch_size, train_collector.buffer) - # for k in losses.keys(): - # if stat.get(k) is None: - # stat[k] = MovAvg() - # stat[k].add(losses[k]) - # data[k] = f"{stat[k].get():.6f}" - # if writer and gradient_step % log_interval == 0: - # writer.add_scalar( - # k, stat[k].get(), global_step=gradient_step) - # t.update(1) - # t.set_postfix(**data) - # if t.n <= t.total: - # t.update() - # # test - # result = test_episode(policy, test_collector, test_fn, epoch, - # episode_per_test, writer, env_step) - # if best_epoch == -1 or best_reward < result["rew"]: - # best_reward, best_reward_std = result["rew"], result["rew_std"] - # best_epoch = epoch - # if save_fn: - # save_fn(policy) - # if verbose: - # print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " - # f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " - # f"{best_reward_std:.6f} in #{best_epoch}") - # if stop_fn and stop_fn(best_reward): - # break - # return gather_info(start_time, train_collector, test_collector, - # best_reward, best_reward_std) From 5c52a18281def45dceec2a0ca6c2c1682a3acbbc Mon Sep 17 00:00:00 2001 From: Jialu Zhu Date: Tue, 15 Dec 2020 00:16:40 -0800 Subject: [PATCH 06/23] almost --- test/discrete/test_bcq.py | 8 +-- test/discrete/test_dqn.py | 3 +- tianshou/policy/modelfree/bcq.py | 3 - tianshou/trainer/offline.py | 8 +-- zjlenv/bin/activate | 78 ----------------------- zjlenv/bin/activate.csh | 42 ------------- zjlenv/bin/activate.fish | 101 ------------------------------ zjlenv/bin/activate.ps1 | 60 ------------------ zjlenv/bin/activate.xsh | 39 ------------ zjlenv/bin/activate_this.py | 46 -------------- zjlenv/bin/chardetect | 8 --- zjlenv/bin/convert-caffe2-to-onnx | 8 --- zjlenv/bin/convert-onnx-to-caffe2 | 8 --- zjlenv/bin/coverage | 8 --- zjlenv/bin/coverage-3.7 | 8 --- zjlenv/bin/coverage3 | 8 --- zjlenv/bin/dmypy | 8 --- zjlenv/bin/doc8 | 8 --- zjlenv/bin/easy_install | 10 --- zjlenv/bin/easy_install-3.7 | 10 --- zjlenv/bin/f2py | 8 --- zjlenv/bin/f2py3 | 8 --- zjlenv/bin/f2py3.7 | 8 --- zjlenv/bin/flake8 | 8 --- zjlenv/bin/futurize | 8 --- zjlenv/bin/google-oauthlib-tool | 8 --- zjlenv/bin/jsonschema | 8 --- zjlenv/bin/markdown_py | 8 --- zjlenv/bin/mypy | 8 --- zjlenv/bin/mypyc | 55 ---------------- zjlenv/bin/numba | 8 --- zjlenv/bin/pasteurize | 8 --- zjlenv/bin/pbr | 8 --- zjlenv/bin/pip | 10 --- zjlenv/bin/pip3 | 10 --- zjlenv/bin/pip3.7 | 10 --- zjlenv/bin/py.test | 8 --- zjlenv/bin/pybabel | 8 --- zjlenv/bin/pybtex | 8 --- zjlenv/bin/pybtex-convert | 8 --- zjlenv/bin/pybtex-format | 8 --- zjlenv/bin/pycc | 3 - zjlenv/bin/pycodestyle | 8 --- zjlenv/bin/pydocstyle | 8 --- zjlenv/bin/pyflakes | 8 --- zjlenv/bin/pygmentize | 8 --- zjlenv/bin/pyrsa-decrypt | 8 --- zjlenv/bin/pyrsa-encrypt | 8 --- zjlenv/bin/pyrsa-keygen | 8 --- zjlenv/bin/pyrsa-priv2pub | 8 --- zjlenv/bin/pyrsa-sign | 8 --- zjlenv/bin/pyrsa-verify | 8 --- zjlenv/bin/pytest | 8 --- zjlenv/bin/python | 1 - zjlenv/bin/python-config | 78 ----------------------- zjlenv/bin/python3 | 1 - zjlenv/bin/python3.7 | Bin 8632 -> 0 bytes zjlenv/bin/ray | 8 --- zjlenv/bin/restructuredtext-lint | 8 --- zjlenv/bin/rllib | 8 --- zjlenv/bin/rst-lint | 8 --- zjlenv/bin/rst2html.py | 23 ------- zjlenv/bin/rst2html4.py | 26 -------- zjlenv/bin/rst2html5.py | 35 ----------- zjlenv/bin/rst2latex.py | 26 -------- zjlenv/bin/rst2man.py | 26 -------- zjlenv/bin/rst2odt.py | 30 --------- zjlenv/bin/rst2odt_prepstyles.py | 67 -------------------- zjlenv/bin/rst2pseudoxml.py | 23 ------- zjlenv/bin/rst2s5.py | 24 ------- zjlenv/bin/rst2xetex.py | 27 -------- zjlenv/bin/rst2xml.py | 23 ------- zjlenv/bin/rstpep2html.py | 25 -------- zjlenv/bin/sphinx-apidoc | 8 --- zjlenv/bin/sphinx-autogen | 8 --- zjlenv/bin/sphinx-build | 8 --- zjlenv/bin/sphinx-quickstart | 8 --- zjlenv/bin/stubgen | 8 --- zjlenv/bin/stubtest | 8 --- zjlenv/bin/tensorboard | 8 --- zjlenv/bin/tqdm | 8 --- zjlenv/bin/tune | 8 --- zjlenv/bin/wheel | 10 --- zjlenv/include/python3.7m | 1 - 84 files changed, 8 insertions(+), 1326 deletions(-) delete mode 100644 zjlenv/bin/activate delete mode 100644 zjlenv/bin/activate.csh delete mode 100644 zjlenv/bin/activate.fish delete mode 100644 zjlenv/bin/activate.ps1 delete mode 100644 zjlenv/bin/activate.xsh delete mode 100644 zjlenv/bin/activate_this.py delete mode 100755 zjlenv/bin/chardetect delete mode 100755 zjlenv/bin/convert-caffe2-to-onnx delete mode 100755 zjlenv/bin/convert-onnx-to-caffe2 delete mode 100755 zjlenv/bin/coverage delete mode 100755 zjlenv/bin/coverage-3.7 delete mode 100755 zjlenv/bin/coverage3 delete mode 100755 zjlenv/bin/dmypy delete mode 100755 zjlenv/bin/doc8 delete mode 100755 zjlenv/bin/easy_install delete mode 100755 zjlenv/bin/easy_install-3.7 delete mode 100755 zjlenv/bin/f2py delete mode 100755 zjlenv/bin/f2py3 delete mode 100755 zjlenv/bin/f2py3.7 delete mode 100755 zjlenv/bin/flake8 delete mode 100755 zjlenv/bin/futurize delete mode 100755 zjlenv/bin/google-oauthlib-tool delete mode 100755 zjlenv/bin/jsonschema delete mode 100755 zjlenv/bin/markdown_py delete mode 100755 zjlenv/bin/mypy delete mode 100755 zjlenv/bin/mypyc delete mode 100755 zjlenv/bin/numba delete mode 100755 zjlenv/bin/pasteurize delete mode 100755 zjlenv/bin/pbr delete mode 100755 zjlenv/bin/pip delete mode 100755 zjlenv/bin/pip3 delete mode 100755 zjlenv/bin/pip3.7 delete mode 100755 zjlenv/bin/py.test delete mode 100755 zjlenv/bin/pybabel delete mode 100755 zjlenv/bin/pybtex delete mode 100755 zjlenv/bin/pybtex-convert delete mode 100755 zjlenv/bin/pybtex-format delete mode 100755 zjlenv/bin/pycc delete mode 100755 zjlenv/bin/pycodestyle delete mode 100755 zjlenv/bin/pydocstyle delete mode 100755 zjlenv/bin/pyflakes delete mode 100755 zjlenv/bin/pygmentize delete mode 100755 zjlenv/bin/pyrsa-decrypt delete mode 100755 zjlenv/bin/pyrsa-encrypt delete mode 100755 zjlenv/bin/pyrsa-keygen delete mode 100755 zjlenv/bin/pyrsa-priv2pub delete mode 100755 zjlenv/bin/pyrsa-sign delete mode 100755 zjlenv/bin/pyrsa-verify delete mode 100755 zjlenv/bin/pytest delete mode 120000 zjlenv/bin/python delete mode 100755 zjlenv/bin/python-config delete mode 120000 zjlenv/bin/python3 delete mode 100755 zjlenv/bin/python3.7 delete mode 100755 zjlenv/bin/ray delete mode 100755 zjlenv/bin/restructuredtext-lint delete mode 100755 zjlenv/bin/rllib delete mode 100755 zjlenv/bin/rst-lint delete mode 100755 zjlenv/bin/rst2html.py delete mode 100755 zjlenv/bin/rst2html4.py delete mode 100755 zjlenv/bin/rst2html5.py delete mode 100755 zjlenv/bin/rst2latex.py delete mode 100755 zjlenv/bin/rst2man.py delete mode 100755 zjlenv/bin/rst2odt.py delete mode 100755 zjlenv/bin/rst2odt_prepstyles.py delete mode 100755 zjlenv/bin/rst2pseudoxml.py delete mode 100755 zjlenv/bin/rst2s5.py delete mode 100755 zjlenv/bin/rst2xetex.py delete mode 100755 zjlenv/bin/rst2xml.py delete mode 100755 zjlenv/bin/rstpep2html.py delete mode 100755 zjlenv/bin/sphinx-apidoc delete mode 100755 zjlenv/bin/sphinx-autogen delete mode 100755 zjlenv/bin/sphinx-build delete mode 100755 zjlenv/bin/sphinx-quickstart delete mode 100755 zjlenv/bin/stubgen delete mode 100755 zjlenv/bin/stubtest delete mode 100755 zjlenv/bin/tensorboard delete mode 100755 zjlenv/bin/tqdm delete mode 100755 zjlenv/bin/tune delete mode 100755 zjlenv/bin/wheel delete mode 120000 zjlenv/include/python3.7m diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py index fa258d6f4..0b06aa8dd 100644 --- a/test/discrete/test_bcq.py +++ b/test/discrete/test_bcq.py @@ -16,13 +16,13 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--batch-size", type=int, default=2) - parser.add_argument("--epoch", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--test-num", type=int, default=1) parser.add_argument("--logdir", type=str, default="log") - parser.add_argument("--buffer-size", type=int, default=10) + parser.add_argument("--buffer-size", type=int, default=1000) parser.add_argument("--hidden-dim", type=int, default=5) - parser.add_argument("--test-frequency", type=int, default=5) + parser.add_argument("--test-frequency", type=int, default=10) parser.add_argument("--target-update-frequency", type=int, default=5) parser.add_argument("--episode-per-test", type=int, default=5) parser.add_argument("--tau", type=float, default=0.8) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index ecab9c017..8ec9fa380 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -27,8 +27,7 @@ def get_args(): parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=2) - # parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=3) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) diff --git a/tianshou/policy/modelfree/bcq.py b/tianshou/policy/modelfree/bcq.py index abd7ccd46..ec84a019e 100644 --- a/tianshou/policy/modelfree/bcq.py +++ b/tianshou/policy/modelfree/bcq.py @@ -42,9 +42,6 @@ def forward(self, state): def __init__( self, model: torch.nn.Module, - # state_dim: int, - # action_dim: int, - # hidden_dim: int, optim: torch.optim.Optimizer, tau: float, target_update_frequency: int, diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index df0eb4810..4899ed32e 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -19,13 +19,11 @@ def offline_trainer( writer: Optional[SummaryWriter] = None, test_frequency: int = 1, ) -> Dict[str, Union[float, str]]: - best_reward = -1 best_policy = policy - total_iter = 0 - iter_per_epoch = len(buffer) // batch_size + for epoch in range(1, 1 + epochs): for iter in range(iter_per_epoch): total_iter += 1 @@ -45,11 +43,11 @@ def offline_trainer( if best_reward < test_result["rew"]: best_reward = test_result["rew"] - best_policy = policy + best_policy.load_state_dict(policy.state_dict()) + print(f"------- epoch: {epoch}, iter: {total_iter} --------") print("loss:", loss["loss"]) print("test_result:", test_result) print("best_reward:", best_reward) - print(f"------- epoch: {epoch}, iter: {total_iter} --------") return {"best_reward": best_reward, "best_policy": best_policy} diff --git a/zjlenv/bin/activate b/zjlenv/bin/activate deleted file mode 100644 index 2174f277c..000000000 --- a/zjlenv/bin/activate +++ /dev/null @@ -1,78 +0,0 @@ -# This file must be used with "source bin/activate" *from bash* -# you cannot run it directly - -deactivate () { - unset -f pydoc >/dev/null 2>&1 - - # reset old environment variables - # ! [ -z ${VAR+_} ] returns true if VAR is declared at all - if ! [ -z "${_OLD_VIRTUAL_PATH+_}" ] ; then - PATH="$_OLD_VIRTUAL_PATH" - export PATH - unset _OLD_VIRTUAL_PATH - fi - if ! [ -z "${_OLD_VIRTUAL_PYTHONHOME+_}" ] ; then - PYTHONHOME="$_OLD_VIRTUAL_PYTHONHOME" - export PYTHONHOME - unset _OLD_VIRTUAL_PYTHONHOME - fi - - # This should detect bash and zsh, which have a hash command that must - # be called to get it to forget past commands. Without forgetting - # past commands the $PATH changes we made may not be respected - if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ] ; then - hash -r 2>/dev/null - fi - - if ! [ -z "${_OLD_VIRTUAL_PS1+_}" ] ; then - PS1="$_OLD_VIRTUAL_PS1" - export PS1 - unset _OLD_VIRTUAL_PS1 - fi - - unset VIRTUAL_ENV - if [ ! "${1-}" = "nondestructive" ] ; then - # Self destruct! - unset -f deactivate - fi -} - -# unset irrelevant variables -deactivate nondestructive - -VIRTUAL_ENV="/Users/jialu.zhu/Desktop/tianshou/zjlenv" -export VIRTUAL_ENV - -_OLD_VIRTUAL_PATH="$PATH" -PATH="$VIRTUAL_ENV/bin:$PATH" -export PATH - -# unset PYTHONHOME if set -if ! [ -z "${PYTHONHOME+_}" ] ; then - _OLD_VIRTUAL_PYTHONHOME="$PYTHONHOME" - unset PYTHONHOME -fi - -if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT-}" ] ; then - _OLD_VIRTUAL_PS1="${PS1-}" - if [ "x" != x ] ; then - PS1="${PS1-}" - else - PS1="(`basename \"$VIRTUAL_ENV\"`) ${PS1-}" - fi - export PS1 -fi - -# Make sure to unalias pydoc if it's already there -alias pydoc 2>/dev/null >/dev/null && unalias pydoc || true - -pydoc () { - python -m pydoc "$@" -} - -# This should detect bash and zsh, which have a hash command that must -# be called to get it to forget past commands. Without forgetting -# past commands the $PATH changes we made may not be respected -if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ] ; then - hash -r 2>/dev/null -fi diff --git a/zjlenv/bin/activate.csh b/zjlenv/bin/activate.csh deleted file mode 100644 index eb7c950f9..000000000 --- a/zjlenv/bin/activate.csh +++ /dev/null @@ -1,42 +0,0 @@ -# This file must be used with "source bin/activate.csh" *from csh*. -# You cannot run it directly. -# Created by Davide Di Blasi . - -set newline='\ -' - -alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH:q" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT:q" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; test "\!:*" != "nondestructive" && unalias deactivate && unalias pydoc' - -# Unset irrelevant variables. -deactivate nondestructive - -setenv VIRTUAL_ENV "/Users/jialu.zhu/Desktop/tianshou/zjlenv" - -set _OLD_VIRTUAL_PATH="$PATH:q" -setenv PATH "$VIRTUAL_ENV:q/bin:$PATH:q" - - - -if ("" != "") then - set env_name = "" -else - set env_name = "$VIRTUAL_ENV:t:q" -endif - -# Could be in a non-interactive environment, -# in which case, $prompt is undefined and we wouldn't -# care about the prompt anyway. -if ( $?prompt ) then - set _OLD_VIRTUAL_PROMPT="$prompt:q" -if ( "$prompt:q" =~ *"$newline:q"* ) then - : -else - set prompt = "[$env_name:q] $prompt:q" -endif -endif - -unset env_name - -alias pydoc python -m pydoc - -rehash diff --git a/zjlenv/bin/activate.fish b/zjlenv/bin/activate.fish deleted file mode 100644 index 92f3e7df2..000000000 --- a/zjlenv/bin/activate.fish +++ /dev/null @@ -1,101 +0,0 @@ -# This file must be used using `source bin/activate.fish` *within a running fish ( http://fishshell.com ) session*. -# Do not run it directly. - -function _bashify_path -d "Converts a fish path to something bash can recognize" - set fishy_path $argv - set bashy_path $fishy_path[1] - for path_part in $fishy_path[2..-1] - set bashy_path "$bashy_path:$path_part" - end - echo $bashy_path -end - -function _fishify_path -d "Converts a bash path to something fish can recognize" - echo $argv | tr ':' '\n' -end - -function deactivate -d 'Exit virtualenv mode and return to the normal environment.' - # reset old environment variables - if test -n "$_OLD_VIRTUAL_PATH" - # https://github.com/fish-shell/fish-shell/issues/436 altered PATH handling - if test (echo $FISH_VERSION | tr "." "\n")[1] -lt 3 - set -gx PATH (_fishify_path $_OLD_VIRTUAL_PATH) - else - set -gx PATH $_OLD_VIRTUAL_PATH - end - set -e _OLD_VIRTUAL_PATH - end - - if test -n "$_OLD_VIRTUAL_PYTHONHOME" - set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME - set -e _OLD_VIRTUAL_PYTHONHOME - end - - if test -n "$_OLD_FISH_PROMPT_OVERRIDE" - # Set an empty local `$fish_function_path` to allow the removal of `fish_prompt` using `functions -e`. - set -l fish_function_path - - # Erase virtualenv's `fish_prompt` and restore the original. - functions -e fish_prompt - functions -c _old_fish_prompt fish_prompt - functions -e _old_fish_prompt - set -e _OLD_FISH_PROMPT_OVERRIDE - end - - set -e VIRTUAL_ENV - - if test "$argv[1]" != 'nondestructive' - # Self-destruct! - functions -e pydoc - functions -e deactivate - functions -e _bashify_path - functions -e _fishify_path - end -end - -# Unset irrelevant variables. -deactivate nondestructive - -set -gx VIRTUAL_ENV "/Users/jialu.zhu/Desktop/tianshou/zjlenv" - -# https://github.com/fish-shell/fish-shell/issues/436 altered PATH handling -if test (echo $FISH_VERSION | tr "." "\n")[1] -lt 3 - set -gx _OLD_VIRTUAL_PATH (_bashify_path $PATH) -else - set -gx _OLD_VIRTUAL_PATH $PATH -end -set -gx PATH "$VIRTUAL_ENV/bin" $PATH - -# Unset `$PYTHONHOME` if set. -if set -q PYTHONHOME - set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME - set -e PYTHONHOME -end - -function pydoc - python -m pydoc $argv -end - -if test -z "$VIRTUAL_ENV_DISABLE_PROMPT" - # Copy the current `fish_prompt` function as `_old_fish_prompt`. - functions -c fish_prompt _old_fish_prompt - - function fish_prompt - # Save the current $status, for fish_prompts that display it. - set -l old_status $status - - # Prompt override provided? - # If not, just prepend the environment name. - if test -n "" - printf '%s%s' "" (set_color normal) - else - printf '%s(%s) ' (set_color normal) (basename "$VIRTUAL_ENV") - end - - # Restore the original $status - echo "exit $old_status" | source - _old_fish_prompt - end - - set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV" -end diff --git a/zjlenv/bin/activate.ps1 b/zjlenv/bin/activate.ps1 deleted file mode 100644 index 6d8ae2aa4..000000000 --- a/zjlenv/bin/activate.ps1 +++ /dev/null @@ -1,60 +0,0 @@ -# This file must be dot sourced from PoSh; you cannot run it directly. Do this: . ./activate.ps1 - -$script:THIS_PATH = $myinvocation.mycommand.path -$script:BASE_DIR = split-path (resolve-path "$THIS_PATH/..") -Parent - -function global:deactivate([switch] $NonDestructive) -{ - if (test-path variable:_OLD_VIRTUAL_PATH) - { - $env:PATH = $variable:_OLD_VIRTUAL_PATH - remove-variable "_OLD_VIRTUAL_PATH" -scope global - } - - if (test-path function:_old_virtual_prompt) - { - $function:prompt = $function:_old_virtual_prompt - remove-item function:\_old_virtual_prompt - } - - if ($env:VIRTUAL_ENV) - { - $old_env = split-path $env:VIRTUAL_ENV -leaf - remove-item env:VIRTUAL_ENV -erroraction silentlycontinue - } - - if (!$NonDestructive) - { - # Self destruct! - remove-item function:deactivate - remove-item function:pydoc - } -} - -function global:pydoc -{ - python -m pydoc $args -} - -# unset irrelevant variables -deactivate -nondestructive - -$VIRTUAL_ENV = $BASE_DIR -$env:VIRTUAL_ENV = $VIRTUAL_ENV - -$global:_OLD_VIRTUAL_PATH = $env:PATH -$env:PATH = "$env:VIRTUAL_ENV/bin:" + $env:PATH -if (!$env:VIRTUAL_ENV_DISABLE_PROMPT) -{ - function global:_old_virtual_prompt - { - "" - } - $function:_old_virtual_prompt = $function:prompt - function global:prompt - { - # Add a prefix to the current prompt, but don't discard it. - write-host "($( split-path $env:VIRTUAL_ENV -leaf )) " -nonewline - & $function:_old_virtual_prompt - } -} diff --git a/zjlenv/bin/activate.xsh b/zjlenv/bin/activate.xsh deleted file mode 100644 index 2dde4b647..000000000 --- a/zjlenv/bin/activate.xsh +++ /dev/null @@ -1,39 +0,0 @@ -"""Xonsh activate script for virtualenv""" -from xonsh.tools import get_sep as _get_sep - -def _deactivate(args): - if "pydoc" in aliases: - del aliases["pydoc"] - - if ${...}.get("_OLD_VIRTUAL_PATH", ""): - $PATH = $_OLD_VIRTUAL_PATH - del $_OLD_VIRTUAL_PATH - - if ${...}.get("_OLD_VIRTUAL_PYTHONHOME", ""): - $PYTHONHOME = $_OLD_VIRTUAL_PYTHONHOME - del $_OLD_VIRTUAL_PYTHONHOME - - if "VIRTUAL_ENV" in ${...}: - del $VIRTUAL_ENV - - if "nondestructive" not in args: - # Self destruct! - del aliases["deactivate"] - - -# unset irrelevant variables -_deactivate(["nondestructive"]) -aliases["deactivate"] = _deactivate - -$VIRTUAL_ENV = r"/Users/jialu.zhu/Desktop/tianshou/zjlenv" - -$_OLD_VIRTUAL_PATH = $PATH -$PATH = $PATH[:] -$PATH.add($VIRTUAL_ENV + _get_sep() + "bin", front=True, replace=True) - -if ${...}.get("PYTHONHOME", ""): - # unset PYTHONHOME if set - $_OLD_VIRTUAL_PYTHONHOME = $PYTHONHOME - del $PYTHONHOME - -aliases["pydoc"] = ["python", "-m", "pydoc"] diff --git a/zjlenv/bin/activate_this.py b/zjlenv/bin/activate_this.py deleted file mode 100644 index 59b5d7242..000000000 --- a/zjlenv/bin/activate_this.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Activate virtualenv for current interpreter: - -Use exec(open(this_file).read(), {'__file__': this_file}). - -This can be used when you must use an existing Python interpreter, not the virtualenv bin/python. -""" -import os -import site -import sys - -try: - __file__ -except NameError: - raise AssertionError("You must use exec(open(this_file).read(), {'__file__': this_file}))") - -# prepend bin to PATH (this file is inside the bin directory) -bin_dir = os.path.dirname(os.path.abspath(__file__)) -os.environ["PATH"] = os.pathsep.join([bin_dir] + os.environ.get("PATH", "").split(os.pathsep)) - -base = os.path.dirname(bin_dir) - -# virtual env is right above bin directory -os.environ["VIRTUAL_ENV"] = base - -# add the virtual environments site-package to the host python import mechanism -IS_PYPY = hasattr(sys, "pypy_version_info") -IS_JYTHON = sys.platform.startswith("java") -if IS_JYTHON: - site_packages = os.path.join(base, "Lib", "site-packages") -elif IS_PYPY: - site_packages = os.path.join(base, "site-packages") -else: - IS_WIN = sys.platform == "win32" - if IS_WIN: - site_packages = os.path.join(base, "Lib", "site-packages") - else: - site_packages = os.path.join(base, "lib", "python{}".format(sys.version[:3]), "site-packages") - -prev = set(sys.path) -site.addsitedir(site_packages) -sys.real_prefix = sys.prefix -sys.prefix = base - -# Move the added items to the front of the path, in place -new = list(sys.path) -sys.path[:] = [i for i in new if i not in prev] + [i for i in new if i in prev] diff --git a/zjlenv/bin/chardetect b/zjlenv/bin/chardetect deleted file mode 100755 index 39be4a317..000000000 --- a/zjlenv/bin/chardetect +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from chardet.cli.chardetect import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/convert-caffe2-to-onnx b/zjlenv/bin/convert-caffe2-to-onnx deleted file mode 100755 index 8da84784a..000000000 --- a/zjlenv/bin/convert-caffe2-to-onnx +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from caffe2.python.onnx.bin.conversion import caffe2_to_onnx -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(caffe2_to_onnx()) diff --git a/zjlenv/bin/convert-onnx-to-caffe2 b/zjlenv/bin/convert-onnx-to-caffe2 deleted file mode 100755 index 0a3e867c7..000000000 --- a/zjlenv/bin/convert-onnx-to-caffe2 +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from caffe2.python.onnx.bin.conversion import onnx_to_caffe2 -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(onnx_to_caffe2()) diff --git a/zjlenv/bin/coverage b/zjlenv/bin/coverage deleted file mode 100755 index 1d31b1d4d..000000000 --- a/zjlenv/bin/coverage +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from coverage.cmdline import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/coverage-3.7 b/zjlenv/bin/coverage-3.7 deleted file mode 100755 index 1d31b1d4d..000000000 --- a/zjlenv/bin/coverage-3.7 +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from coverage.cmdline import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/coverage3 b/zjlenv/bin/coverage3 deleted file mode 100755 index 1d31b1d4d..000000000 --- a/zjlenv/bin/coverage3 +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from coverage.cmdline import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/dmypy b/zjlenv/bin/dmypy deleted file mode 100755 index 8c7401b5a..000000000 --- a/zjlenv/bin/dmypy +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from mypy.dmypy.client import console_entry -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(console_entry()) diff --git a/zjlenv/bin/doc8 b/zjlenv/bin/doc8 deleted file mode 100755 index d6cd48c32..000000000 --- a/zjlenv/bin/doc8 +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from doc8.main import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/easy_install b/zjlenv/bin/easy_install deleted file mode 100755 index 11711685f..000000000 --- a/zjlenv/bin/easy_install +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys - -from setuptools.command.easy_install import main - -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/easy_install-3.7 b/zjlenv/bin/easy_install-3.7 deleted file mode 100755 index 11711685f..000000000 --- a/zjlenv/bin/easy_install-3.7 +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys - -from setuptools.command.easy_install import main - -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/f2py b/zjlenv/bin/f2py deleted file mode 100755 index 3ec4f2c7a..000000000 --- a/zjlenv/bin/f2py +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from numpy.f2py.f2py2e import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/f2py3 b/zjlenv/bin/f2py3 deleted file mode 100755 index 3ec4f2c7a..000000000 --- a/zjlenv/bin/f2py3 +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from numpy.f2py.f2py2e import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/f2py3.7 b/zjlenv/bin/f2py3.7 deleted file mode 100755 index 3ec4f2c7a..000000000 --- a/zjlenv/bin/f2py3.7 +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from numpy.f2py.f2py2e import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/flake8 b/zjlenv/bin/flake8 deleted file mode 100755 index 2a55e0048..000000000 --- a/zjlenv/bin/flake8 +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from flake8.main.cli import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/futurize b/zjlenv/bin/futurize deleted file mode 100755 index cc2e6b567..000000000 --- a/zjlenv/bin/futurize +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from libfuturize.main import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/google-oauthlib-tool b/zjlenv/bin/google-oauthlib-tool deleted file mode 100755 index 705e66710..000000000 --- a/zjlenv/bin/google-oauthlib-tool +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from google_auth_oauthlib.tool.__main__ import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/jsonschema b/zjlenv/bin/jsonschema deleted file mode 100755 index 22598c645..000000000 --- a/zjlenv/bin/jsonschema +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from jsonschema.cli import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/markdown_py b/zjlenv/bin/markdown_py deleted file mode 100755 index 837d31afd..000000000 --- a/zjlenv/bin/markdown_py +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from markdown.__main__ import run -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(run()) diff --git a/zjlenv/bin/mypy b/zjlenv/bin/mypy deleted file mode 100755 index 2b9ec040c..000000000 --- a/zjlenv/bin/mypy +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from mypy.__main__ import console_entry -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(console_entry()) diff --git a/zjlenv/bin/mypyc b/zjlenv/bin/mypyc deleted file mode 100755 index e9ff9b11a..000000000 --- a/zjlenv/bin/mypyc +++ /dev/null @@ -1,55 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -"""Mypyc command-line tool. - -Usage: - - $ mypyc foo.py [...] - $ python3 -c 'import foo' # Uses compiled 'foo' - - -This is just a thin wrapper that generates a setup.py file that uses -mypycify, suitable for prototyping and testing. -""" - -import os -import os.path -import subprocess -import sys -import tempfile -import time - -base_path = os.path.join(os.path.dirname(__file__), '..') - -setup_format = """\ -from distutils.core import setup -from mypyc.build import mypycify - -setup(name='mypyc_output', - ext_modules=mypycify({}, opt_level="{}"), -) -""" - -def main() -> None: - build_dir = 'build' # can this be overridden?? - try: - os.mkdir(build_dir) - except FileExistsError: - pass - - opt_level = os.getenv("MYPYC_OPT_LEVEL", '3') - - setup_file = os.path.join(build_dir, 'setup.py') - with open(setup_file, 'w') as f: - f.write(setup_format.format(sys.argv[1:], opt_level)) - - # We don't use run_setup (like we do in the test suite) because it throws - # away the error code from distutils, and we don't care about the slight - # performance loss here. - env = os.environ.copy() - base_path = os.path.join(os.path.dirname(__file__), '..') - env['PYTHONPATH'] = base_path + os.pathsep + env.get('PYTHONPATH', '') - cmd = subprocess.run([sys.executable, setup_file, 'build_ext', '--inplace'], env=env) - sys.exit(cmd.returncode) - -if __name__ == '__main__': - main() diff --git a/zjlenv/bin/numba b/zjlenv/bin/numba deleted file mode 100755 index 3eedf343e..000000000 --- a/zjlenv/bin/numba +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: UTF-8 -*- -from __future__ import print_function, division, absolute_import - -from numba.misc.numba_entry import main - -if __name__ == "__main__": - main() diff --git a/zjlenv/bin/pasteurize b/zjlenv/bin/pasteurize deleted file mode 100755 index 971710717..000000000 --- a/zjlenv/bin/pasteurize +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from libpasteurize.main import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pbr b/zjlenv/bin/pbr deleted file mode 100755 index 63eb221d5..000000000 --- a/zjlenv/bin/pbr +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pbr.cmd.main import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pip b/zjlenv/bin/pip deleted file mode 100755 index 90b0b9d2b..000000000 --- a/zjlenv/bin/pip +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys - -from pip._internal.cli.main import main - -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pip3 b/zjlenv/bin/pip3 deleted file mode 100755 index 90b0b9d2b..000000000 --- a/zjlenv/bin/pip3 +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys - -from pip._internal.cli.main import main - -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pip3.7 b/zjlenv/bin/pip3.7 deleted file mode 100755 index 90b0b9d2b..000000000 --- a/zjlenv/bin/pip3.7 +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys - -from pip._internal.cli.main import main - -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/py.test b/zjlenv/bin/py.test deleted file mode 100755 index 76ae5d160..000000000 --- a/zjlenv/bin/py.test +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pytest import console_main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(console_main()) diff --git a/zjlenv/bin/pybabel b/zjlenv/bin/pybabel deleted file mode 100755 index 6de123547..000000000 --- a/zjlenv/bin/pybabel +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from babel.messages.frontend import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pybtex b/zjlenv/bin/pybtex deleted file mode 100755 index 3a2fe04c1..000000000 --- a/zjlenv/bin/pybtex +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pybtex.__main__ import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pybtex-convert b/zjlenv/bin/pybtex-convert deleted file mode 100755 index 22a0ff550..000000000 --- a/zjlenv/bin/pybtex-convert +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pybtex.database.convert.__main__ import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pybtex-format b/zjlenv/bin/pybtex-format deleted file mode 100755 index f4982fc8e..000000000 --- a/zjlenv/bin/pybtex-format +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pybtex.database.format.__main__ import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pycc b/zjlenv/bin/pycc deleted file mode 100755 index fd239d851..000000000 --- a/zjlenv/bin/pycc +++ /dev/null @@ -1,3 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -from numba.pycc import main -main() diff --git a/zjlenv/bin/pycodestyle b/zjlenv/bin/pycodestyle deleted file mode 100755 index c20b84a57..000000000 --- a/zjlenv/bin/pycodestyle +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pycodestyle import _main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(_main()) diff --git a/zjlenv/bin/pydocstyle b/zjlenv/bin/pydocstyle deleted file mode 100755 index aab2f9446..000000000 --- a/zjlenv/bin/pydocstyle +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pydocstyle.cli import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pyflakes b/zjlenv/bin/pyflakes deleted file mode 100755 index de8fd38b8..000000000 --- a/zjlenv/bin/pyflakes +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pyflakes.api import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pygmentize b/zjlenv/bin/pygmentize deleted file mode 100755 index 811516c4e..000000000 --- a/zjlenv/bin/pygmentize +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pygments.cmdline import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/pyrsa-decrypt b/zjlenv/bin/pyrsa-decrypt deleted file mode 100755 index a56dfc1e3..000000000 --- a/zjlenv/bin/pyrsa-decrypt +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from rsa.cli import decrypt -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(decrypt()) diff --git a/zjlenv/bin/pyrsa-encrypt b/zjlenv/bin/pyrsa-encrypt deleted file mode 100755 index 40244720b..000000000 --- a/zjlenv/bin/pyrsa-encrypt +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from rsa.cli import encrypt -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(encrypt()) diff --git a/zjlenv/bin/pyrsa-keygen b/zjlenv/bin/pyrsa-keygen deleted file mode 100755 index 17d3b5200..000000000 --- a/zjlenv/bin/pyrsa-keygen +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from rsa.cli import keygen -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(keygen()) diff --git a/zjlenv/bin/pyrsa-priv2pub b/zjlenv/bin/pyrsa-priv2pub deleted file mode 100755 index 2a0676a54..000000000 --- a/zjlenv/bin/pyrsa-priv2pub +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from rsa.util import private_to_public -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(private_to_public()) diff --git a/zjlenv/bin/pyrsa-sign b/zjlenv/bin/pyrsa-sign deleted file mode 100755 index 057548544..000000000 --- a/zjlenv/bin/pyrsa-sign +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from rsa.cli import sign -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(sign()) diff --git a/zjlenv/bin/pyrsa-verify b/zjlenv/bin/pyrsa-verify deleted file mode 100755 index cf4ba7bc1..000000000 --- a/zjlenv/bin/pyrsa-verify +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from rsa.cli import verify -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(verify()) diff --git a/zjlenv/bin/pytest b/zjlenv/bin/pytest deleted file mode 100755 index 76ae5d160..000000000 --- a/zjlenv/bin/pytest +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from pytest import console_main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(console_main()) diff --git a/zjlenv/bin/python b/zjlenv/bin/python deleted file mode 120000 index 940bee389..000000000 --- a/zjlenv/bin/python +++ /dev/null @@ -1 +0,0 @@ -python3.7 \ No newline at end of file diff --git a/zjlenv/bin/python-config b/zjlenv/bin/python-config deleted file mode 100755 index e33168611..000000000 --- a/zjlenv/bin/python-config +++ /dev/null @@ -1,78 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python - -import sys -import getopt -import sysconfig - -valid_opts = ['prefix', 'exec-prefix', 'includes', 'libs', 'cflags', - 'ldflags', 'help'] - -if sys.version_info >= (3, 2): - valid_opts.insert(-1, 'extension-suffix') - valid_opts.append('abiflags') -if sys.version_info >= (3, 3): - valid_opts.append('configdir') - - -def exit_with_usage(code=1): - sys.stderr.write("Usage: {0} [{1}]\n".format( - sys.argv[0], '|'.join('--'+opt for opt in valid_opts))) - sys.exit(code) - -try: - opts, args = getopt.getopt(sys.argv[1:], '', valid_opts) -except getopt.error: - exit_with_usage() - -if not opts: - exit_with_usage() - -pyver = sysconfig.get_config_var('VERSION') -getvar = sysconfig.get_config_var - -opt_flags = [flag for (flag, val) in opts] - -if '--help' in opt_flags: - exit_with_usage(code=0) - -for opt in opt_flags: - if opt == '--prefix': - print(sysconfig.get_config_var('prefix')) - - elif opt == '--exec-prefix': - print(sysconfig.get_config_var('exec_prefix')) - - elif opt in ('--includes', '--cflags'): - flags = ['-I' + sysconfig.get_path('include'), - '-I' + sysconfig.get_path('platinclude')] - if opt == '--cflags': - flags.extend(getvar('CFLAGS').split()) - print(' '.join(flags)) - - elif opt in ('--libs', '--ldflags'): - abiflags = getattr(sys, 'abiflags', '') - libs = ['-lpython' + pyver + abiflags] - libs += getvar('LIBS').split() - libs += getvar('SYSLIBS').split() - # add the prefix/lib/pythonX.Y/config dir, but only if there is no - # shared library in prefix/lib/. - if opt == '--ldflags': - if not getvar('Py_ENABLE_SHARED'): - libs.insert(0, '-L' + getvar('LIBPL')) - if not getvar('PYTHONFRAMEWORK'): - libs.extend(getvar('LINKFORSHARED').split()) - print(' '.join(libs)) - - elif opt == '--extension-suffix': - ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') - if ext_suffix is None: - ext_suffix = sysconfig.get_config_var('SO') - print(ext_suffix) - - elif opt == '--abiflags': - if not getattr(sys, 'abiflags', None): - exit_with_usage() - print(sys.abiflags) - - elif opt == '--configdir': - print(sysconfig.get_config_var('LIBPL')) diff --git a/zjlenv/bin/python3 b/zjlenv/bin/python3 deleted file mode 120000 index 940bee389..000000000 --- a/zjlenv/bin/python3 +++ /dev/null @@ -1 +0,0 @@ -python3.7 \ No newline at end of file diff --git a/zjlenv/bin/python3.7 b/zjlenv/bin/python3.7 deleted file mode 100755 index 8e1f2b49fe2438795c05cacff8e2a132f794f570..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8632 zcmeHMOK%fN5bhyBvS5M3uC%Z$$RI#GRx=BPhgN%tS{sIy=_yOGExTh5t9yzSEXAYR}>*-ErOoF%~bxGy!>Z<;_=Ib7-?Rxd` zzfZ$LI733rE)znmgU+rH;(;ijC&UKm38<8l#p9(brSoSwniVH&E@~0y9|TG{Rk}Qt z4N>#+*)gGQG>^<8hBD zv@>|g@cgHgK~>(2{6>%r4dQ*!c{$Gp9;N zPnt7aNN`@5#bHy7O&sIm_*T4Lzq30L?+YPCN4IwX#>~)+?|n8==hO)X2-#&40Jr032t;^zg!Pw z%a5y`>w4~FH?B5AVU}C~-tQjrv=dkq@^lz#@^_r4_?mGRxC4L5qqs+0@K`HzmJmCj z|4JK@Zh2qkeR>Qn^`5?4v(MiwIow0@>K6fw1V;R% zx-jmZ7tT!H$>)Ygdg>;w9lGfOYmwn>khLrWmI2FvWxz6E8L$jk1}p=X0n318z%pPN z_^ucj-I_nggnt}ou}>AlWf}Em6tB-kq9>yKx^dM0yK*|L&7AgYA;y$AQZ}kGsiFm_ zcl(_3^OB{<|2HSoKK`3yIiIC1)XS9ohEMaC1Lt3d@^6fXh()9R;-H@;Ui#PHXf|BG z)v5<>wj>DOke-9Gqrx -# Copyright: This module has been placed in the public domain. - -""" -A minimal front end to the Docutils Publisher, producing HTML. -""" - -try: - import locale - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline, default_description - - -description = ('Generates (X)HTML documents from standalone reStructuredText ' - 'sources. ' + default_description) - -publish_cmdline(writer_name='html', description=description) diff --git a/zjlenv/bin/rst2html4.py b/zjlenv/bin/rst2html4.py deleted file mode 100755 index 29a95c807..000000000 --- a/zjlenv/bin/rst2html4.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# $Id: rst2html4.py 7994 2016-12-10 17:41:45Z milde $ -# Author: David Goodger -# Copyright: This module has been placed in the public domain. - -""" -A minimal front end to the Docutils Publisher, producing (X)HTML. - -The output conforms to XHTML 1.0 transitional -and almost to HTML 4.01 transitional (except for closing empty tags). -""" - -try: - import locale - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline, default_description - - -description = ('Generates (X)HTML documents from standalone reStructuredText ' - 'sources. ' + default_description) - -publish_cmdline(writer_name='html4', description=description) diff --git a/zjlenv/bin/rst2html5.py b/zjlenv/bin/rst2html5.py deleted file mode 100755 index e8a9f2af3..000000000 --- a/zjlenv/bin/rst2html5.py +++ /dev/null @@ -1,35 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf8 -*- -# :Copyright: © 2015 Günter Milde. -# :License: Released under the terms of the `2-Clause BSD license`_, in short: -# -# Copying and distribution of this file, with or without modification, -# are permitted in any medium without royalty provided the copyright -# notice and this notice are preserved. -# This file is offered as-is, without any warranty. -# -# .. _2-Clause BSD license: http://www.spdx.org/licenses/BSD-2-Clause -# -# Revision: $Revision: 8410 $ -# Date: $Date: 2019-11-04 22:14:43 +0100 (Mo, 04. Nov 2019) $ - -""" -A minimal front end to the Docutils Publisher, producing HTML 5 documents. - -The output also conforms to XHTML 1.0 transitional -(except for the doctype declaration). -""" - -try: - import locale # module missing in Jython - locale.setlocale(locale.LC_ALL, '') -except locale.Error: - pass - -from docutils.core import publish_cmdline, default_description - -description = (u'Generates HTML 5 documents from standalone ' - u'reStructuredText sources ' - + default_description) - -publish_cmdline(writer_name='html5', description=description) diff --git a/zjlenv/bin/rst2latex.py b/zjlenv/bin/rst2latex.py deleted file mode 100755 index c7eef261d..000000000 --- a/zjlenv/bin/rst2latex.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# $Id: rst2latex.py 5905 2009-04-16 12:04:49Z milde $ -# Author: David Goodger -# Copyright: This module has been placed in the public domain. - -""" -A minimal front end to the Docutils Publisher, producing LaTeX. -""" - -try: - import locale - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline - -description = ('Generates LaTeX documents from standalone reStructuredText ' - 'sources. ' - 'Reads from (default is stdin) and writes to ' - ' (default is stdout). See ' - ' for ' - 'the full reference.') - -publish_cmdline(writer_name='latex', description=description) diff --git a/zjlenv/bin/rst2man.py b/zjlenv/bin/rst2man.py deleted file mode 100755 index aaccd9274..000000000 --- a/zjlenv/bin/rst2man.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# Author: -# Contact: grubert@users.sf.net -# Copyright: This module has been placed in the public domain. - -""" -man.py -====== - -This module provides a simple command line interface that uses the -man page writer to output from ReStructuredText source. -""" - -import locale -try: - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline, default_description -from docutils.writers import manpage - -description = ("Generates plain unix manual documents. " + default_description) - -publish_cmdline(writer=manpage.Writer(), description=description) diff --git a/zjlenv/bin/rst2odt.py b/zjlenv/bin/rst2odt.py deleted file mode 100755 index f8f9b9e8b..000000000 --- a/zjlenv/bin/rst2odt.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# $Id: rst2odt.py 5839 2009-01-07 19:09:28Z dkuhlman $ -# Author: Dave Kuhlman -# Copyright: This module has been placed in the public domain. - -""" -A front end to the Docutils Publisher, producing OpenOffice documents. -""" - -import sys -try: - import locale - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline_to_binary, default_description -from docutils.writers.odf_odt import Writer, Reader - - -description = ('Generates OpenDocument/OpenOffice/ODF documents from ' - 'standalone reStructuredText sources. ' + default_description) - - -writer = Writer() -reader = Reader() -output = publish_cmdline_to_binary(reader=reader, writer=writer, - description=description) - diff --git a/zjlenv/bin/rst2odt_prepstyles.py b/zjlenv/bin/rst2odt_prepstyles.py deleted file mode 100755 index 656c33e93..000000000 --- a/zjlenv/bin/rst2odt_prepstyles.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# $Id: rst2odt_prepstyles.py 8346 2019-08-26 12:11:32Z milde $ -# Author: Dave Kuhlman -# Copyright: This module has been placed in the public domain. - -""" -Fix a word-processor-generated styles.odt for odtwriter use: Drop page size -specifications from styles.xml in STYLE_FILE.odt. -""" - -# Author: Michael Schutte - -from __future__ import print_function - -from lxml import etree -import sys -import zipfile -from tempfile import mkstemp -import shutil -import os - -NAMESPACES = { - "style": "urn:oasis:names:tc:opendocument:xmlns:style:1.0", - "fo": "urn:oasis:names:tc:opendocument:xmlns:xsl-fo-compatible:1.0" -} - - -def prepstyle(filename): - - zin = zipfile.ZipFile(filename) - styles = zin.read("styles.xml") - - root = etree.fromstring(styles) - for el in root.xpath("//style:page-layout-properties", - namespaces=NAMESPACES): - for attr in el.attrib: - if attr.startswith("{%s}" % NAMESPACES["fo"]): - del el.attrib[attr] - - tempname = mkstemp() - zout = zipfile.ZipFile(os.fdopen(tempname[0], "w"), "w", - zipfile.ZIP_DEFLATED) - - for item in zin.infolist(): - if item.filename == "styles.xml": - zout.writestr(item, etree.tostring(root)) - else: - zout.writestr(item, zin.read(item.filename)) - - zout.close() - zin.close() - shutil.move(tempname[1], filename) - - -def main(): - args = sys.argv[1:] - if len(args) != 1: - print(__doc__, file=sys.stderr) - print("Usage: %s STYLE_FILE.odt\n" % sys.argv[0], file=sys.stderr) - sys.exit(1) - filename = args[0] - prepstyle(filename) - - -if __name__ == '__main__': - main() diff --git a/zjlenv/bin/rst2pseudoxml.py b/zjlenv/bin/rst2pseudoxml.py deleted file mode 100755 index fefd51aa3..000000000 --- a/zjlenv/bin/rst2pseudoxml.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# $Id: rst2pseudoxml.py 4564 2006-05-21 20:44:42Z wiemann $ -# Author: David Goodger -# Copyright: This module has been placed in the public domain. - -""" -A minimal front end to the Docutils Publisher, producing pseudo-XML. -""" - -try: - import locale - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline, default_description - - -description = ('Generates pseudo-XML from standalone reStructuredText ' - 'sources (for testing purposes). ' + default_description) - -publish_cmdline(description=description) diff --git a/zjlenv/bin/rst2s5.py b/zjlenv/bin/rst2s5.py deleted file mode 100755 index 66a257d1c..000000000 --- a/zjlenv/bin/rst2s5.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# $Id: rst2s5.py 4564 2006-05-21 20:44:42Z wiemann $ -# Author: Chris Liechti -# Copyright: This module has been placed in the public domain. - -""" -A minimal front end to the Docutils Publisher, producing HTML slides using -the S5 template system. -""" - -try: - import locale - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline, default_description - - -description = ('Generates S5 (X)HTML slideshow documents from standalone ' - 'reStructuredText sources. ' + default_description) - -publish_cmdline(writer_name='s5', description=description) diff --git a/zjlenv/bin/rst2xetex.py b/zjlenv/bin/rst2xetex.py deleted file mode 100755 index 835e34da8..000000000 --- a/zjlenv/bin/rst2xetex.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# $Id: rst2xetex.py 7847 2015-03-17 17:30:47Z milde $ -# Author: Guenter Milde -# Copyright: This module has been placed in the public domain. - -""" -A minimal front end to the Docutils Publisher, producing Lua/XeLaTeX code. -""" - -try: - import locale - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline - -description = ('Generates LaTeX documents from standalone reStructuredText ' - 'sources for compilation with the Unicode-aware TeX variants ' - 'XeLaTeX or LuaLaTeX. ' - 'Reads from (default is stdin) and writes to ' - ' (default is stdout). See ' - ' for ' - 'the full reference.') - -publish_cmdline(writer_name='xetex', description=description) diff --git a/zjlenv/bin/rst2xml.py b/zjlenv/bin/rst2xml.py deleted file mode 100755 index bc61abe8c..000000000 --- a/zjlenv/bin/rst2xml.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# $Id: rst2xml.py 4564 2006-05-21 20:44:42Z wiemann $ -# Author: David Goodger -# Copyright: This module has been placed in the public domain. - -""" -A minimal front end to the Docutils Publisher, producing Docutils XML. -""" - -try: - import locale - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline, default_description - - -description = ('Generates Docutils-native XML from standalone ' - 'reStructuredText sources. ' + default_description) - -publish_cmdline(writer_name='xml', description=description) diff --git a/zjlenv/bin/rstpep2html.py b/zjlenv/bin/rstpep2html.py deleted file mode 100755 index 0f9d6c4f6..000000000 --- a/zjlenv/bin/rstpep2html.py +++ /dev/null @@ -1,25 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 - -# $Id: rstpep2html.py 4564 2006-05-21 20:44:42Z wiemann $ -# Author: David Goodger -# Copyright: This module has been placed in the public domain. - -""" -A minimal front end to the Docutils Publisher, producing HTML from PEP -(Python Enhancement Proposal) documents. -""" - -try: - import locale - locale.setlocale(locale.LC_ALL, '') -except: - pass - -from docutils.core import publish_cmdline, default_description - - -description = ('Generates (X)HTML from reStructuredText-format PEP files. ' - + default_description) - -publish_cmdline(reader_name='pep', writer_name='pep_html', - description=description) diff --git a/zjlenv/bin/sphinx-apidoc b/zjlenv/bin/sphinx-apidoc deleted file mode 100755 index 4eed5f711..000000000 --- a/zjlenv/bin/sphinx-apidoc +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from sphinx.ext.apidoc import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/sphinx-autogen b/zjlenv/bin/sphinx-autogen deleted file mode 100755 index 8bd5818af..000000000 --- a/zjlenv/bin/sphinx-autogen +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from sphinx.ext.autosummary.generate import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/sphinx-build b/zjlenv/bin/sphinx-build deleted file mode 100755 index b76dbb1b3..000000000 --- a/zjlenv/bin/sphinx-build +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from sphinx.cmd.build import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/sphinx-quickstart b/zjlenv/bin/sphinx-quickstart deleted file mode 100755 index a0ef78ffb..000000000 --- a/zjlenv/bin/sphinx-quickstart +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from sphinx.cmd.quickstart import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/stubgen b/zjlenv/bin/stubgen deleted file mode 100755 index d73e4a697..000000000 --- a/zjlenv/bin/stubgen +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from mypy.stubgen import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/stubtest b/zjlenv/bin/stubtest deleted file mode 100755 index fc6ce4a03..000000000 --- a/zjlenv/bin/stubtest +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from mypy.stubtest import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/tensorboard b/zjlenv/bin/tensorboard deleted file mode 100755 index c976b11fd..000000000 --- a/zjlenv/bin/tensorboard +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from tensorboard.main import run_main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(run_main()) diff --git a/zjlenv/bin/tqdm b/zjlenv/bin/tqdm deleted file mode 100755 index 85f1b2ae5..000000000 --- a/zjlenv/bin/tqdm +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from tqdm.cli import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/bin/tune b/zjlenv/bin/tune deleted file mode 100755 index 62a89f244..000000000 --- a/zjlenv/bin/tune +++ /dev/null @@ -1,8 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys -from ray.tune.scripts import cli -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(cli()) diff --git a/zjlenv/bin/wheel b/zjlenv/bin/wheel deleted file mode 100755 index 8e18047d6..000000000 --- a/zjlenv/bin/wheel +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/jialu.zhu/Desktop/tianshou/zjlenv/bin/python3.7 -# -*- coding: utf-8 -*- -import re -import sys - -from wheel.cli import main - -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/zjlenv/include/python3.7m b/zjlenv/include/python3.7m deleted file mode 120000 index 7e2d975bd..000000000 --- a/zjlenv/include/python3.7m +++ /dev/null @@ -1 +0,0 @@ -/Library/Frameworks/Python.framework/Versions/3.7/include/python3.7m \ No newline at end of file From 9bf02bedacbab269a8adf5675d2ab755bfe2ed52 Mon Sep 17 00:00:00 2001 From: Jialu Zhu Date: Wed, 6 Jan 2021 06:47:40 +0800 Subject: [PATCH 07/23] feedback --- test/discrete/test_bcq.py | 3 ++- tianshou/policy/__init__.py | 1 - tianshou/policy/modelfree/bcq.py | 40 +++++++------------------------- tianshou/utils/net/common.py | 26 +++++++++++++++++++++ 4 files changed, 37 insertions(+), 33 deletions(-) diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py index 0b06aa8dd..9029ddadf 100644 --- a/test/discrete/test_bcq.py +++ b/test/discrete/test_bcq.py @@ -10,6 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import offline_trainer from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.utils.net.common import BCQN def get_args(): @@ -41,7 +42,7 @@ def test_bcq(args=get_args()): state_shape = env.observation_space.shape or env.observation_space.n state_shape = state_shape[0] action_shape = env.action_space.shape or env.action_space.n - model = BCQPolicy.BCQN(state_shape, action_shape, args.hidden_dim, args.hidden_dim) + model = BCQN(state_shape, action_shape, args.hidden_dim, args.hidden_dim) optim = torch.optim.Adam(model.parameters(), lr=args.learning_rate) policy = BCQPolicy( diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index d21efdd73..28cff3076 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -28,6 +28,5 @@ "DiscreteSACPolicy", "PSRLPolicy", "MultiAgentPolicyManager", - "BCQManager", "BCQPolicy", ] diff --git a/tianshou/policy/modelfree/bcq.py b/tianshou/policy/modelfree/bcq.py index ec84a019e..c1594e34b 100644 --- a/tianshou/policy/modelfree/bcq.py +++ b/tianshou/policy/modelfree/bcq.py @@ -1,6 +1,6 @@ import torch -import numpy as np import torch.nn.functional as F +import numpy as np from typing import Any, Dict, Union, Optional from tianshou.data import Batch, to_torch @@ -15,30 +15,6 @@ class BCQPolicy(BasePolicy): """ - class BCQN(nn.Module): - """BCQ NN for dialogue policy. It includes a net for imitation and a net for Q-value""" - - def __init__( - self, input_size, n_actions, imitation_model_hidden_dim, policy_model_hidden_dim - ): - super(BCQPolicy.BCQN, self).__init__() - self.q1 = nn.Linear(input_size, policy_model_hidden_dim) - self.q2 = nn.Linear(policy_model_hidden_dim, policy_model_hidden_dim) - self.q3 = nn.Linear(policy_model_hidden_dim, n_actions) - - self.i1 = nn.Linear(input_size, imitation_model_hidden_dim) - self.i2 = nn.Linear(imitation_model_hidden_dim, imitation_model_hidden_dim) - self.i3 = nn.Linear(imitation_model_hidden_dim, n_actions) - - def forward(self, state): - q = F.relu(self.q1(state)) - q = F.relu(self.q2(q)) - - i = F.relu(self.i1(state)) - i = F.relu(self.i2(i)) - i = F.relu(self.i3(i)) - return self.q3(q), F.log_softmax(i, dim=1), i - def __init__( self, model: torch.nn.Module, @@ -52,6 +28,7 @@ def __init__( ) -> None: super().__init__(**kwargs) self._policy_net = model + self._policy_net.to(device) self._optimizer = optim self._cnt = 0 self._device = device @@ -59,6 +36,7 @@ def __init__( self._tau = tau self._target_net = deepcopy(self._policy_net) self._target_net.eval() + self._target_net.to(device) self._target_update_frequency = target_update_frequency self._imitation_logits_penalty = imitation_logits_penalty @@ -70,7 +48,7 @@ def forward( ) -> Batch: batch.to_torch() - state = batch.obs + state = batch.obs.to(self._device) q, imt, _ = self._policy_net(state.float()) imt = imt.exp() imt = (imt / imt.max(1, keepdim=True)[0] > self._tau).float() @@ -93,12 +71,13 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: non_final_next_states = torch.cat( [s.obs.unsqueeze(0) for s in batch if not s.done], dim=0 ) + non_final_next_states = non_final_next_states.to(self._device) except Exception: non_final_next_states = None # Compute the target Q value with torch.no_grad(): - expected_state_action_values = batch.rew.float() + expected_state_action_values = batch.rew.float().to(self._device) # Add target Q value for non-final next_state if non_final_next_states is not None: @@ -108,7 +87,6 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # Use large negative number to mask actions from argmax next_action = (imt * q + (1 - imt) * -1e8).argmax(1, keepdim=True) - q, _, _ = self._target_net(non_final_next_states) q = q.gather(1, next_action).reshape(-1, 1) @@ -118,12 +96,12 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: expected_state_action_values += next_state_values * self._gamma # Get current Q estimate - current_Q, imt, i = self._policy_net(batch.obs) - current_Q = current_Q.gather(1, batch.act.unsqueeze(1)).squeeze() + current_Q, imt, i = self._policy_net(batch.obs.to(self._device)) + current_Q = current_Q.gather(1, batch.act.unsqueeze(1).to(self._device)).squeeze() # Compute Q loss q_loss = F.smooth_l1_loss(current_Q, expected_state_action_values) - i_loss = F.nll_loss(imt, batch.act.reshape(-1)) + i_loss = F.nll_loss(imt, batch.act.reshape(-1).to(self._device)) Q_loss = q_loss + i_loss + self._imitation_logits_penalty * i.pow(2).mean() diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 8b52a87a0..6ce6ed9f4 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,6 +1,7 @@ import torch import numpy as np from torch import nn +import torch.nn.functional as F from typing import Any, Dict, List, Tuple, Union, Callable, Optional, Sequence from tianshou.data import to_torch @@ -159,3 +160,28 @@ def forward( # please ensure the first dim is batch size: [bsz, len, ...] return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()} + + +class BCQN(nn.Module): + """BCQ NN for dialogue policy. It includes a net for imitation and a net for Q-value""" + + def __init__( + self, input_size, n_actions, imitation_model_hidden_dim, policy_model_hidden_dim + ): + super().__init__() + self.q1 = nn.Linear(input_size, policy_model_hidden_dim) + self.q2 = nn.Linear(policy_model_hidden_dim, policy_model_hidden_dim) + self.q3 = nn.Linear(policy_model_hidden_dim, n_actions) + + self.i1 = nn.Linear(input_size, imitation_model_hidden_dim) + self.i2 = nn.Linear(imitation_model_hidden_dim, imitation_model_hidden_dim) + self.i3 = nn.Linear(imitation_model_hidden_dim, n_actions) + + def forward(self, state): + q = F.relu(self.q1(state)) + q = F.relu(self.q2(q)) + + i = F.relu(self.i1(state)) + i = F.relu(self.i2(i)) + i = F.relu(self.i3(i)) + return self.q3(q), F.log_softmax(i, dim=1), i From 288c154d67358b897c1c64bf2ece1d1eccdecd37 Mon Sep 17 00:00:00 2001 From: Jialu Zhu Date: Wed, 6 Jan 2021 06:58:34 +0800 Subject: [PATCH 08/23] lint --- test/discrete/test_bcq.py | 7 ++++--- tianshou/policy/modelfree/bcq.py | 19 +++++++++++++------ tianshou/trainer/offline.py | 13 +++++++------ tianshou/utils/net/common.py | 13 ++++++++++--- 4 files changed, 34 insertions(+), 18 deletions(-) diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py index 9029ddadf..1c633be97 100644 --- a/test/discrete/test_bcq.py +++ b/test/discrete/test_bcq.py @@ -2,14 +2,13 @@ import os import gym import torch -import pprint import argparse import random import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.env import DummyVectorEnv from tianshou.trainer import offline_trainer -from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.common import BCQN @@ -31,7 +30,9 @@ def get_args(): parser.add_argument("--learning-rate", type=float, default=0.01) parser.add_argument("--imitation_logits_penalty", type=float, default=0.1) parser.add_argument( - "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", ) args = parser.parse_known_args()[0] return args diff --git a/tianshou/policy/modelfree/bcq.py b/tianshou/policy/modelfree/bcq.py index c1594e34b..c85865c61 100644 --- a/tianshou/policy/modelfree/bcq.py +++ b/tianshou/policy/modelfree/bcq.py @@ -3,9 +3,8 @@ import numpy as np from typing import Any, Dict, Union, Optional -from tianshou.data import Batch, to_torch +from tianshou.data import Batch from tianshou.policy import BasePolicy -import torch.nn as nn from copy import deepcopy @@ -86,24 +85,32 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: imt = (imt / imt.max(1, keepdim=True)[0] > self._tau).float() # Use large negative number to mask actions from argmax - next_action = (imt * q + (1 - imt) * -1e8).argmax(1, keepdim=True) + next_action = (imt * q + (1 - imt) * -1e8).argmax( + 1, keepdim=True + ) q, _, _ = self._target_net(non_final_next_states) q = q.gather(1, next_action).reshape(-1, 1) - next_state_values = torch.zeros(len(batch), device=self._device).float() + next_state_values = torch.zeros( + len(batch), device=self._device + ).float() next_state_values[non_final_mask] = q.squeeze() expected_state_action_values += next_state_values * self._gamma # Get current Q estimate current_Q, imt, i = self._policy_net(batch.obs.to(self._device)) - current_Q = current_Q.gather(1, batch.act.unsqueeze(1).to(self._device)).squeeze() + current_Q = current_Q.gather( + 1, batch.act.unsqueeze(1).to(self._device) + ).squeeze() # Compute Q loss q_loss = F.smooth_l1_loss(current_Q, expected_state_action_values) i_loss = F.nll_loss(imt, batch.act.reshape(-1).to(self._device)) - Q_loss = q_loss + i_loss + self._imitation_logits_penalty * i.pow(2).mean() + Q_loss = ( + q_loss + i_loss + self._imitation_logits_penalty * i.pow(2).mean() + ) self._optimizer.zero_grad() Q_loss.backward() diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 4899ed32e..20e95295d 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,12 +1,9 @@ -import time -import tqdm from torch.utils.tensorboard import SummaryWriter -from typing import Dict, List, Union, Callable, Optional +from typing import Dict, Union, Optional from tianshou.data import Collector, ReplayBuffer from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg -from tianshou.trainer import test_episode, gather_info +from tianshou.trainer import test_episode def offline_trainer( @@ -29,7 +26,11 @@ def offline_trainer( total_iter += 1 loss = policy.update(batch_size, buffer) if total_iter % test_frequency == 0: - writer.add_scalar("train/loss", loss["loss"], global_step=total_iter) + writer.add_scalar( + "train/loss", + loss["loss"], + global_step=total_iter, + ) test_result = test_episode( policy, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 6ce6ed9f4..502943213 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -163,10 +163,14 @@ def forward( class BCQN(nn.Module): - """BCQ NN for dialogue policy. It includes a net for imitation and a net for Q-value""" + """A net imitation and a net for Q-value""" def __init__( - self, input_size, n_actions, imitation_model_hidden_dim, policy_model_hidden_dim + self, + input_size: int, + n_actions: int, + imitation_model_hidden_dim: int, + policy_model_hidden_dim: int, ): super().__init__() self.q1 = nn.Linear(input_size, policy_model_hidden_dim) @@ -174,7 +178,10 @@ def __init__( self.q3 = nn.Linear(policy_model_hidden_dim, n_actions) self.i1 = nn.Linear(input_size, imitation_model_hidden_dim) - self.i2 = nn.Linear(imitation_model_hidden_dim, imitation_model_hidden_dim) + self.i2 = nn.Linear( + imitation_model_hidden_dim, + imitation_model_hidden_dim, + ) self.i3 = nn.Linear(imitation_model_hidden_dim, n_actions) def forward(self, state): From 48f00128a8fa1f2297bf3370d2a6cf36e0f14111 Mon Sep 17 00:00:00 2001 From: Jialu Zhu Date: Wed, 6 Jan 2021 12:24:32 +0800 Subject: [PATCH 09/23] . --- tianshou/policy/modelfree/bcq.py | 3 ++- tianshou/trainer/offline.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tianshou/policy/modelfree/bcq.py b/tianshou/policy/modelfree/bcq.py index c85865c61..f8f32de9b 100644 --- a/tianshou/policy/modelfree/bcq.py +++ b/tianshou/policy/modelfree/bcq.py @@ -66,13 +66,14 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: dtype=torch.bool, ) + non_final_next_states = None try: non_final_next_states = torch.cat( [s.obs.unsqueeze(0) for s in batch if not s.done], dim=0 ) non_final_next_states = non_final_next_states.to(self._device) except Exception: - non_final_next_states = None + pass # Compute the target Q value with torch.no_grad(): diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 20e95295d..6f967376f 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -15,8 +15,8 @@ def offline_trainer( episode_per_test: int, writer: Optional[SummaryWriter] = None, test_frequency: int = 1, -) -> Dict[str, Union[float, str]]: - best_reward = -1 +) -> Dict[str, Union[float, BasePolicy]]: + best_reward = -1.0 best_policy = policy total_iter = 0 iter_per_epoch = len(buffer) // batch_size @@ -26,11 +26,12 @@ def offline_trainer( total_iter += 1 loss = policy.update(batch_size, buffer) if total_iter % test_frequency == 0: - writer.add_scalar( - "train/loss", - loss["loss"], - global_step=total_iter, - ) + if writer is not None: + writer.add_scalar( + "train/loss", + loss["loss"], + global_step=total_iter, + ) test_result = test_episode( policy, From f4b9aa6469bd479efce460b663178dc27306850a Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 12 Jan 2021 10:04:34 +0800 Subject: [PATCH 10/23] resolve #269, #270 --- setup.py | 2 +- tianshou/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 5af0b206f..38c73b55f 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ def get_version() -> str: "tensorboard", "torch>=1.4.0", "numba>=0.51.0", - "h5py>=3.1.0" + "h5py>=2.10.0", # to match tensorflow's minimal reqiurements ], extras_require={ "dev": [ diff --git a/tianshou/__init__.py b/tianshou/__init__.py index cc37f2773..0f7a59c10 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, exploration -__version__ = "0.3.0" +__version__ = "0.3.1" __all__ = [ "env", From 36a137c89511c2da7b22b6b343fff87b8e534e53 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 12 Jan 2021 21:17:20 +0800 Subject: [PATCH 11/23] update BCQPolicy and BCQN --- test/discrete/test_bcq.py | 7 +- tianshou/policy/__init__.py | 2 +- tianshou/policy/imitation/bcq.py | 130 +++++++++++++++++++++++++++++++ tianshou/policy/modelfree/bcq.py | 124 ----------------------------- tianshou/policy/modelfree/c51.py | 4 +- tianshou/policy/modelfree/dqn.py | 6 +- tianshou/policy/modelfree/sac.py | 1 - tianshou/utils/net/common.py | 72 +++++++++-------- 8 files changed, 180 insertions(+), 166 deletions(-) create mode 100644 tianshou/policy/imitation/bcq.py delete mode 100644 tianshou/policy/modelfree/bcq.py diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py index 1c633be97..997863c68 100644 --- a/test/discrete/test_bcq.py +++ b/test/discrete/test_bcq.py @@ -1,15 +1,16 @@ -from tianshou.policy import BCQPolicy import os import gym import torch -import argparse import random +import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import BCQPolicy from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import BCQN from tianshou.trainer import offline_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.common import BCQN def get_args(): diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index f80fe67d1..a5a323c9e 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -10,7 +10,7 @@ from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy -from tianshou.policy.modelfree.bcq import BCQPolicy +from tianshou.policy.imitation.bcq import BCQPolicy from tianshou.policy.modelbase.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py new file mode 100644 index 000000000..4b939ec9e --- /dev/null +++ b/tianshou/policy/imitation/bcq.py @@ -0,0 +1,130 @@ +import torch +import numpy as np +from copy import deepcopy +import torch.nn.functional as F +from typing import Any, Dict, Union, Optional + +from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer, to_torch_as + + +class BCQPolicy(BasePolicy): + """Implementation for discrete BCQ algorithm.""" + + def __init__( + self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 8000, + eval_eps: float = 1e-3, + unlikely_action_threshold: float = 0.3, + imitation_logits_penalty: float = 1e-2, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + # init model + self.model = model + self.optim = optim + self.model_old = deepcopy(self.model) + self.model_old.eval() + self._iter = 0 + # init hparam + assert ( + 0.0 <= discount_factor <= 1.0 + ), "discount factor should be in [0, 1]" + self._gamma = discount_factor + assert ( + 0.0 <= unlikely_action_threshold < 1.0 + ), "unlikely_action_threshold should be in [0, 1)" + self._thres = unlikely_action_threshold + assert estimation_step > 0, "estimation_step should be greater than 0" + self._n_step = estimation_step + self._eps = eval_eps + self._freq = target_update_freq + self._w_imitation = imitation_logits_penalty + + def train(self, mode: bool = True) -> "BCQPolicy": + """Set the module in training mode, except for the target network.""" + self.training = mode + self.model.train(mode) + return self + + def sync_weight(self) -> None: + """Synchronize the weight for the target network.""" + self.model_old.load_state_dict(self.model.state_dict()) + + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: + batch = buffer[indice] # batch.obs_next: s_{t+n} + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + with torch.no_grad(): + act = self(batch, input="obs_next", eps=0.0).act + target_q = self( + batch, model="model_old", input="obs_next", eps=0.0 + ).logits + target_q = target_q[np.arange(len(act)), act] + return target_q + + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: + """Compute the n-step return for Q-learning targets. + + More details can be found at + :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. + """ + batch = self.compute_nstep_return( + batch, buffer, indice, self._target_q, + self._gamma, self._n_step, False) + return batch + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + model: str = "model", + input: str = "obs", + eps: Optional[float] = None, + **kwargs: Any, + ) -> Batch: + if eps is None: + eps = self._eps + obs = batch[input] + (q, imt, i), state = self.model(obs, state=state, info=batch.info) + imt = imt.exp() + imt = (imt / imt.max(1, keepdim=True)[0] > self._thres).float() + # Use large negative number to mask actions from argmax + action = (imt * q + (1.0 - imt) * -np.inf).argmax(-1) + assert len(action.shape) == 1 + + # add eps to act + if not np.isclose(eps, 0.0) and np.random.rand() < eps: + bsz, action_num = q.shape + action = np.random.randint(action_num, size=bsz) + + return Batch(logits=q, act=action, state=state) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + if self._iter % self._freq == 0: + self.sync_weight() + + target_q = batch.returns.flatten() + + (current_q, imt, i), _ = self.model(batch.obs) + current_q = current_q[np.arange(len(target_q)), batch.act] + + act = to_torch_as(batch.act, target_q) + q_loss = F.smooth_l1_loss(current_q, target_q) + i_loss = F.nll_loss(imt, act) # type: ignore + + loss = q_loss + i_loss + self._w_imitation * i.pow(2).mean() + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + self._iter += 1 + return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/bcq.py b/tianshou/policy/modelfree/bcq.py deleted file mode 100644 index f8f32de9b..000000000 --- a/tianshou/policy/modelfree/bcq.py +++ /dev/null @@ -1,124 +0,0 @@ -import torch -import torch.nn.functional as F -import numpy as np -from typing import Any, Dict, Union, Optional - -from tianshou.data import Batch -from tianshou.policy import BasePolicy -from copy import deepcopy - - -class BCQPolicy(BasePolicy): - """Implementation discrete BCQ algorithm. Some code is from - https://github.com/sfujim/BCQ/tree/master/discrete_BCQ - - """ - - def __init__( - self, - model: torch.nn.Module, - optim: torch.optim.Optimizer, - tau: float, - target_update_frequency: int, - device: str, - gamma: float, - imitation_logits_penalty: float, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self._policy_net = model - self._policy_net.to(device) - self._optimizer = optim - self._cnt = 0 - self._device = device - self._gamma = gamma - self._tau = tau - self._target_net = deepcopy(self._policy_net) - self._target_net.eval() - self._target_net.to(device) - self._target_update_frequency = target_update_frequency - self._imitation_logits_penalty = imitation_logits_penalty - - def forward( - self, - batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs: Any, - ) -> Batch: - batch.to_torch() - - state = batch.obs.to(self._device) - q, imt, _ = self._policy_net(state.float()) - imt = imt.exp() - imt = (imt / imt.max(1, keepdim=True)[0] > self._tau).float() - - # Use large negative number to mask actions from argmax - action = (imt * q + (1.0 - imt) * -1e8).argmax(1) - - return Batch(act=action, state=state, qvalue=q) - - def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - batch.to_torch() - - non_final_mask = torch.tensor( - tuple(map(lambda s: not s, batch.done)), - device=self._device, - dtype=torch.bool, - ) - - non_final_next_states = None - try: - non_final_next_states = torch.cat( - [s.obs.unsqueeze(0) for s in batch if not s.done], dim=0 - ) - non_final_next_states = non_final_next_states.to(self._device) - except Exception: - pass - - # Compute the target Q value - with torch.no_grad(): - expected_state_action_values = batch.rew.float().to(self._device) - - # Add target Q value for non-final next_state - if non_final_next_states is not None: - q, imt, _ = self._policy_net(non_final_next_states) - imt = imt.exp() - imt = (imt / imt.max(1, keepdim=True)[0] > self._tau).float() - - # Use large negative number to mask actions from argmax - next_action = (imt * q + (1 - imt) * -1e8).argmax( - 1, keepdim=True - ) - q, _, _ = self._target_net(non_final_next_states) - q = q.gather(1, next_action).reshape(-1, 1) - - next_state_values = torch.zeros( - len(batch), device=self._device - ).float() - next_state_values[non_final_mask] = q.squeeze() - - expected_state_action_values += next_state_values * self._gamma - - # Get current Q estimate - current_Q, imt, i = self._policy_net(batch.obs.to(self._device)) - current_Q = current_Q.gather( - 1, batch.act.unsqueeze(1).to(self._device) - ).squeeze() - - # Compute Q loss - q_loss = F.smooth_l1_loss(current_Q, expected_state_action_values) - i_loss = F.nll_loss(imt, batch.act.reshape(-1).to(self._device)) - - Q_loss = ( - q_loss + i_loss + self._imitation_logits_penalty * i.pow(2).mean() - ) - - self._optimizer.zero_grad() - Q_loss.backward() - self._optimizer.step() - - if self._cnt % self._target_update_frequency == 0: - self._target_net.load_state_dict(self._policy_net.state_dict()) - self._target_net.eval() - - return {"loss": Q_loss.item()} diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 96cc6801a..706872efb 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -124,7 +124,7 @@ def _target_dist(self, batch: Batch) -> torch.Tensor: return target_dist.sum(-1) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - if self._target and self._cnt % self._freq == 0: + if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() with torch.no_grad(): @@ -139,5 +139,5 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: batch.weight = cross_entropy.detach() # prio-buffer loss.backward() self.optim.step() - self._cnt += 1 + self._iter += 1 return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 91cca6139..3c72911a7 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -54,7 +54,7 @@ def __init__( self._n_step = estimation_step self._target = target_update_freq > 0 self._freq = target_update_freq - self._cnt = 0 + self._iter = 0 if self._target: self.model_old = deepcopy(self.model) self.model_old.eval() @@ -162,7 +162,7 @@ def forward( return Batch(logits=q, act=act, state=h) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - if self._target and self._cnt % self._freq == 0: + if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) @@ -174,5 +174,5 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: batch.weight = td # prio-buffer loss.backward() self.optim.step() - self._cnt += 1 + self._iter += 1 return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 014fbe6d5..983df7158 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -143,7 +143,6 @@ def _target_q( with torch.no_grad(): obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act - batch.act = to_torch_as(batch.act, a_) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_), diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index e77ea79a4..3b0d5f48f 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -113,6 +113,46 @@ def forward( return logits, state +class BCQN(nn.Module): + """A double-head MLP (imitation and q-value) for BCQ algorithm.""" + + def __init__( + self, + state_shape: tuple, + action_shape: Union[tuple, int], + policy_model_hidden_dim: List[int] = [256, 256], + imitation_model_hidden_dim: List[int] = [256, 256], + device: Union[str, int, torch.device] = "cpu", + norm_layer: Optional[Callable[[int], nn.modules.Module]] = None, + ) -> None: + super().__init__() + self.device = device + policy_dim = [np.prod(state_shape)] + policy_model_hidden_dim + imitation_dim = [np.prod(state_shape)] + \ + imitation_model_hidden_dim + [np.prod(action_shape)] + self.Q = nn.Sequential(*[ + layer + for (inp, oup) in zip(policy_dim[:-1], policy_dim[1:]) + for layer in miniblock(inp, oup, norm_layer) + ], nn.Linear(policy_dim[-1], np.prod(action_shape))) + self.imitation = nn.Sequential(*[ + layer + for (inp, oup) in zip(imitation_dim[:-1], imitation_dim[1:]) + for layer in miniblock(inp, oup, norm_layer) + ]) + + def forward( + self, + s: Union[np.ndarray, torch.Tensor], + state: Optional[Dict[str, torch.Tensor]] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + s = to_torch(s, device=self.device, dtype=torch.float32) + s = s.reshape(s.size(0), -1) + i = self.imitation(s) + return (self.Q(s), F.log_softmax(i, dim=1), i), state + + class Recurrent(nn.Module): """Simple Recurrent network based on LSTM. @@ -172,35 +212,3 @@ def forward( # please ensure the first dim is batch size: [bsz, len, ...] return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()} - - -class BCQN(nn.Module): - """A net imitation and a net for Q-value""" - - def __init__( - self, - input_size: int, - n_actions: int, - imitation_model_hidden_dim: int, - policy_model_hidden_dim: int, - ): - super().__init__() - self.q1 = nn.Linear(input_size, policy_model_hidden_dim) - self.q2 = nn.Linear(policy_model_hidden_dim, policy_model_hidden_dim) - self.q3 = nn.Linear(policy_model_hidden_dim, n_actions) - - self.i1 = nn.Linear(input_size, imitation_model_hidden_dim) - self.i2 = nn.Linear( - imitation_model_hidden_dim, - imitation_model_hidden_dim, - ) - self.i3 = nn.Linear(imitation_model_hidden_dim, n_actions) - - def forward(self, state): - q = F.relu(self.q1(state)) - q = F.relu(self.q2(q)) - - i = F.relu(self.i1(state)) - i = F.relu(self.i2(i)) - i = F.relu(self.i3(i)) - return self.q3(q), F.log_softmax(i, dim=1), i From 2e22daf1c2478ae5cac4bed4d5d79d3cae010720 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 13 Jan 2021 20:59:13 +0800 Subject: [PATCH 12/23] runnable --- .gitignore | 1 + test/discrete/test_bcq.py | 102 --------------- test/discrete/test_dqn.py | 7 ++ test/discrete/test_il_bcq.py | 113 +++++++++++++++++ tianshou/policy/__init__.py | 4 +- .../imitation/{bcq.py => discrete_bcq.py} | 57 ++------- tianshou/policy/modelfree/dqn.py | 13 +- tianshou/trainer/offline.py | 118 ++++++++++++------ tianshou/trainer/offpolicy.py | 5 +- tianshou/trainer/onpolicy.py | 5 +- tianshou/trainer/utils.py | 23 ++-- 11 files changed, 234 insertions(+), 214 deletions(-) delete mode 100644 test/discrete/test_bcq.py create mode 100644 test/discrete/test_il_bcq.py rename tianshou/policy/imitation/{bcq.py => discrete_bcq.py} (61%) diff --git a/.gitignore b/.gitignore index 082dcef12..769ca8bd6 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,4 @@ MUJOCO_LOG.TXT *.zip *.pstats *.swp +*.pkl diff --git a/test/discrete/test_bcq.py b/test/discrete/test_bcq.py deleted file mode 100644 index 997863c68..000000000 --- a/test/discrete/test_bcq.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -import gym -import torch -import random -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import BCQPolicy -from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import BCQN -from tianshou.trainer import offline_trainer -from tianshou.data import Collector, ReplayBuffer - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") - parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--test-num", type=int, default=1) - parser.add_argument("--logdir", type=str, default="log") - parser.add_argument("--buffer-size", type=int, default=1000) - parser.add_argument("--hidden-dim", type=int, default=5) - parser.add_argument("--test-frequency", type=int, default=10) - parser.add_argument("--target-update-frequency", type=int, default=5) - parser.add_argument("--episode-per-test", type=int, default=5) - parser.add_argument("--tau", type=float, default=0.8) - parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--learning-rate", type=float, default=0.01) - parser.add_argument("--imitation_logits_penalty", type=float, default=0.1) - parser.add_argument( - "--device", - type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - ) - args = parser.parse_known_args()[0] - return args - - -def test_bcq(args=get_args()): - env = gym.make(args.task) - state_shape = env.observation_space.shape or env.observation_space.n - state_shape = state_shape[0] - action_shape = env.action_space.shape or env.action_space.n - model = BCQN(state_shape, action_shape, args.hidden_dim, args.hidden_dim) - optim = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - - policy = BCQPolicy( - model, - optim, - args.tau, - args.target_update_frequency, - args.device, - args.gamma, - args.imitation_logits_penalty, - ) - - # Make up some dummy training data in replay buffer - buffer = ReplayBuffer(size=args.buffer_size) - for i in range(args.buffer_size): - buffer.add( - obs=torch.rand(state_shape), - act=random.randint(0, action_shape - 1), - rew=1, - done=False, - obs_next=torch.rand(state_shape), - info={}, - ) - - test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)] - ) - - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - test_envs.seed(args.seed) - - test_collector = Collector(policy, test_envs) - - log_path = os.path.join(args.logdir, "writer") - writer = SummaryWriter(log_path) - if not os.path.exists(log_path): - os.makedirs(log_path) - - res = offline_trainer( - policy, - buffer, - test_collector, - args.epoch, - args.batch_size, - args.episode_per_test, - writer, - args.test_frequency, - ) - print("final best_reward", res["best_reward"]) - - -if __name__ == "__main__": - test_bcq(get_args()) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 8ec9fa380..88ecd3cd1 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,6 +1,7 @@ import os import gym import torch +import pickle import pprint import argparse import numpy as np @@ -36,6 +37,9 @@ def get_args(): parser.add_argument('--prioritized-replay', type=int, default=0) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--save-buffer-name', type=str, + default="./expert_DQN_CartPole-v0.pkl") parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -110,6 +114,9 @@ def test_fn(epoch, env_step): stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) + # save buffer in pickle format, for imitation learning unittest + pickle.dump(buf, open(args.save_buffer_name, "wb")) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py new file mode 100644 index 000000000..7b9fe77d4 --- /dev/null +++ b/test/discrete/test_il_bcq.py @@ -0,0 +1,113 @@ +import os +import gym +import torch +import pickle +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import BCQN +from tianshou.trainer import offline_trainer +from tianshou.policy import DiscreteBCQPolicy + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--gamma", type=float, default=0.9) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) + parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) + parser.add_argument("--epoch", type=int, default=5) + parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument( + "--policy-model-hidden-dim", type=int, nargs="*", + default=[256, 256], + ) + parser.add_argument( + "--imitation-model-hidden-dim", type=int, nargs="*", + default=[256, 256], + ) + parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument( + "--load-buffer-name", type=str, + default="./expert_DQN_CartPole-v0.pkl", + ) + parser.add_argument( + "--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + args = parser.parse_known_args()[0] + return args + + +def test_discrete_bcq(args=get_args()): + # envs + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + # model + net = BCQN( + args.state_shape, args.action_shape, args.policy_model_hidden_dim, + args.imitation_model_hidden_dim, args.device, + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + + policy = DiscreteBCQPolicy( + net, optim, args.gamma, args.n_step, args.target_update_freq, + args.eps_test, args.unlikely_action_threshold, + args.imitation_logits_penalty, + ) + # buffer + assert os.path.exists(args.load_buffer_name), \ + "Please run test_dqn.py first to get expert data buffer." + buffer = pickle.load(open(args.load_buffer_name, "rb")) + + # collector + test_collector = Collector(policy, test_envs) + + log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + result = offline_trainer( + policy, buffer, test_collector, + args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + ) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + + +if __name__ == "__main__": + test_discrete_bcq(get_args()) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index a5a323c9e..968aaf69b 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -10,7 +10,7 @@ from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy -from tianshou.policy.imitation.bcq import BCQPolicy +from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.modelbase.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -28,7 +28,7 @@ "TD3Policy", "SACPolicy", "DiscreteSACPolicy", + "DiscreteBCQPolicy", "PSRLPolicy", "MultiAgentPolicyManager", - "BCQPolicy", ] diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/discrete_bcq.py similarity index 61% rename from tianshou/policy/imitation/bcq.py rename to tianshou/policy/imitation/discrete_bcq.py index 4b939ec9e..69d65aabd 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,15 +1,14 @@ import torch import numpy as np -from copy import deepcopy import torch.nn.functional as F from typing import Any, Dict, Union, Optional -from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.policy import DQNPolicy +from tianshou.data import Batch, ReplayBuffer, to_torch -class BCQPolicy(BasePolicy): - """Implementation for discrete BCQ algorithm.""" +class DiscreteBCQPolicy(DQNPolicy): + """Implementation of discrete BCQ algorithm. arXiv:1812.02900.""" def __init__( self, @@ -23,38 +22,16 @@ def __init__( imitation_logits_penalty: float = 1e-2, **kwargs: Any, ) -> None: - super().__init__(**kwargs) - # init model - self.model = model - self.optim = optim - self.model_old = deepcopy(self.model) - self.model_old.eval() + super().__init__(model, optim, discount_factor, estimation_step, + target_update_freq, **kwargs) self._iter = 0 - # init hparam - assert ( - 0.0 <= discount_factor <= 1.0 - ), "discount factor should be in [0, 1]" - self._gamma = discount_factor assert ( 0.0 <= unlikely_action_threshold < 1.0 ), "unlikely_action_threshold should be in [0, 1)" self._thres = unlikely_action_threshold - assert estimation_step > 0, "estimation_step should be greater than 0" - self._n_step = estimation_step self._eps = eval_eps - self._freq = target_update_freq self._w_imitation = imitation_logits_penalty - def train(self, mode: bool = True) -> "BCQPolicy": - """Set the module in training mode, except for the target network.""" - self.training = mode - self.model.train(mode) - return self - - def sync_weight(self) -> None: - """Synchronize the weight for the target network.""" - self.model_old.load_state_dict(self.model.state_dict()) - def _target_q( self, buffer: ReplayBuffer, indice: np.ndarray ) -> torch.Tensor: @@ -68,19 +45,6 @@ def _target_q( target_q = target_q[np.arange(len(act)), act] return target_q - def process_fn( - self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray - ) -> Batch: - """Compute the n-step return for Q-learning targets. - - More details can be found at - :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. - """ - batch = self.compute_nstep_return( - batch, buffer, indice, self._target_q, - self._gamma, self._n_step, False) - return batch - def forward( self, batch: Batch, @@ -96,9 +60,8 @@ def forward( (q, imt, i), state = self.model(obs, state=state, info=batch.info) imt = imt.exp() imt = (imt / imt.max(1, keepdim=True)[0] > self._thres).float() - # Use large negative number to mask actions from argmax + # mask actions for argmax action = (imt * q + (1.0 - imt) * -np.inf).argmax(-1) - assert len(action.shape) == 1 # add eps to act if not np.isclose(eps, 0.0) and np.random.rand() < eps: @@ -110,21 +73,19 @@ def forward( def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._iter % self._freq == 0: self.sync_weight() + self._iter += 1 target_q = batch.returns.flatten() - (current_q, imt, i), _ = self.model(batch.obs) current_q = current_q[np.arange(len(target_q)), batch.act] - act = to_torch_as(batch.act, target_q) + act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) i_loss = F.nll_loss(imt, act) # type: ignore - loss = q_loss + i_loss + self._w_imitation * i.pow(2).mean() self.optim.zero_grad() loss.backward() self.optim.step() - self._iter += 1 return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 3c72911a7..a8f705b81 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -78,16 +78,15 @@ def _target_q( self, buffer: ReplayBuffer, indice: np.ndarray ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} - if self._target: - # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - a = self(batch, input="obs_next").act - with torch.no_grad(): + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + with torch.no_grad(): + if self._target: + a = self(batch, input="obs_next").act target_q = self( batch, model="model_old", input="obs_next" ).logits - target_q = target_q[np.arange(len(a)), a] - else: - with torch.no_grad(): + target_q = target_q[np.arange(len(a)), a] + else: target_q = self(batch, input="obs_next").logits.max(dim=1)[0] return target_q diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 6f967376f..da5c661a7 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,55 +1,93 @@ +import time +import tqdm +from collections import defaultdict from torch.utils.tensorboard import SummaryWriter -from typing import Dict, Union, Optional +from typing import Dict, List, Union, Callable, Optional -from tianshou.data import Collector, ReplayBuffer from tianshou.policy import BasePolicy -from tianshou.trainer import test_episode +from tianshou.utils import tqdm_config, MovAvg +from tianshou.data import Collector, ReplayBuffer +from tianshou.trainer import test_episode, gather_info def offline_trainer( policy: BasePolicy, buffer: ReplayBuffer, test_collector: Collector, - epochs: int, + max_epoch: int, + step_per_epoch: int, + episode_per_test: Union[int, List[int]], batch_size: int, - episode_per_test: int, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, writer: Optional[SummaryWriter] = None, - test_frequency: int = 1, -) -> Dict[str, Union[float, BasePolicy]]: - best_reward = -1.0 - best_policy = policy - total_iter = 0 - iter_per_epoch = len(buffer) // batch_size + log_interval: int = 1, + verbose: bool = True, +) -> Dict[str, Union[float, str]]: + """A wrapper for offline trainer procedure. - for epoch in range(1, 1 + epochs): - for iter in range(iter_per_epoch): - total_iter += 1 - loss = policy.update(batch_size, buffer) - if total_iter % test_frequency == 0: - if writer is not None: - writer.add_scalar( - "train/loss", - loss["loss"], - global_step=total_iter, - ) + The "step" in trainer means a policy network update. - test_result = test_episode( - policy, - test_collector, - None, - epoch, - episode_per_test, - writer, - total_iter, - ) + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` + class. + :param test_collector: the collector used for testing. + :type test_collector: :class:`~tianshou.data.Collector` + :param int max_epoch: the maximum of epochs for training. The training + process might be finished before reaching the ``max_epoch``. + :param int step_per_epoch: the number of step for updating policy network + in one epoch. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to + feed in the policy network. + :param function test_fn: a function receives the current number of epoch + and step index, and performs some operations at the beginning of + testing in this epoch. + :param function save_fn: a function for saving policy when the undiscounted + average mean reward in evaluation phase gets better. + :param function stop_fn: a function receives the average undiscounted + returns of the testing result, return a boolean which indicates whether + reaching the goal. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard + SummaryWriter. + :param int log_interval: the log interval of the writer. + :param bool verbose: whether to print the information. - if best_reward < test_result["rew"]: - best_reward = test_result["rew"] - best_policy.load_state_dict(policy.state_dict()) + :return: See :func:`~tianshou.trainer.gather_info`. + """ + gradient_step = 0 + best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 + stat: Dict[str, MovAvg] = defaultdict(MovAvg) + start_time = time.time() + test_collector.reset_stat() - print(f"------- epoch: {epoch}, iter: {total_iter} --------") - print("loss:", loss["loss"]) - print("test_result:", test_result) - print("best_reward:", best_reward) - - return {"best_reward": best_reward, "best_policy": best_policy} + for epoch in range(1, 1 + max_epoch): + policy.train() + for i in tqdm.trange( + step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + ): + gradient_step += 1 + losses = policy.update(batch_size, buffer) + data = {"gradient_step": str(gradient_step)} + for k in losses.keys(): + stat[k].add(losses[k]) + data[k] = f"{stat[k].get():.6f}" + if writer and gradient_step % log_interval == 0: + writer.add_scalar( + k, stat[k].get(), global_step=gradient_step) + # test + result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, gradient_step) + if best_epoch == -1 or best_reward < result["rew"]: + best_reward, best_reward_std = result["rew"], result["rew_std"] + best_epoch = epoch + if save_fn: + save_fn(policy) + if verbose: + print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " + f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " + f"{best_reward_std:.6f} in #{best_epoch}") + if stop_fn and stop_fn(best_reward): + break + return gather_info(start_time, None, test_collector, + best_reward, best_reward_std) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index fb4b6f453..2642be734 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,5 +1,6 @@ import time import tqdm +from collections import defaultdict from torch.utils.tensorboard import SummaryWriter from typing import Dict, List, Union, Callable, Optional @@ -73,7 +74,7 @@ def offpolicy_trainer( """ env_step, gradient_step = 0, 0 best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 - stat: Dict[str, MovAvg] = {} + stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() @@ -122,8 +123,6 @@ def offpolicy_trainer( gradient_step += 1 losses = policy.update(batch_size, train_collector.buffer) for k in losses.keys(): - if stat.get(k) is None: - stat[k] = MovAvg() stat[k].add(losses[k]) data[k] = f"{stat[k].get():.6f}" if writer and gradient_step % log_interval == 0: diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 5aff68beb..d2f2b38cf 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,5 +1,6 @@ import time import tqdm +from collections import defaultdict from torch.utils.tensorboard import SummaryWriter from typing import Dict, List, Union, Callable, Optional @@ -73,7 +74,7 @@ def onpolicy_trainer( """ env_step, gradient_step = 0, 0 best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 - stat: Dict[str, MovAvg] = {} + stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() @@ -125,8 +126,6 @@ def onpolicy_trainer( len(v) for v in losses.values() if isinstance(v, list)]) gradient_step += step for k in losses.keys(): - if stat.get(k) is None: - stat[k] = MovAvg() stat[k].add(losses[k]) data[k] = f"{stat[k].get():.6f}" if writer and gradient_step % log_interval == 0: diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index da9dea8ec..3af6101ce 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -36,7 +36,7 @@ def test_episode( def gather_info( start_time: float, - train_c: Collector, + train_c: Optional[Collector], test_c: Collector, best_reward: float, best_reward_std: float, @@ -59,15 +59,9 @@ def gather_info( * ``duration`` the total elapsed time. """ duration = time.time() - start_time - model_time = duration - train_c.collect_time - test_c.collect_time - train_speed = train_c.collect_step / (duration - test_c.collect_time) + model_time = duration - test_c.collect_time test_speed = test_c.collect_step / test_c.collect_time - return { - "train_step": train_c.collect_step, - "train_episode": train_c.collect_episode, - "train_time/collector": f"{train_c.collect_time:.2f}s", - "train_time/model": f"{model_time:.2f}s", - "train_speed": f"{train_speed:.2f} step/s", + result: Dict[str, Union[float, str]] = { "test_step": test_c.collect_step, "test_episode": test_c.collect_episode, "test_time": f"{test_c.collect_time:.2f}s", @@ -76,3 +70,14 @@ def gather_info( "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", "duration": f"{duration:.2f}s", } + if train_c is not None: + model_time -= train_c.collect_time + train_speed = train_c.collect_step / (duration - test_c.collect_time) + result.update({ + "train_step": train_c.collect_step, + "train_episode": train_c.collect_episode, + "train_time/collector": f"{train_c.collect_time:.2f}s", + "train_time/model": f"{model_time:.2f}s", + "train_speed": f"{train_speed:.2f} step/s", + }) + return result From e63cb501d8cca57191db6ed0ce083e44f1875086 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 13 Jan 2021 23:06:38 +0800 Subject: [PATCH 13/23] polish --- test/discrete/test_dqn.py | 8 ++++++-- test/discrete/test_il_bcq.py | 2 +- tianshou/policy/imitation/discrete_bcq.py | 2 -- tianshou/trainer/offline.py | 24 ++++++++++++----------- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 88ecd3cd1..9ccc59630 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -114,8 +114,6 @@ def test_fn(epoch, env_step): stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) - # save buffer in pickle format, for imitation learning unittest - pickle.dump(buf, open(args.save_buffer_name, "wb")) if __name__ == '__main__': pprint.pprint(result) @@ -127,6 +125,12 @@ def test_fn(epoch, env_step): result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') + # save buffer in pickle format, for imitation learning unittest + buf = ReplayBuffer(args.buffer_size) + collector = Collector(policy, test_envs, buf) + collector.collect(n_step=args.buffer_size) + pickle.dump(buf, open(args.save_buffer_name, "wb")) + def test_pdqn(args=get_args()): args.prioritized_replay = 1 diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 7b9fe77d4..e2e9a248c 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -19,7 +19,7 @@ def get_args(): parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) - parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 69d65aabd..a1ec5a1ce 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -24,7 +24,6 @@ def __init__( ) -> None: super().__init__(model, optim, discount_factor, estimation_step, target_update_freq, **kwargs) - self._iter = 0 assert ( 0.0 <= unlikely_action_threshold < 1.0 ), "unlikely_action_threshold should be in [0, 1)" @@ -78,7 +77,6 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_q = batch.returns.flatten() (current_q, imt, i), _ = self.model(batch.obs) current_q = current_q[np.arange(len(target_q)), batch.act] - act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) i_loss = F.nll_loss(imt, act) # type: ignore diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index da5c661a7..6b8175744 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -63,18 +63,20 @@ def offline_trainer( for epoch in range(1, 1 + max_epoch): policy.train() - for i in tqdm.trange( + with tqdm.trange( step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config - ): - gradient_step += 1 - losses = policy.update(batch_size, buffer) - data = {"gradient_step": str(gradient_step)} - for k in losses.keys(): - stat[k].add(losses[k]) - data[k] = f"{stat[k].get():.6f}" - if writer and gradient_step % log_interval == 0: - writer.add_scalar( - k, stat[k].get(), global_step=gradient_step) + ) as t: + for i in t: + gradient_step += 1 + losses = policy.update(batch_size, buffer) + data = {"gradient_step": str(gradient_step)} + for k in losses.keys(): + stat[k].add(losses[k]) + data[k] = f"{stat[k].get():.6f}" + if writer and gradient_step % log_interval == 0: + writer.add_scalar( + k, stat[k].get(), global_step=gradient_step) + t.set_postfix(**data) # test result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, gradient_step) From 21705b7258ddca05eaa631c8d4cae609e90eca8a Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 14 Jan 2021 11:14:19 +0800 Subject: [PATCH 14/23] fix unacessary relu layer in network --- tianshou/policy/imitation/discrete_bcq.py | 12 +++++++++--- tianshou/utils/net/common.py | 5 ++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index a1ec5a1ce..7f3b232e5 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -8,7 +8,7 @@ class DiscreteBCQPolicy(DQNPolicy): - """Implementation of discrete BCQ algorithm. arXiv:1812.02900.""" + """Implementation of discrete BCQ algorithm. arXiv:1910.01708.""" def __init__( self, @@ -80,10 +80,16 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) i_loss = F.nll_loss(imt, act) # type: ignore - loss = q_loss + i_loss + self._w_imitation * i.pow(2).mean() + reg_loss = i.pow(2).mean() + loss = q_loss + i_loss + self._w_imitation * reg_loss self.optim.zero_grad() loss.backward() self.optim.step() - return {"loss": loss.item()} + return { + "loss": loss.item(), + "q_loss": q_loss.item(), + "i_loss": i_loss.item(), + "reg_loss": reg_loss.item(), + } diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 3b0d5f48f..7ec37963e 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -128,8 +128,7 @@ def __init__( super().__init__() self.device = device policy_dim = [np.prod(state_shape)] + policy_model_hidden_dim - imitation_dim = [np.prod(state_shape)] + \ - imitation_model_hidden_dim + [np.prod(action_shape)] + imitation_dim = [np.prod(state_shape)] + imitation_model_hidden_dim self.Q = nn.Sequential(*[ layer for (inp, oup) in zip(policy_dim[:-1], policy_dim[1:]) @@ -139,7 +138,7 @@ def __init__( layer for (inp, oup) in zip(imitation_dim[:-1], imitation_dim[1:]) for layer in miniblock(inp, oup, norm_layer) - ]) + ], nn.Linear(imitation_dim[-1], np.prod(action_shape))) def forward( self, From d8be9ed0ceaedde88aae6f824a7e9b505e57bae8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 14 Jan 2021 16:42:57 +0800 Subject: [PATCH 15/23] final --- README.md | 1 + docs/index.rst | 3 +- docs/tutorials/concepts.rst | 2 +- docs/tutorials/dqn.rst | 4 +- test/discrete/test_il_bcq.py | 33 +++++------ tianshou/policy/imitation/base.py | 2 +- tianshou/policy/imitation/discrete_bcq.py | 70 +++++++++++++++++------ tianshou/trainer/utils.py | 1 + tianshou/utils/net/common.py | 40 ------------- 9 files changed, 77 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 833a8562a..f071fe6a3 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning +- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) diff --git a/docs/index.rst b/docs/index.rst index 3b1fe023c..72704f49c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,8 +20,9 @@ Welcome to Tianshou! * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ * :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ -* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning +* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 7d644587d..749ef4b03 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -201,7 +201,7 @@ Trainer Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`. -Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage. +Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.trainer.offpolicy_trainer`, and :func:`~tianshou.trainer.offline_trainer`, corresponding to on-policy algorithms (such as Policy Gradient), off-policy algorithms (such as DQN), and offline algorithms (such as imitation learning or BCQ). Please check out :doc:`/api/tianshou.trainer` for the usage. .. _pseudocode: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index b1971ea35..faea6a869 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -120,7 +120,7 @@ In each step, the collector will let the policy perform (at least) a specified n Train Policy with a Trainer --------------------------- -Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows: +Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.trainer.offpolicy_trainer`, and :func:`~tianshou.trainer.offline_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :func:`~tianshou.trainer.offpolicy_trainer` as follows: :: result = ts.trainer.offpolicy_trainer( @@ -133,7 +133,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians writer=None) print(f'Finished training! Use {result["duration"]}') -The meaning of each parameter is as follows (full description can be found at :meth:`~tianshou.trainer.offpolicy_trainer`): +The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`): * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; * ``step_per_epoch``: The number of step for updating policy network in one epoch; diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index e2e9a248c..7ee5bfd9d 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -9,7 +9,7 @@ from tianshou.data import Collector from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import BCQN +from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteBCQPolicy @@ -28,14 +28,8 @@ def get_args(): parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--step-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument( - "--policy-model-hidden-dim", type=int, nargs="*", - default=[256, 256], - ) - parser.add_argument( - "--imitation-model-hidden-dim", type=int, nargs="*", - default=[256, 256], - ) + parser.add_argument("--layer-num", type=int, default=2) + parser.add_argument("--hidden-layer-size", type=int, default=128) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) @@ -63,20 +57,27 @@ def test_discrete_bcq(args=get_args()): torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = BCQN( - args.state_shape, args.action_shape, args.policy_model_hidden_dim, - args.imitation_model_hidden_dim, args.device, + policy_net = Net( + args.layer_num, args.state_shape, args.action_shape, args.device, + hidden_layer_size=args.hidden_layer_size, + ).to(args.device) + imitation_net = Net( + args.layer_num, args.state_shape, args.action_shape, args.device, + hidden_layer_size=args.hidden_layer_size, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = torch.optim.Adam( + list(policy_net.parameters()) + list(imitation_net.parameters()), + lr=args.lr + ) policy = DiscreteBCQPolicy( - net, optim, args.gamma, args.n_step, args.target_update_freq, - args.eps_test, args.unlikely_action_threshold, + policy_net, imitation_net, optim, args.gamma, args.n_step, + args.target_update_freq, args.eps_test, args.unlikely_action_threshold, args.imitation_logits_penalty, ) # buffer assert os.path.exists(args.load_buffer_name), \ - "Please run test_dqn.py first to get expert data buffer." + "Please run test_dqn.py first to get expert's data buffer." buffer = pickle.load(open(args.load_buffer_name, "rb")) # collector diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index d65cbc87c..954bc81f6 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -57,7 +57,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) loss = F.mse_loss(a, a_) # type: ignore elif self.mode == "discrete": # classification - a = self(batch).logits + a = F.log_softmax(self(batch).logits, dim=-1) a_ = to_torch(batch.act, dtype=torch.long, device=a.device) loss = F.nll_loss(a, a_) # type: ignore loss.backward() diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 7f3b232e5..5d01f35db 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,3 +1,4 @@ +import math import torch import numpy as np import torch.nn.functional as F @@ -8,11 +9,35 @@ class DiscreteBCQPolicy(DQNPolicy): - """Implementation of discrete BCQ algorithm. arXiv:1910.01708.""" + """Implementation of discrete BCQ algorithm. arXiv:1910.01708. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> q_value) + :param torch.nn.Module imitator: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> imtation_logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. + :param int estimation_step: greater than 1, the number of steps to look + ahead. + :param int target_update_freq: the target network update frequency. + :param float eval_eps: the epsilon-greedy noise added in evaluation. + :param float unlikely_action_threshold: the threshold (tau) for unlikely + actions, as shown in Equ. (17) in the paper, defaults to 0.3. + :param float imitation_logits_penalty: reguralization weight for imitation + logits, defaults to 1e-2. + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to False. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ def __init__( self, model: torch.nn.Module, + imitator: torch.nn.Module, optim: torch.optim.Optimizer, discount_factor: float = 0.99, estimation_step: int = 1, @@ -20,17 +45,26 @@ def __init__( eval_eps: float = 1e-3, unlikely_action_threshold: float = 0.3, imitation_logits_penalty: float = 1e-2, + reward_normalization: bool = False, **kwargs: Any, ) -> None: super().__init__(model, optim, discount_factor, estimation_step, - target_update_freq, **kwargs) + target_update_freq, reward_normalization, **kwargs) + assert target_update_freq > 0, "BCQ needs target network setting." assert ( 0.0 <= unlikely_action_threshold < 1.0 ), "unlikely_action_threshold should be in [0, 1)" - self._thres = unlikely_action_threshold + self.imitator = imitator + self._log_tau = math.log(unlikely_action_threshold) self._eps = eval_eps self._w_imitation = imitation_logits_penalty + def train(self, mode: bool = True) -> "DiscreteBCQPolicy": + self.training = mode + self.model.train(mode) + self.imitator.train(mode) + return self + def _target_q( self, buffer: ReplayBuffer, indice: np.ndarray ) -> torch.Tensor: @@ -38,17 +72,14 @@ def _target_q( # target_Q = Q_old(s_, argmax(Q_new(s_, *))) with torch.no_grad(): act = self(batch, input="obs_next", eps=0.0).act - target_q = self( - batch, model="model_old", input="obs_next", eps=0.0 - ).logits + target_q, _ = self.model_old(batch.obs_next) target_q = target_q[np.arange(len(act)), act] return target_q - def forward( + def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, - model: str = "model", input: str = "obs", eps: Optional[float] = None, **kwargs: Any, @@ -56,18 +87,20 @@ def forward( if eps is None: eps = self._eps obs = batch[input] - (q, imt, i), state = self.model(obs, state=state, info=batch.info) - imt = imt.exp() - imt = (imt / imt.max(1, keepdim=True)[0] > self._thres).float() + q_value, state = self.model(obs, state=state, info=batch.info) + imt, _ = self.imitator(obs, state=state, info=batch.info) + # mask actions for argmax - action = (imt * q + (1.0 - imt) * -np.inf).argmax(-1) + ratio = imt - imt.max(dim=-1, keepdim=True).values + mask = (ratio < self._log_tau).float() + action = (q_value - np.inf * mask).argmax(dim=-1) # add eps to act if not np.isclose(eps, 0.0) and np.random.rand() < eps: - bsz, action_num = q.shape + bsz, action_num = q_value.shape action = np.random.randint(action_num, size=bsz) - return Batch(logits=q, act=action, state=state) + return Batch(logits=q_value, act=action, state=state, imt=imt) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._iter % self._freq == 0: @@ -75,12 +108,13 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self._iter += 1 target_q = batch.returns.flatten() - (current_q, imt, i), _ = self.model(batch.obs) - current_q = current_q[np.arange(len(target_q)), batch.act] + result = self(batch, eps=0.0) + imt = result.imt + current_q = result.logits[np.arange(len(target_q)), batch.act] act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) - i_loss = F.nll_loss(imt, act) # type: ignore - reg_loss = i.pow(2).mean() + i_loss = F.nll_loss(F.log_softmax(imt, dim=-1), act) # type: ignore + reg_loss = imt.pow(2).mean() loss = q_loss + i_loss + self._w_imitation * reg_loss self.optim.zero_grad() diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 3af6101ce..dfffd71a4 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -69,6 +69,7 @@ def gather_info( "best_reward": best_reward, "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", "duration": f"{duration:.2f}s", + "train_time/model": f"{model_time:.2f}s", } if train_c is not None: model_time -= train_c.collect_time diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 7ec37963e..40f050af8 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,7 +1,6 @@ import torch import numpy as np from torch import nn -import torch.nn.functional as F from typing import Any, Dict, List, Tuple, Union, Callable, Optional, Sequence from tianshou.data import to_torch @@ -113,45 +112,6 @@ def forward( return logits, state -class BCQN(nn.Module): - """A double-head MLP (imitation and q-value) for BCQ algorithm.""" - - def __init__( - self, - state_shape: tuple, - action_shape: Union[tuple, int], - policy_model_hidden_dim: List[int] = [256, 256], - imitation_model_hidden_dim: List[int] = [256, 256], - device: Union[str, int, torch.device] = "cpu", - norm_layer: Optional[Callable[[int], nn.modules.Module]] = None, - ) -> None: - super().__init__() - self.device = device - policy_dim = [np.prod(state_shape)] + policy_model_hidden_dim - imitation_dim = [np.prod(state_shape)] + imitation_model_hidden_dim - self.Q = nn.Sequential(*[ - layer - for (inp, oup) in zip(policy_dim[:-1], policy_dim[1:]) - for layer in miniblock(inp, oup, norm_layer) - ], nn.Linear(policy_dim[-1], np.prod(action_shape))) - self.imitation = nn.Sequential(*[ - layer - for (inp, oup) in zip(imitation_dim[:-1], imitation_dim[1:]) - for layer in miniblock(inp, oup, norm_layer) - ], nn.Linear(imitation_dim[-1], np.prod(action_shape))) - - def forward( - self, - s: Union[np.ndarray, torch.Tensor], - state: Optional[Dict[str, torch.Tensor]] = None, - info: Dict[str, Any] = {}, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - s = to_torch(s, device=self.device, dtype=torch.float32) - s = s.reshape(s.size(0), -1) - i = self.imitation(s) - return (self.Q(s), F.log_softmax(i, dim=1), i), state - - class Recurrent(nn.Module): """Simple Recurrent network based on LSTM. From 8489d17bf27e3ace8b38bf1a2f45805c66e7e574 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 14 Jan 2021 20:05:52 +0800 Subject: [PATCH 16/23] fix --- docs/tutorials/concepts.rst | 2 +- setup.py | 2 +- test/discrete/test_il_bcq.py | 4 +-- tianshou/policy/imitation/discrete_bcq.py | 24 ++++++++++------- tianshou/policy/modelfree/a2c.py | 7 ++--- tianshou/policy/modelfree/ppo.py | 7 ++--- tianshou/trainer/offline.py | 27 ++++++++++--------- tianshou/trainer/offpolicy.py | 33 ++++++++++++----------- tianshou/trainer/onpolicy.py | 33 ++++++++++++----------- 9 files changed, 74 insertions(+), 65 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 749ef4b03..8ea2d272e 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -201,7 +201,7 @@ Trainer Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`. -Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.trainer.offpolicy_trainer`, and :func:`~tianshou.trainer.offline_trainer`, corresponding to on-policy algorithms (such as Policy Gradient), off-policy algorithms (such as DQN), and offline algorithms (such as imitation learning or BCQ). Please check out :doc:`/api/tianshou.trainer` for the usage. +Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage. .. _pseudocode: diff --git a/setup.py b/setup.py index 38c73b55f..4b2247218 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ def get_version() -> str: "tensorboard", "torch>=1.4.0", "numba>=0.51.0", - "h5py>=2.10.0", # to match tensorflow's minimal reqiurements + "h5py>=2.10.0", # to match tensorflow's minimal requirements ], extras_require={ "dev": [ diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 7ee5bfd9d..7d57bec6e 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -72,8 +72,8 @@ def test_discrete_bcq(args=get_args()): policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, - args.target_update_freq, args.eps_test, args.unlikely_action_threshold, - args.imitation_logits_penalty, + args.target_update_freq, args.eps_test, + args.unlikely_action_threshold, args.imitation_logits_penalty, ) # buffer assert os.path.exists(args.load_buffer_name), \ diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 5d01f35db..eae24f01f 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -51,13 +51,14 @@ def __init__( super().__init__(model, optim, discount_factor, estimation_step, target_update_freq, reward_normalization, **kwargs) assert target_update_freq > 0, "BCQ needs target network setting." + self.imitator = imitator assert ( 0.0 <= unlikely_action_threshold < 1.0 ), "unlikely_action_threshold should be in [0, 1)" - self.imitator = imitator self._log_tau = math.log(unlikely_action_threshold) + assert 0.0 <= eval_eps < 1.0 self._eps = eval_eps - self._w_imitation = imitation_logits_penalty + self._weight_reg = imitation_logits_penalty def train(self, mode: bool = True) -> "DiscreteBCQPolicy": self.training = mode @@ -88,10 +89,11 @@ def forward( # type: ignore eps = self._eps obs = batch[input] q_value, state = self.model(obs, state=state, info=batch.info) - imt, _ = self.imitator(obs, state=state, info=batch.info) + imitation_logits, _ = self.imitator(obs, state=state, info=batch.info) # mask actions for argmax - ratio = imt - imt.max(dim=-1, keepdim=True).values + ratio = imitation_logits - imitation_logits.max( + dim=-1, keepdim=True).values mask = (ratio < self._log_tau).float() action = (q_value - np.inf * mask).argmax(dim=-1) @@ -100,7 +102,8 @@ def forward( # type: ignore bsz, action_num = q_value.shape action = np.random.randint(action_num, size=bsz) - return Batch(logits=q_value, act=action, state=state, imt=imt) + return Batch(act=action, state=state, q_value=q_value, + imitation_logits=imitation_logits) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._iter % self._freq == 0: @@ -109,13 +112,14 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_q = batch.returns.flatten() result = self(batch, eps=0.0) - imt = result.imt - current_q = result.logits[np.arange(len(target_q)), batch.act] + imitation_logits = result.imitation_logits + current_q = result.q_value[np.arange(len(target_q)), batch.act] act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) - i_loss = F.nll_loss(F.log_softmax(imt, dim=-1), act) # type: ignore - reg_loss = imt.pow(2).mean() - loss = q_loss + i_loss + self._w_imitation * reg_loss + i_loss = F.nll_loss( + F.log_softmax(imitation_logits, dim=-1), act) # type: ignore + reg_loss = imitation_logits.pow(2).mean() + loss = q_loss + i_loss + self._weight_reg * reg_loss self.optim.zero_grad() loss.backward() diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index df34e01b6..e215c6bbd 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -58,8 +58,8 @@ def __init__( self.critic = critic assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." self._lambda = gae_lambda - self._w_vf = vf_coef - self._w_ent = ent_coef + self._weight_vf = vf_coef + self._weight_ent = ent_coef self._grad_norm = max_grad_norm self._batch = max_batchsize self._rew_norm = reward_normalization @@ -122,7 +122,8 @@ def learn( # type: ignore a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) # type: ignore ent_loss = dist.entropy().mean() - loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss + loss = a_loss + self._weight_vf * vf_loss - \ + self._weight_ent * ent_loss loss.backward() if self._grad_norm is not None: nn.utils.clip_grad_norm_( diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 5a04ec6df..4cf9f9054 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -68,8 +68,8 @@ def __init__( super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm self._eps_clip = eps_clip - self._w_vf = vf_coef - self._w_ent = ent_coef + self._weight_vf = vf_coef + self._weight_ent = ent_coef self._range = action_range self.actor = actor self.critic = critic @@ -174,7 +174,8 @@ def learn( # type: ignore vf_losses.append(vf_loss.item()) e_loss = dist.entropy().mean() ent_losses.append(e_loss.item()) - loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + loss = clip_loss + self._weight_vf * vf_loss - \ + self._weight_ent * e_loss losses.append(loss.item()) self.optim.zero_grad() loss.backward() diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 6b8175744..af7a40aa6 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -33,23 +33,24 @@ def offline_trainer( class. :param test_collector: the collector used for testing. :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum of epochs for training. The training - process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of step for updating policy network - in one epoch. + :param int max_epoch: the maximum number of epochs for training. The + training process might be finished before reaching the ``max_epoch``. + :param int step_per_epoch: the number of policy network updates, so-called + gradient steps, per epoch. :param episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to feed in the policy network. - :param function test_fn: a function receives the current number of epoch - and step index, and performs some operations at the beginning of - testing in this epoch. - :param function save_fn: a function for saving policy when the undiscounted - average mean reward in evaluation phase gets better. - :param function stop_fn: a function receives the average undiscounted - returns of the testing result, return a boolean which indicates whether - reaching the goal. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature ``f(policy: + BasePolicy) -> None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) + -> bool``, receives the average undiscounted returns of the testing + result, returns a boolean which indicates whether reaching the goal. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter. + SummaryWriter; if None is given, it will not write logs to TensorBoard. :param int log_interval: the log interval of the writer. :param bool verbose: whether to print the information. diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 2642be734..f34f5b281 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -39,10 +39,10 @@ def offpolicy_trainer( :type train_collector: :class:`~tianshou.data.Collector` :param test_collector: the collector used for testing. :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum of epochs for training. The training - process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of step for updating policy network - in one epoch. + :param int max_epoch: the maximum number of epochs for training. The + training process might be finished before reaching the ``max_epoch``. + :param int step_per_epoch: the number of policy network updates, so-called + gradient steps, per epoch. :param int collect_per_step: the number of frames the collector would collect before the network update. In other words, collect some frames and do some policy network update. @@ -53,19 +53,20 @@ def offpolicy_trainer( be updated after frames are collected, for example, set it to 256 means it updates policy 256 times once after ``collect_per_step`` frames are collected. - :param function train_fn: a function receives the current number of epoch - and step index, and performs some operations at the beginning of - training in this epoch. - :param function test_fn: a function receives the current number of epoch - and step index, and performs some operations at the beginning of - testing in this epoch. - :param function save_fn: a function for saving policy when the undiscounted - average mean reward in evaluation phase gets better. - :param function stop_fn: a function receives the average undiscounted - returns of the testing result, return a boolean which indicates whether - reaching the goal. + :param function train_fn: a hook called at the beginning of training in + each epoch. It can be used to perform custom additional operations, + with the signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature ``f(policy: + BasePolicy) -> None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) + -> bool``, receives the average undiscounted returns of the testing + result, returns a boolean which indicates whether reaching the goal. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter. + SummaryWriter; if None is given, it will not write logs to TensorBoard. :param int log_interval: the log interval of the writer. :param bool verbose: whether to print the information. :param bool test_in_train: whether to test in the training phase. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index d2f2b38cf..f094ddd7d 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -39,10 +39,10 @@ def onpolicy_trainer( :type train_collector: :class:`~tianshou.data.Collector` :param test_collector: the collector used for testing. :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum of epochs for training. The training - process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of step for updating policy network - in one epoch. + :param int max_epoch: the maximum number of epochs for training. The + training process might be finished before reaching the ``max_epoch``. + :param int step_per_epoch: the number of policy network updates, so-called + gradient steps, per epoch. :param int collect_per_step: the number of episodes the collector would collect before the network update. In other words, collect some episodes and do one policy network update. @@ -53,19 +53,20 @@ def onpolicy_trainer( :type episode_per_test: int or list of ints :param int batch_size: the batch size of sample data, which is going to feed in the policy network. - :param function train_fn: a function receives the current number of epoch - and step index, and performs some operations at the beginning of - training in this poch. - :param function test_fn: a function receives the current number of epoch - and step index, and performs some operations at the beginning of - testing in this epoch. - :param function save_fn: a function for saving policy when the undiscounted - average mean reward in evaluation phase gets better. - :param function stop_fn: a function receives the average undiscounted - returns of the testing result, return a boolean which indicates whether - reaching the goal. + :param function train_fn: a hook called at the beginning of training in + each epoch. It can be used to perform custom additional operations, + with the signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature ``f(policy: + BasePolicy) -> None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) + -> bool``, receives the average undiscounted returns of the testing + result, returns a boolean which indicates whether reaching the goal. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter. + SummaryWriter; if None is given, it will not write logs to TensorBoard. :param int log_interval: the log interval of the writer. :param bool verbose: whether to print the information. :param bool test_in_train: whether to test in the training phase. From c2cf97241d4be8928308e76beeaa95359683bea4 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 15 Jan 2021 19:17:47 +0800 Subject: [PATCH 17/23] add atari_bcq, still need check --- .gitignore | 1 + examples/atari/atari_bcq.py | 151 ++++++++++++++++++++++ examples/atari/atari_dqn.py | 21 ++- test/discrete/test_il_bcq.py | 2 +- tianshou/policy/imitation/discrete_bcq.py | 5 +- tianshou/utils/net/discrete.py | 17 ++- 6 files changed, 185 insertions(+), 12 deletions(-) create mode 100644 examples/atari/atari_bcq.py diff --git a/.gitignore b/.gitignore index 769ca8bd6..be8453abd 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,4 @@ MUJOCO_LOG.TXT *.pstats *.swp *.pkl +*.hdf5 diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py new file mode 100644 index 000000000..e3d61dcb1 --- /dev/null +++ b/examples/atari/atari_bcq.py @@ -0,0 +1,151 @@ +import os +import gym +import torch +import pickle +import pprint +import argparse +import numpy as np +from torch import nn +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import SubprocVectorEnv +from tianshou.utils.net.discrete import DQN +from tianshou.trainer import offline_trainer +from tianshou.policy import DiscreteBCQPolicy +from tianshou.data import Collector, ReplayBuffer + +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--lr", type=float, default=6.25e-5) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--target-update-freq", type=int, default=8000) + parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) + parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=10000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--layer-num", type=int, default=2) + parser.add_argument("--hidden-layer-size", type=int, default=512) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument('--frames_stack', type=int, default=4) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument( + "--load-buffer-name", type=str, + default="./expert_DQN_PongNoFrameskip-v4.hdf5", + ) + parser.add_argument( + "--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + args = parser.parse_known_args()[0] + return args + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +class Net(nn.Module): + def __init__(self, preprocess_net, action_shape, hidden_layer_size): + super().__init__() + self.preprocess = preprocess_net + self.last = nn.Sequential( + nn.Linear(3136, hidden_layer_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_layer_size, np.prod(action_shape)) + ) + + def forward(self, s, state=None, **kwargs): + feature, h = self.preprocess(s, state) + return self.last(feature), h + + +def test_discrete_bcq(args=get_args()): + # envs + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + # model + feature_net = DQN(*args.state_shape, args.action_shape, + args.device, features_only=True).to(args.device) + policy_net = Net(feature_net, args.action_shape, + args.hidden_layer_size).to(args.device) + imitation_net = Net(feature_net, args.action_shape, + args.hidden_layer_size).to(args.device) + print(feature_net) + print(policy_net) + print(imitation_net) + optim = torch.optim.Adam( + list(set(policy_net).union(imitation_net)), lr=args.lr + ) + + policy = DiscreteBCQPolicy( + policy_net, imitation_net, optim, args.gamma, args.n_step, + args.target_update_freq, args.eps_test, + args.unlikely_action_threshold, args.imitation_logits_penalty, + ) + # buffer + assert os.path.exists(args.load_buffer_name), \ + "Please run atari_dqn.py first to get expert's data buffer." + if args.load_buffer_name.endswith('.pkl'): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + elif args.load_buffer_name.endswith('.hdf5'): + buffer = ReplayBuffer.load_hdf5(args.load_buffer_name) + else: + print(f"Unknown buffer format: {args.load_buffer_name}") + exit(0) + + # collector + test_collector = Collector(policy, test_envs) + + log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return False + + result = offline_trainer( + policy, buffer, test_collector, + args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + ) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + + +if __name__ == "__main__": + test_discrete_bcq(get_args()) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index b7d8ffef4..1b405f714 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -41,6 +41,7 @@ def get_args(): parser.add_argument('--resume_path', type=str, default=None) parser.add_argument('--watch', default=False, action='store_true', help='watch the play of pre-trained policy only') + parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -120,13 +121,25 @@ def test_fn(epoch, env_step): # watch agent's performance def watch(): - print("Testing agent ...") + print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = ReplayBuffer( + args.buffer_size, ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) + collector = Collector(policy, test_envs, buffer) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) pprint.pprint(result) if args.watch: diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 7d57bec6e..45eb59b65 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -67,7 +67,7 @@ def test_discrete_bcq(args=get_args()): ).to(args.device) optim = torch.optim.Adam( list(policy_net.parameters()) + list(imitation_net.parameters()), - lr=args.lr + lr=args.lr, ) policy = DiscreteBCQPolicy( diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index eae24f01f..4f9564a86 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -55,7 +55,10 @@ def __init__( assert ( 0.0 <= unlikely_action_threshold < 1.0 ), "unlikely_action_threshold should be in [0, 1)" - self._log_tau = math.log(unlikely_action_threshold) + if unlikely_action_threshold > 0: + self._log_tau = math.log(unlikely_action_threshold) + else: + self._log_tau = -np.inf assert 0.0 <= eval_eps < 1.0 self._eps = eval_eps self._weight_reg = imitation_logits_penalty diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index bff229cd7..05d647ff8 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -80,6 +80,7 @@ def __init__( w: int, action_shape: Sequence[int], device: Union[str, int, torch.device] = "cpu", + features_only: bool = False, ) -> None: super().__init__() self.device = device @@ -107,18 +108,22 @@ def conv2d_layers_size_out( convh = conv2d_layers_size_out(h) linear_input_size = convw * convh * 64 - self.net = nn.Sequential( + layers = [ nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True), - nn.Flatten(), - nn.Linear(linear_input_size, 512), - nn.ReLU(inplace=True), - nn.Linear(512, np.prod(action_shape)), - ) + nn.Flatten() + ] + if not features_only: + layers += [ + nn.Linear(linear_input_size, 512), + nn.ReLU(inplace=True), + nn.Linear(512, np.prod(action_shape)), + ] + self.net = nn.Sequential(*layers) def forward( self, From a3b51a2d6ffd80792e733545bb435e1e580017ac Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 16 Jan 2021 16:26:15 +0800 Subject: [PATCH 18/23] update examples --- examples/atari/atari_bcq.py | 51 ++++++++++++++++++++++++------------- examples/atari/atari_dqn.py | 4 +-- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index e3d61dcb1..68b9d4987 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -1,5 +1,4 @@ import os -import gym import torch import pickle import pprint @@ -24,7 +23,7 @@ def get_args(): parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=6.25e-5) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=8000) parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) @@ -33,10 +32,14 @@ def get_args(): parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--layer-num", type=int, default=2) parser.add_argument("--hidden-layer-size", type=int, default=512) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=100) parser.add_argument('--frames_stack', type=int, default=4) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--watch", default=False, action="store_true", + help="watch the play of pre-trained policy only") + parser.add_argument("--log-interval", type=int, default=1000) parser.add_argument( "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", @@ -95,18 +98,22 @@ def test_discrete_bcq(args=get_args()): args.hidden_layer_size).to(args.device) imitation_net = Net(feature_net, args.action_shape, args.hidden_layer_size).to(args.device) - print(feature_net) - print(policy_net) - print(imitation_net) optim = torch.optim.Adam( - list(set(policy_net).union(imitation_net)), lr=args.lr + set(policy_net.parameters()).union(imitation_net.parameters()), + lr=args.lr, ) - + # define policy policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, args.target_update_freq, args.eps_test, args.unlikely_action_threshold, args.imitation_logits_penalty, ) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load( + args.resume_path, map_location=args.device + )) + print("Loaded agent from: ", args.resume_path) # buffer assert os.path.exists(args.load_buffer_name), \ "Please run atari_dqn.py first to get expert's data buffer." @@ -130,21 +137,31 @@ def save_fn(policy): def stop_fn(mean_rewards): return False + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) + pprint.pprint(result) + + if args.watch: + watch() + exit(0) + result = offline_trainer( policy, buffer, test_collector, args.epoch, args.step_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, + log_interval=args.log_interval, ) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - policy.set_eps(args.eps_test) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + pprint.pprint(result) + watch() if __name__ == "__main__": diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 1b405f714..57ff4608f 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -37,8 +37,8 @@ def get_args(): parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') - parser.add_argument('--frames_stack', type=int, default=4) - parser.add_argument('--resume_path', type=str, default=None) + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument('--resume-path', type=str, default=None) parser.add_argument('--watch', default=False, action='store_true', help='watch the play of pre-trained policy only') parser.add_argument('--save-buffer-name', type=str, default=None) From 667b2f8bd17845150c1e1a45ece6e4d08ae097b2 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 16 Jan 2021 22:49:57 +0800 Subject: [PATCH 19/23] tune eps code --- tianshou/policy/imitation/discrete_bcq.py | 6 ++++-- tianshou/trainer/offline.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 4f9564a86..6bd780714 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -101,9 +101,11 @@ def forward( # type: ignore action = (q_value - np.inf * mask).argmax(dim=-1) # add eps to act - if not np.isclose(eps, 0.0) and np.random.rand() < eps: + if not np.isclose(eps, 0.0): bsz, action_num = q_value.shape - action = np.random.randint(action_num, size=bsz) + mask = np.random.rand(bsz) < eps + action_rand = np.random.randint(action_num, size=bsz) + action[mask] = action_rand[mask] return Batch(act=action, state=state, q_value=q_value, imitation_logits=imitation_logits) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index af7a40aa6..e69364135 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -76,7 +76,8 @@ def offline_trainer( data[k] = f"{stat[k].get():.6f}" if writer and gradient_step % log_interval == 0: writer.add_scalar( - k, stat[k].get(), global_step=gradient_step) + "train/" + k, stat[k].get(), + global_step=gradient_step) t.set_postfix(**data) # test result = test_episode(policy, test_collector, test_fn, epoch, From 151fd0b2919135b2dd46f4d52a2f0ebeba167372 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 16 Jan 2021 22:54:55 +0800 Subject: [PATCH 20/23] fix eps mask --- tianshou/policy/imitation/discrete_bcq.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 6bd780714..faae34aaf 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -104,7 +104,8 @@ def forward( # type: ignore if not np.isclose(eps, 0.0): bsz, action_num = q_value.shape mask = np.random.rand(bsz) < eps - action_rand = np.random.randint(action_num, size=bsz) + action_rand = torch.randint( + action_num, size=[bsz], device=action.device) action[mask] = action_rand[mask] return Batch(act=action, state=state, q_value=q_value, From 5ef0c4c16d9600a2970bcb9af97ddec6fad9bc35 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 20 Jan 2021 17:17:37 +0800 Subject: [PATCH 21/23] fix test --- examples/atari/atari_bcq.py | 35 ++++++++++------------------------ test/discrete/test_il_bcq.py | 16 +++++++--------- tianshou/utils/net/discrete.py | 2 +- 3 files changed, 18 insertions(+), 35 deletions(-) diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index 68b9d4987..e9edb8310 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -4,15 +4,15 @@ import pprint import argparse import numpy as np -from torch import nn from torch.utils.tensorboard import SummaryWriter from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.discrete import DQN from tianshou.trainer import offline_trainer +from tianshou.utils.net.discrete import Actor from tianshou.policy import DiscreteBCQPolicy from tianshou.data import Collector, ReplayBuffer +from atari_network import DQN from atari_wrapper import wrap_deepmind @@ -30,13 +30,13 @@ def get_args(): parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--step-per-epoch", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--layer-num", type=int, default=2) - parser.add_argument("--hidden-layer-size", type=int, default=512) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[512]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument('--frames_stack', type=int, default=4) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) - parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume-path", type=str, default=None) parser.add_argument("--watch", default=False, action="store_true", help="watch the play of pre-trained policy only") parser.add_argument("--log-interval", type=int, default=1000) @@ -61,21 +61,6 @@ def make_atari_env_watch(args): episode_life=False, clip_rewards=False) -class Net(nn.Module): - def __init__(self, preprocess_net, action_shape, hidden_layer_size): - super().__init__() - self.preprocess = preprocess_net - self.last = nn.Sequential( - nn.Linear(3136, hidden_layer_size), - nn.ReLU(inplace=True), - nn.Linear(hidden_layer_size, np.prod(action_shape)) - ) - - def forward(self, s, state=None, **kwargs): - feature, h = self.preprocess(s, state) - return self.last(feature), h - - def test_discrete_bcq(args=get_args()): # envs env = make_atari_env(args) @@ -93,11 +78,11 @@ def test_discrete_bcq(args=get_args()): test_envs.seed(args.seed) # model feature_net = DQN(*args.state_shape, args.action_shape, - args.device, features_only=True).to(args.device) - policy_net = Net(feature_net, args.action_shape, - args.hidden_layer_size).to(args.device) - imitation_net = Net(feature_net, args.action_shape, - args.hidden_layer_size).to(args.device) + device=args.device, features_only=True).to(args.device) + policy_net = Actor(feature_net, args.action_shape, + hidden_sizes=args.hidden_sizes).to(args.device) + imitation_net = Actor(feature_net, args.action_shape, + hidden_sizes=args.hidden_sizes).to(args.device) optim = torch.optim.Adam( set(policy_net.parameters()).union(imitation_net.parameters()), lr=args.lr, diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 45eb59b65..b817ab9c5 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -28,8 +28,8 @@ def get_args(): parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--step-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--layer-num", type=int, default=2) - parser.add_argument("--hidden-layer-size", type=int, default=128) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[128, 128, 128]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) @@ -58,15 +58,13 @@ def test_discrete_bcq(args=get_args()): test_envs.seed(args.seed) # model policy_net = Net( - args.layer_num, args.state_shape, args.action_shape, args.device, - hidden_layer_size=args.hidden_layer_size, - ).to(args.device) + args.state_shape, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) imitation_net = Net( - args.layer_num, args.state_shape, args.action_shape, args.device, - hidden_layer_size=args.hidden_layer_size, - ).to(args.device) + args.state_shape, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), + set(policy_net.parameters()).union(imitation_net.parameters()), lr=args.lr, ) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 5bf3eb5de..05c02361c 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -104,4 +104,4 @@ def forward( ) -> torch.Tensor: """Mapping: s -> V(s).""" logits, _ = self.preprocess(s, state=kwargs.get("state", None)) - return self.last(logits) \ No newline at end of file + return self.last(logits) From a37a5424730a7642b52771aec3ebb28c72e92dfa Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 20 Jan 2021 17:35:19 +0800 Subject: [PATCH 22/23] update readme --- examples/atari/README.md | 13 ++++++++++++- examples/atari/atari_network.py | 3 +-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/atari/README.md b/examples/atari/README.md index 7fd034461..933415e58 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -38,4 +38,15 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` | -Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper. \ No newline at end of file +Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper. + +# BCQ + +TODO: after the `done` issue fixed, the result should be re-tuned and place here. + +To running BCQ algorithm on Atari, you need to do the following things: + +- Train an expert, by using the command listed in the above DQN section; +- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); +- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`. + diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 0b7adaa04..c31a6c8cf 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -44,8 +44,7 @@ def forward( info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: x -> Q(x, \*).""" - x = torch.as_tensor( - x, device=self.device, dtype=torch.float32) # type: ignore + x = torch.as_tensor(x, device=self.device, dtype=torch.float32) return self.net(x), state From 0b291deee51e1cf0a7c6089c416c1ce6d91f8b69 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 20 Jan 2021 17:38:40 +0800 Subject: [PATCH 23/23] trailing comma --- test/discrete/test_il_bcq.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index b817ab9c5..e9e857ef1 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -65,8 +65,7 @@ def test_discrete_bcq(args=get_args()): hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) optim = torch.optim.Adam( set(policy_net.parameters()).union(imitation_net.parameters()), - lr=args.lr, - ) + lr=args.lr) policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, @@ -93,8 +92,8 @@ def stop_fn(mean_rewards): result = offline_trainer( policy, buffer, test_collector, args.epoch, args.step_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - ) + stop_fn=stop_fn, save_fn=save_fn, writer=writer) + assert stop_fn(result['best_reward']) if __name__ == '__main__':