From f14e4c6297b5afa82d51b0aa26317e73f8a6162d Mon Sep 17 00:00:00 2001 From: Yao Date: Fri, 4 Sep 2020 20:31:50 +0800 Subject: [PATCH 01/62] add PSRL policy --- examples/FrozenLake_psrl.py | 101 +++++++++++++++++ examples/NChain_psrl.py | 101 +++++++++++++++++ examples/Taxi_psrl.py | 101 +++++++++++++++++ tianshou/policy/__init__.py | 2 + tianshou/policy/psrl/__init__.py | 0 tianshou/policy/psrl/psrl.py | 182 +++++++++++++++++++++++++++++++ 6 files changed, 487 insertions(+) create mode 100644 examples/FrozenLake_psrl.py create mode 100644 examples/NChain_psrl.py create mode 100644 examples/Taxi_psrl.py create mode 100644 tianshou/policy/psrl/__init__.py create mode 100644 tianshou/policy/psrl/psrl.py diff --git a/examples/FrozenLake_psrl.py b/examples/FrozenLake_psrl.py new file mode 100644 index 000000000..c92fcad11 --- /dev/null +++ b/examples/FrozenLake_psrl.py @@ -0,0 +1,101 @@ +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import onpolicy_trainer, offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer + +import gym +from tianshou.env import VectorEnv +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.split(__file__)[0], os.pardir)) + '/tianshou/policy/psrl') +from psrl import PSRLPolicy, PSRLModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='FrozenLake-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=1) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=100) + parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--repeat-per-collect', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=1) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + args = parser.parse_known_args()[0] + return args + + +def test_psrl(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.env.action_space.shape or env.env.action_space.n + # train_envs = gym.make(args.task) + # train_envs = gym.make(args.task) + train_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + n_action = args.action_shape + n_state = args.state_shape + p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) + rew_mean = np.zeros((n_state, n_action)) + rew_std = np.ones((n_state, n_action)) + model = PSRLModel(p_pri, rew_mean, rew_std) + policy = PSRLPolicy(model) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # log + writer = SummaryWriter(args.logdir + '/' + 'FrozenLake') + + def train_fn(x): + policy.set_eps(args.eps_train) + + def test_fn(x): + policy.set_eps(args.eps_test) + + def stop_fn(x): + if env.env.spec.reward_threshold: + return x >= env.spec.reward_threshold + else: + return False + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, + args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, writer=writer) + + train_collector.close() + test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + + +if __name__ == '__main__': + test_psrl() diff --git a/examples/NChain_psrl.py b/examples/NChain_psrl.py new file mode 100644 index 000000000..a817aa1b0 --- /dev/null +++ b/examples/NChain_psrl.py @@ -0,0 +1,101 @@ +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import onpolicy_trainer, offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer + +import gym +from tianshou.env import VectorEnv +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.split(__file__)[0], os.pardir)) + '/tianshou/policy/psrl') +from psrl import PSRLPolicy, PSRLModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='NChain-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=100) + parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--repeat-per-collect', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=1) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + args = parser.parse_known_args()[0] + return args + + +def test_psrl(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.env.action_space.shape or env.env.action_space.n + # train_envs = gym.make(args.task) + # train_envs = gym.make(args.task) + train_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + n_action = args.action_shape + n_state = args.state_shape + p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) + rew_mean = np.ones((n_state, n_action)) + rew_std = np.ones((n_state, n_action)) + model = PSRLModel(p_pri, rew_mean, rew_std) + policy = PSRLPolicy(model) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # log + writer = SummaryWriter(args.logdir + '/' + 'NChain') + + def train_fn(x): + policy.set_eps(args.eps_train) + + def test_fn(x): + policy.set_eps(args.eps_test) + + def stop_fn(x): + if env.env.spec.reward_threshold: + return x >= env.spec.reward_threshold + else: + return False + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, + args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, writer=writer) + + train_collector.close() + test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + + +if __name__ == '__main__': + test_psrl() diff --git a/examples/Taxi_psrl.py b/examples/Taxi_psrl.py new file mode 100644 index 000000000..46dfdd65d --- /dev/null +++ b/examples/Taxi_psrl.py @@ -0,0 +1,101 @@ +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import onpolicy_trainer, offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer + +import gym +from tianshou.env import VectorEnv +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.split(__file__)[0], os.pardir)) + '/tianshou/policy/psrl') +from psrl import PSRLPolicy, PSRLModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Taxi-v3') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100) + parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--repeat-per-collect', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=1) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + args = parser.parse_known_args()[0] + return args + + +def test_psrl(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.env.action_space.shape or env.env.action_space.n + # train_envs = gym.make(args.task) + # train_envs = gym.make(args.task) + train_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + n_action = args.action_shape + n_state = args.state_shape + p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) + rew_mean = np.ones((n_state, n_action)) + rew_std = np.ones((n_state, n_action)) + model = PSRLModel(p_pri, rew_mean, rew_std) + policy = PSRLPolicy(model) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # log + writer = SummaryWriter(args.logdir + '/' + 'Taxi') + + def train_fn(x): + policy.set_eps(args.eps_train) + + def test_fn(x): + policy.set_eps(args.eps_test) + + def stop_fn(x): + if env.env.spec.reward_threshold: + return x >= env.spec.reward_threshold + else: + return False + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, + args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, writer=writer) + + train_collector.close() + test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + + +if __name__ == '__main__': + test_psrl() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 95b7f0eeb..3517f0368 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -9,6 +9,7 @@ from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager +from tianshou.policy.psrl.psrl import PSRLPolicy __all__ = [ @@ -23,4 +24,5 @@ 'TD3Policy', 'SACPolicy', 'MultiAgentPolicyManager', + 'PSRLPolicy', ] diff --git a/tianshou/policy/psrl/__init__.py b/tianshou/policy/psrl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tianshou/policy/psrl/psrl.py b/tianshou/policy/psrl/psrl.py new file mode 100644 index 000000000..fd53c6f0a --- /dev/null +++ b/tianshou/policy/psrl/psrl.py @@ -0,0 +1,182 @@ +import torch +import numpy as np +from typing import Dict, List, Union, Optional +import mdptoolbox + +from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer, to_torch_as + + +class PSRLModel(object): + """Implementation of Posterior Sampling Reinforcement Learning Model. + + :param np.ndarray p_prior: dirichlet prior (alphas). + :param np.ndarray rew_mean_prior: means of the normal priors of rewards. + :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards. + :param float discount_factor: in [0, 1]. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__(self, p_prior, rew_mean_prior, rew_std_prior, discount_factor: float = 0.99): + self.p = p_prior + self.n_action = len(self.p) + self.n_state = len(self.p[0]) + self.rew_mean = rew_mean_prior + self.rew_std = rew_std_prior + self.discount_factor = discount_factor + self.rew_count = np.zeros_like(rew_mean_prior) + self.sample_p = None + self.sample_rew = None + self.policy = None + self.updated = False + + def observe(self, p, rew_sum, rew_count): + """Add data.""" + self.updated = False + self.p += p + sum_count_nonzero = np.where(self.rew_count + rew_count == 0, 1, self.rew_count + rew_count) + rew_count_nonzero = np.where(rew_count == 0, 1, rew_count) + self.rew_mean = np.where(self.rew_count == 0, + np.where(rew_count == 0, self.rew_mean, rew_sum / rew_count_nonzero), + (self.rew_mean * self.rew_count + rew_sum) / sum_count_nonzero) + self.rew_std *= np.where(self.rew_count == 0, 1, self.rew_count) / sum_count_nonzero + self.rew_count += rew_count + + def sample_from_p(self): + sample_p = [] + for a in range(self.n_action): + for i in range(self.n_state): + param = self.p[a][i] + 1e-5 * np.random.randn(len(self.p[a][i])) + sample_p.append(param / np.sum(param)) + sample_p = np.array(sample_p).reshape(self.n_action, self.n_state, self.n_state) + return sample_p + + def sample_from_rew(self): + sample_rew = np.random.randn(len(self.rew_mean), len(self.rew_mean[0])) + sample_rew = sample_rew * self.rew_std + self.rew_mean + return sample_rew + + def solve_policy(self): + self.updated = True + self.sample_p = self.sample_from_p() + self.sample_rew = self.sample_from_rew() + problem = mdptoolbox.mdp.ValueIteration(self.sample_p, self.sample_rew, self.discount_factor) + problem.run() + self.policy = np.array(problem.policy) + return self.policy + + def __call__(self, obs, state=None, info=None): + if self.updated is False: + self.solve_policy() + act = self.policy[obs] + return act + + +class PSRLPolicy(BasePolicy): + """Implementation of Posterior Sampling Reinforcement Learning. + + :param PSRLModel model: a model following the rules in + :class:`PSRLModel`. + :param torch.distributions.Distribution dist_fn: for computing the action. + :param float discount_factor: in [0, 1]. + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to ``False``. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__(self, + model: PSRLModel, + dist_fn: torch.distributions.Distribution + = torch.distributions.Categorical, + discount_factor: float = 0.2, + reward_normalization: bool = False, + **kwargs) -> None: + super().__init__(**kwargs) + self.model = model + self.dist_fn = dist_fn + assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]' + self._gamma = discount_factor + self._rew_norm = reward_normalization + self.eps = 0 + + def set_eps(self, eps: float) -> None: + """Set the eps for epsilon-greedy exploration.""" + self.eps = eps + + def process_fn(self, batch: Batch, buffer: ReplayBuffer, + indice: np.ndarray) -> Batch: + r"""Compute the discounted returns for each frame: + + .. math:: + G_t = \sum_{i=t}^T \gamma^{i-t}r_i + + , where :math:`T` is the terminal time step, :math:`\gamma` is the + discount factor, :math:`\gamma \in [0, 1]`. + """ + # batch.returns = self._vanilla_returns(batch) + # batch.returns = self._vectorized_returns(batch) + # return batch + return self.compute_episodic_return( + batch, gamma=self._gamma, gae_lambda=1.) + + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + eps: Optional[float] = None, + **kwargs) -> Batch: + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which has 1 key: + + * ``act`` the action. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + act = self.model(batch.obs, state=state, info=batch.info) + if eps is None: + eps = self.eps + if not np.isclose(eps, 0): + for i in range(len(act)): + if np.random.rand() < eps: + act[i] = np.random.randint(0, self.model.n_action) + return Batch(act=act) + + def learn(self, batch: Batch, **kwargs) -> Dict[str, List[float]]: + r = batch.returns + if self._rew_norm and not np.isclose(r.std(), 0): + batch.returns = (r - r.mean()) / r.std() + p = np.zeros((self.model.n_action, self.model.n_state, self.model.n_state)) + rew_sum = np.zeros((self.model.n_state, self.model.n_action)) + rew_count = np.zeros((self.model.n_state, self.model.n_action)) + a = batch.act + r = batch.returns + obs = batch.obs + obs_next = batch.obs_next + for i in range(len(obs)): + p[a[i]][obs[i]][obs_next[i]] += 1 + rew_sum[obs[i]][a[i]] += r[i] + rew_count[obs[i]][a[i]] += 1 + self.model.observe(p, rew_sum, rew_count) + return {'loss': [0.0]} + + +if __name__ == "__main__": + n_action = 3 + n_state = 4 + p_pri = np.ones((n_action, n_state, n_state)) + rew_mean = np.ones((n_state, n_action)) + rew_std = np.ones((n_state, n_action)) + model = PSRLModel(p_pri, rew_mean, rew_std) + policy = PSRLPolicy(model) + import pdb + pdb.set_trace() From f26255dbfa97e53e871c6ce3166b84150b573c4a Mon Sep 17 00:00:00 2001 From: Yao Date: Sat, 5 Sep 2020 10:54:26 +0800 Subject: [PATCH 02/62] improve PSRL code --- README.md | 1 + docs/index.rst | 1 + examples/{ => modelbase}/FrozenLake_psrl.py | 27 ++---- examples/{ => modelbase}/NChain_psrl.py | 27 ++---- examples/{ => modelbase}/Taxi_psrl.py | 29 ++---- setup.py | 1 + test/modelbase/NChain_psrl.py | 92 +++++++++++++++++++ tianshou/policy/__init__.py | 2 +- .../policy/{psrl => modelbase}/__init__.py | 0 tianshou/policy/{psrl => modelbase}/psrl.py | 88 +++++++++--------- 10 files changed, 166 insertions(+), 102 deletions(-) rename examples/{ => modelbase}/FrozenLake_psrl.py (79%) rename examples/{ => modelbase}/NChain_psrl.py (80%) rename examples/{ => modelbase}/Taxi_psrl.py (78%) create mode 100644 test/modelbase/NChain_psrl.py rename tianshou/policy/{psrl => modelbase}/__init__.py (100%) rename tianshou/policy/{psrl => modelbase}/psrl.py (69%) diff --git a/README.md b/README.md index 1764051cd..b5818ae7b 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ - Vanilla Imitation Learning - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) +- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) Here is Tianshou's other features: diff --git a/docs/index.rst b/docs/index.rst index 9ef598a81..d16b96dc0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,6 +17,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ +* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/examples/FrozenLake_psrl.py b/examples/modelbase/FrozenLake_psrl.py similarity index 79% rename from examples/FrozenLake_psrl.py rename to examples/modelbase/FrozenLake_psrl.py index c92fcad11..372eb8f91 100644 --- a/examples/FrozenLake_psrl.py +++ b/examples/modelbase/FrozenLake_psrl.py @@ -4,16 +4,12 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.env import SubprocVectorEnv -from tianshou.trainer import onpolicy_trainer, offpolicy_trainer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer import gym -from tianshou.env import VectorEnv -import sys -import os -sys.path.append(os.path.abspath(os.path.join(os.path.split(__file__)[0], os.pardir)) + '/tianshou/policy/psrl') -from psrl import PSRLPolicy, PSRLModel +from tianshou.policy import PSRLPolicy def get_args(): @@ -23,17 +19,15 @@ def get_args(): parser.add_argument('--eps-test', type=float, default=0) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=1) - parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=100) parser.add_argument('--collect-per-step', type=int, default=100) - parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=1) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - args = parser.parse_known_args()[0] - return args + return parser.parse_args() def test_psrl(args=get_args()): @@ -42,7 +36,7 @@ def test_psrl(args=get_args()): args.action_shape = env.env.action_space.shape or env.env.action_space.n # train_envs = gym.make(args.task) # train_envs = gym.make(args.task) - train_envs = VectorEnv( + train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( @@ -58,14 +52,13 @@ def test_psrl(args=get_args()): p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) rew_mean = np.zeros((n_state, n_action)) rew_std = np.ones((n_state, n_action)) - model = PSRLModel(p_pri, rew_mean, rew_std) - policy = PSRLPolicy(model) + policy = PSRLPolicy(p_pri, rew_mean, rew_std) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + 'FrozenLake') + writer = SummaryWriter(args.logdir + '/' + args.task) def train_fn(x): policy.set_eps(args.eps_train) @@ -85,16 +78,14 @@ def stop_fn(x): args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, writer=writer) - train_collector.close() - test_collector.close() if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) collector = Collector(policy, env) + policy.eval() result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') - collector.close() if __name__ == '__main__': diff --git a/examples/NChain_psrl.py b/examples/modelbase/NChain_psrl.py similarity index 80% rename from examples/NChain_psrl.py rename to examples/modelbase/NChain_psrl.py index a817aa1b0..e98c44673 100644 --- a/examples/NChain_psrl.py +++ b/examples/modelbase/NChain_psrl.py @@ -4,16 +4,12 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.env import SubprocVectorEnv -from tianshou.trainer import onpolicy_trainer, offpolicy_trainer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer import gym -from tianshou.env import VectorEnv -import sys -import os -sys.path.append(os.path.abspath(os.path.join(os.path.split(__file__)[0], os.pardir)) + '/tianshou/policy/psrl') -from psrl import PSRLPolicy, PSRLModel +from tianshou.policy import PSRLPolicy def get_args(): @@ -26,14 +22,12 @@ def get_args(): parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=100) parser.add_argument('--collect-per-step', type=int, default=100) - parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=1) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - args = parser.parse_known_args()[0] - return args + return parser.parse_args() def test_psrl(args=get_args()): @@ -42,7 +36,7 @@ def test_psrl(args=get_args()): args.action_shape = env.env.action_space.shape or env.env.action_space.n # train_envs = gym.make(args.task) # train_envs = gym.make(args.task) - train_envs = VectorEnv( + train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( @@ -56,16 +50,15 @@ def test_psrl(args=get_args()): n_action = args.action_shape n_state = args.state_shape p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) - rew_mean = np.ones((n_state, n_action)) + rew_mean = np.zeros((n_state, n_action)) rew_std = np.ones((n_state, n_action)) - model = PSRLModel(p_pri, rew_mean, rew_std) - policy = PSRLPolicy(model) + policy = PSRLPolicy(p_pri, rew_mean, rew_std) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + 'NChain') + writer = SummaryWriter(args.logdir + '/' + args.task) def train_fn(x): policy.set_eps(args.eps_train) @@ -85,16 +78,14 @@ def stop_fn(x): args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, writer=writer) - train_collector.close() - test_collector.close() if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) collector = Collector(policy, env) + policy.eval() result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') - collector.close() if __name__ == '__main__': diff --git a/examples/Taxi_psrl.py b/examples/modelbase/Taxi_psrl.py similarity index 78% rename from examples/Taxi_psrl.py rename to examples/modelbase/Taxi_psrl.py index 46dfdd65d..d1283ae29 100644 --- a/examples/Taxi_psrl.py +++ b/examples/modelbase/Taxi_psrl.py @@ -4,16 +4,12 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.env import SubprocVectorEnv -from tianshou.trainer import onpolicy_trainer, offpolicy_trainer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer import gym -from tianshou.env import VectorEnv -import sys -import os -sys.path.append(os.path.abspath(os.path.join(os.path.split(__file__)[0], os.pardir)) + '/tianshou/policy/psrl') -from psrl import PSRLPolicy, PSRLModel +from tianshou.policy import PSRLPolicy def get_args(): @@ -23,17 +19,15 @@ def get_args(): parser.add_argument('--eps-test', type=float, default=0) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=100) parser.add_argument('--collect-per-step', type=int, default=100) - parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=1) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - args = parser.parse_known_args()[0] - return args + return parser.parse_args() def test_psrl(args=get_args()): @@ -42,7 +36,7 @@ def test_psrl(args=get_args()): args.action_shape = env.env.action_space.shape or env.env.action_space.n # train_envs = gym.make(args.task) # train_envs = gym.make(args.task) - train_envs = VectorEnv( + train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( @@ -56,16 +50,15 @@ def test_psrl(args=get_args()): n_action = args.action_shape n_state = args.state_shape p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) - rew_mean = np.ones((n_state, n_action)) + rew_mean = np.zeros((n_state, n_action)) rew_std = np.ones((n_state, n_action)) - model = PSRLModel(p_pri, rew_mean, rew_std) - policy = PSRLPolicy(model) + policy = PSRLPolicy(p_pri, rew_mean, rew_std) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + 'Taxi') + writer = SummaryWriter(args.logdir + '/' + args.task) def train_fn(x): policy.set_eps(args.eps_train) @@ -85,16 +78,14 @@ def stop_fn(x): args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, writer=writer) - train_collector.close() - test_collector.close() if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) collector = Collector(policy, env) + policy.eval() result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') - collector.close() if __name__ == '__main__': diff --git a/setup.py b/setup.py index 64aac40b2..0d7ef8152 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ 'numpy', 'tensorboard', 'torch>=1.4.0', + 'pymdptoolbox', ], extras_require={ 'dev': [ diff --git a/test/modelbase/NChain_psrl.py b/test/modelbase/NChain_psrl.py new file mode 100644 index 000000000..e98c44673 --- /dev/null +++ b/test/modelbase/NChain_psrl.py @@ -0,0 +1,92 @@ +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer + +import gym +from tianshou.policy import PSRLPolicy + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='NChain-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=100) + parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=1) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + return parser.parse_args() + + +def test_psrl(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.env.action_space.shape or env.env.action_space.n + # train_envs = gym.make(args.task) + # train_envs = gym.make(args.task) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + n_action = args.action_shape + n_state = args.state_shape + p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) + rew_mean = np.zeros((n_state, n_action)) + rew_std = np.ones((n_state, n_action)) + policy = PSRLPolicy(p_pri, rew_mean, rew_std) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # log + writer = SummaryWriter(args.logdir + '/' + args.task) + + def train_fn(x): + policy.set_eps(args.eps_train) + + def test_fn(x): + policy.set_eps(args.eps_test) + + def stop_fn(x): + if env.env.spec.reward_threshold: + return x >= env.spec.reward_threshold + else: + return False + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, + args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, writer=writer) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + policy.eval() + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + + +if __name__ == '__main__': + test_psrl() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 3517f0368..c9faf653a 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -9,7 +9,7 @@ from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager -from tianshou.policy.psrl.psrl import PSRLPolicy +from tianshou.policy.modelbase.psrl import PSRLPolicy __all__ = [ diff --git a/tianshou/policy/psrl/__init__.py b/tianshou/policy/modelbase/__init__.py similarity index 100% rename from tianshou/policy/psrl/__init__.py rename to tianshou/policy/modelbase/__init__.py diff --git a/tianshou/policy/psrl/psrl.py b/tianshou/policy/modelbase/psrl.py similarity index 69% rename from tianshou/policy/psrl/psrl.py rename to tianshou/policy/modelbase/psrl.py index fd53c6f0a..d1a017687 100644 --- a/tianshou/policy/psrl/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,10 +1,9 @@ -import torch import numpy as np from typing import Dict, List, Union, Optional import mdptoolbox from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer class PSRLModel(object): @@ -12,16 +11,20 @@ class PSRLModel(object): :param np.ndarray p_prior: dirichlet prior (alphas). :param np.ndarray rew_mean_prior: means of the normal priors of rewards. - :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards. - :param float discount_factor: in [0, 1]. + :param np.ndarray rew_std_prior: standard deviations of the normal + priors of rewards. + :param float discount_factor: in (0, 1]. .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. + Strens M. A Bayesian framework for reinforcement learning[C] + //ICML. 2000, 2000: 943-950. """ - def __init__(self, p_prior, rew_mean_prior, rew_std_prior, discount_factor: float = 0.99): + def __init__(self, p_prior: np.ndarray, rew_mean_prior: np.ndarray, + rew_std_prior: np.ndarray, discount_factor: float = 0.99): self.p = p_prior self.n_action = len(self.p) self.n_state = len(self.p[0]) @@ -34,25 +37,32 @@ def __init__(self, p_prior, rew_mean_prior, rew_std_prior, discount_factor: floa self.policy = None self.updated = False - def observe(self, p, rew_sum, rew_count): + def observe(self, p: np.ndarray, rew_sum: np.ndarray, + rew_count: np.ndarray): """Add data.""" self.updated = False self.p += p - sum_count_nonzero = np.where(self.rew_count + rew_count == 0, 1, self.rew_count + rew_count) + sum_count_nonzero = np.where(self.rew_count + rew_count == 0, + 1, self.rew_count + rew_count) rew_count_nonzero = np.where(rew_count == 0, 1, rew_count) self.rew_mean = np.where(self.rew_count == 0, - np.where(rew_count == 0, self.rew_mean, rew_sum / rew_count_nonzero), - (self.rew_mean * self.rew_count + rew_sum) / sum_count_nonzero) - self.rew_std *= np.where(self.rew_count == 0, 1, self.rew_count) / sum_count_nonzero + np.where(rew_count == 0, self.rew_mean, + rew_sum / rew_count_nonzero), + (self.rew_mean * self.rew_count + rew_sum) + / sum_count_nonzero) + self.rew_std *= np.where(self.rew_count == 0, 1, + self.rew_count) / sum_count_nonzero self.rew_count += rew_count def sample_from_p(self): sample_p = [] for a in range(self.n_action): for i in range(self.n_state): - param = self.p[a][i] + 1e-5 * np.random.randn(len(self.p[a][i])) + param = self.p[a][i] + \ + 1e-5 * np.random.randn(len(self.p[a][i])) sample_p.append(param / np.sum(param)) - sample_p = np.array(sample_p).reshape(self.n_action, self.n_state, self.n_state) + sample_p = np.array(sample_p).reshape( + self.n_action, self.n_state, self.n_state) return sample_p def sample_from_rew(self): @@ -64,12 +74,13 @@ def solve_policy(self): self.updated = True self.sample_p = self.sample_from_p() self.sample_rew = self.sample_from_rew() - problem = mdptoolbox.mdp.ValueIteration(self.sample_p, self.sample_rew, self.discount_factor) + problem = mdptoolbox.mdp.ValueIteration( + self.sample_p, self.sample_rew, self.discount_factor) problem.run() self.policy = np.array(problem.policy) return self.policy - def __call__(self, obs, state=None, info=None): + def __call__(self, obs: np.ndarray, state=None, info=None): if self.updated is False: self.solve_policy() act = self.policy[obs] @@ -79,8 +90,10 @@ def __call__(self, obs, state=None, info=None): class PSRLPolicy(BasePolicy): """Implementation of Posterior Sampling Reinforcement Learning. - :param PSRLModel model: a model following the rules in - :class:`PSRLModel`. + :param np.ndarray p_prior: dirichlet prior (alphas). + :param np.ndarray rew_mean_prior: means of the normal priors of rewards. + :param np.ndarray rew_std_prior: standard deviations of the normal + priors of rewards. :param torch.distributions.Distribution dist_fn: for computing the action. :param float discount_factor: in [0, 1]. :param bool reward_normalization: normalize the reward to Normal(0, 1), @@ -90,18 +103,18 @@ class PSRLPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. + """ def __init__(self, - model: PSRLModel, - dist_fn: torch.distributions.Distribution - = torch.distributions.Categorical, - discount_factor: float = 0.2, + p_prior: np.ndarray, + rew_mean_prior: np.ndarray, + rew_std_prior: np.ndarray, + discount_factor: float = 0, reward_normalization: bool = False, **kwargs) -> None: super().__init__(**kwargs) - self.model = model - self.dist_fn = dist_fn + self.model = PSRLModel(p_prior, rew_mean_prior, rew_std_prior) assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]' self._gamma = discount_factor self._rew_norm = reward_normalization @@ -121,11 +134,9 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, , where :math:`T` is the terminal time step, :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. """ - # batch.returns = self._vanilla_returns(batch) - # batch.returns = self._vectorized_returns(batch) - # return batch - return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1.) + return self.compute_episodic_return(batch, gamma=self._gamma, + gae_lambda=1., + rew_norm=self._rew_norm) def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, @@ -133,9 +144,8 @@ def forward(self, batch: Batch, **kwargs) -> Batch: """Compute action over the given batch data. - :return: A :class:`~tianshou.data.Batch` which has 1 key: - - * ``act`` the action. + :return: A :class:`~tianshou.data.Batch` with "act" key containing + the action. .. seealso:: @@ -152,10 +162,8 @@ def forward(self, batch: Batch, return Batch(act=act) def learn(self, batch: Batch, **kwargs) -> Dict[str, List[float]]: - r = batch.returns - if self._rew_norm and not np.isclose(r.std(), 0): - batch.returns = (r - r.mean()) / r.std() - p = np.zeros((self.model.n_action, self.model.n_state, self.model.n_state)) + p = np.zeros((self.model.n_action, self.model.n_state, + self.model.n_state)) rew_sum = np.zeros((self.model.n_state, self.model.n_action)) rew_count = np.zeros((self.model.n_state, self.model.n_action)) a = batch.act @@ -168,15 +176,3 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, List[float]]: rew_count[obs[i]][a[i]] += 1 self.model.observe(p, rew_sum, rew_count) return {'loss': [0.0]} - - -if __name__ == "__main__": - n_action = 3 - n_state = 4 - p_pri = np.ones((n_action, n_state, n_state)) - rew_mean = np.ones((n_state, n_action)) - rew_std = np.ones((n_state, n_action)) - model = PSRLModel(p_pri, rew_mean, rew_std) - policy = PSRLPolicy(model) - import pdb - pdb.set_trace() From a317933dca033ee54ef8be655cd1fb40accf115f Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 12:44:20 +0800 Subject: [PATCH 03/62] polish --- examples/modelbase/NChain_psrl.py | 92 ------------------- examples/modelbase/README.md | 7 ++ examples/modelbase/Taxi_psrl.py | 92 ------------------- .../modelbase/{FrozenLake_psrl.py => psrl.py} | 18 ++-- tianshou/policy/modelbase/psrl.py | 85 +++++++++-------- 5 files changed, 66 insertions(+), 228 deletions(-) delete mode 100644 examples/modelbase/NChain_psrl.py create mode 100644 examples/modelbase/README.md delete mode 100644 examples/modelbase/Taxi_psrl.py rename examples/modelbase/{FrozenLake_psrl.py => psrl.py} (91%) diff --git a/examples/modelbase/NChain_psrl.py b/examples/modelbase/NChain_psrl.py deleted file mode 100644 index e98c44673..000000000 --- a/examples/modelbase/NChain_psrl.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer - -import gym -from tianshou.policy import PSRLPolicy - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='NChain-v0') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0) - parser.add_argument('--eps-train', type=float, default=0.1) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=100) - parser.add_argument('--collect-per-step', type=int, default=100) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=1) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - return parser.parse_args() - - -def test_psrl(args=get_args()): - env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.env.action_space.shape or env.env.action_space.n - # train_envs = gym.make(args.task) - # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - n_action = args.action_shape - n_state = args.state_shape - p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) - rew_mean = np.zeros((n_state, n_action)) - rew_std = np.ones((n_state, n_action)) - policy = PSRLPolicy(p_pri, rew_mean, rew_std) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) - # log - writer = SummaryWriter(args.logdir + '/' + args.task) - - def train_fn(x): - policy.set_eps(args.eps_train) - - def test_fn(x): - policy.set_eps(args.eps_test) - - def stop_fn(x): - if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold - else: - return False - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, - args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, writer=writer) - - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - policy.eval() - result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_psrl() diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md new file mode 100644 index 000000000..0112cb523 --- /dev/null +++ b/examples/modelbase/README.md @@ -0,0 +1,7 @@ +# PSRL + +`NChain-v0`: `python3 psrl.py --task NChain-v0 --buffer-size 20000 --epoch 5` + +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --buffer-size 1 --epoch 20` + +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --buffer-size 20000 --epoch 20` diff --git a/examples/modelbase/Taxi_psrl.py b/examples/modelbase/Taxi_psrl.py deleted file mode 100644 index d1283ae29..000000000 --- a/examples/modelbase/Taxi_psrl.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer - -import gym -from tianshou.policy import PSRLPolicy - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Taxi-v3') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0) - parser.add_argument('--eps-train', type=float, default=0.1) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=100) - parser.add_argument('--collect-per-step', type=int, default=100) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=1) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - return parser.parse_args() - - -def test_psrl(args=get_args()): - env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.env.action_space.shape or env.env.action_space.n - # train_envs = gym.make(args.task) - # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - n_action = args.action_shape - n_state = args.state_shape - p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) - rew_mean = np.zeros((n_state, n_action)) - rew_std = np.ones((n_state, n_action)) - policy = PSRLPolicy(p_pri, rew_mean, rew_std) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) - # log - writer = SummaryWriter(args.logdir + '/' + args.task) - - def train_fn(x): - policy.set_eps(args.eps_train) - - def test_fn(x): - policy.set_eps(args.eps_test) - - def stop_fn(x): - if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold - else: - return False - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, - args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, writer=writer) - - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - policy.eval() - result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_psrl() diff --git a/examples/modelbase/FrozenLake_psrl.py b/examples/modelbase/psrl.py similarity index 91% rename from examples/modelbase/FrozenLake_psrl.py rename to examples/modelbase/psrl.py index 372eb8f91..3f42a6023 100644 --- a/examples/modelbase/FrozenLake_psrl.py +++ b/examples/modelbase/psrl.py @@ -1,15 +1,14 @@ +import gym import torch import pprint import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.policy import PSRLPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -import gym -from tianshou.policy import PSRLPolicy +from tianshou.env import DummyVectorEnv, SubprocVectorEnv def get_args(): @@ -81,12 +80,15 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) policy.eval() - result = collector.collect(n_episode=1, render=args.render) + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) + pprint.pprint(result) print(f'Final reward: {result["rew"]}, length: {result["len"]}') if __name__ == '__main__': - test_psrl() + test_psrl(get_args()) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index d1a017687..bffd23937 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,6 +1,6 @@ -import numpy as np -from typing import Dict, List, Union, Optional import mdptoolbox +import numpy as np +from typing import Any, Dict, List, Union, Optional from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer @@ -19,12 +19,15 @@ class PSRLModel(object): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. - Strens M. A Bayesian framework for reinforcement learning[C] - //ICML. 2000, 2000: 943-950. """ - def __init__(self, p_prior: np.ndarray, rew_mean_prior: np.ndarray, - rew_std_prior: np.ndarray, discount_factor: float = 0.99): + def __init__( + self, + p_prior: np.ndarray, + rew_mean_prior: np.ndarray, + rew_std_prior: np.ndarray, + discount_factor: float = 0.99, + ) -> None: self.p = p_prior self.n_action = len(self.p) self.n_state = len(self.p[0]) @@ -37,9 +40,10 @@ def __init__(self, p_prior: np.ndarray, rew_mean_prior: np.ndarray, self.policy = None self.updated = False - def observe(self, p: np.ndarray, rew_sum: np.ndarray, - rew_count: np.ndarray): - """Add data.""" + def observe( + self, p: np.ndarray, rew_sum: np.ndarray, rew_count: np.ndarray + ) -> None: + """Add data into memory pool.""" self.updated = False self.p += p sum_count_nonzero = np.where(self.rew_count + rew_count == 0, @@ -54,23 +58,23 @@ def observe(self, p: np.ndarray, rew_sum: np.ndarray, self.rew_count) / sum_count_nonzero self.rew_count += rew_count - def sample_from_p(self): + def sample_from_p(self) -> np.ndarray: sample_p = [] for a in range(self.n_action): for i in range(self.n_state): param = self.p[a][i] + \ - 1e-5 * np.random.randn(len(self.p[a][i])) + 1e-5 * np.random.randn(len(self.p[a][i])) sample_p.append(param / np.sum(param)) sample_p = np.array(sample_p).reshape( self.n_action, self.n_state, self.n_state) return sample_p - def sample_from_rew(self): + def sample_from_rew(self) -> np.ndarray: sample_rew = np.random.randn(len(self.rew_mean), len(self.rew_mean[0])) sample_rew = sample_rew * self.rew_std + self.rew_mean return sample_rew - def solve_policy(self): + def solve_policy(self) -> np.ndarray: self.updated = True self.sample_p = self.sample_from_p() self.sample_rew = self.sample_from_rew() @@ -80,16 +84,18 @@ def solve_policy(self): self.policy = np.array(problem.policy) return self.policy - def __call__(self, obs: np.ndarray, state=None, info=None): + def __call__(self, obs: np.ndarray, state=None, info=None) -> np.ndarray: if self.updated is False: self.solve_policy() - act = self.policy[obs] - return act + return self.policy[obs] class PSRLPolicy(BasePolicy): """Implementation of Posterior Sampling Reinforcement Learning. + Reference: Strens M. A Bayesian framework for reinforcement learning [C] + //ICML. 2000, 2000: 943-950. + :param np.ndarray p_prior: dirichlet prior (alphas). :param np.ndarray rew_mean_prior: means of the normal priors of rewards. :param np.ndarray rew_std_prior: standard deviations of the normal @@ -103,16 +109,17 @@ class PSRLPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. - """ - def __init__(self, - p_prior: np.ndarray, - rew_mean_prior: np.ndarray, - rew_std_prior: np.ndarray, - discount_factor: float = 0, - reward_normalization: bool = False, - **kwargs) -> None: + def __init__( + self, + p_prior: np.ndarray, + rew_mean_prior: np.ndarray, + rew_std_prior: np.ndarray, + discount_factor: float = 0, + reward_normalization: bool = False, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.model = PSRLModel(p_prior, rew_mean_prior, rew_std_prior) assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]' @@ -124,8 +131,9 @@ def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" self.eps = eps - def process_fn(self, batch: Batch, buffer: ReplayBuffer, - indice: np.ndarray) -> Batch: + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: r"""Compute the discounted returns for each frame: .. math:: @@ -134,14 +142,17 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, , where :math:`T` is the terminal time step, :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. """ - return self.compute_episodic_return(batch, gamma=self._gamma, - gae_lambda=1., - rew_norm=self._rew_norm) - - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - eps: Optional[float] = None, - **kwargs) -> Batch: + return self.compute_episodic_return( + batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm + ) + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + eps: Optional[float] = None, + **kwargs: Any, + ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` with "act" key containing @@ -161,7 +172,9 @@ def forward(self, batch: Batch, act[i] = np.random.randint(0, self.model.n_action) return Batch(act=act) - def learn(self, batch: Batch, **kwargs) -> Dict[str, List[float]]: + def learn( # type: ignore + self, batch: Batch, **kwargs: Any + ) -> Dict[str, float]: p = np.zeros((self.model.n_action, self.model.n_state, self.model.n_state)) rew_sum = np.zeros((self.model.n_state, self.model.n_action)) @@ -175,4 +188,4 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, List[float]]: rew_sum[obs[i]][a[i]] += r[i] rew_count[obs[i]][a[i]] += 1 self.model.observe(p, rew_sum, rew_count) - return {'loss': [0.0]} + return {'loss': 0.0} From 6c25cea1cab46c1bf8c2f65608d4dab0f8e763c8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 12:48:35 +0800 Subject: [PATCH 04/62] pep8 --- tianshou/policy/modelbase/psrl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index bffd23937..3eca9c37c 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,6 +1,6 @@ import mdptoolbox import numpy as np -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, Union, Optional from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer From 43e25817c629f0a2cc99249d58efc4fe6f399755 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 12:53:12 +0800 Subject: [PATCH 05/62] fix docs error --- tianshou/policy/modelbase/psrl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 3eca9c37c..5effa2be6 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -12,7 +12,7 @@ class PSRLModel(object): :param np.ndarray p_prior: dirichlet prior (alphas). :param np.ndarray rew_mean_prior: means of the normal priors of rewards. :param np.ndarray rew_std_prior: standard deviations of the normal - priors of rewards. + priors of rewards. :param float discount_factor: in (0, 1]. .. seealso:: @@ -99,7 +99,7 @@ class PSRLPolicy(BasePolicy): :param np.ndarray p_prior: dirichlet prior (alphas). :param np.ndarray rew_mean_prior: means of the normal priors of rewards. :param np.ndarray rew_std_prior: standard deviations of the normal - priors of rewards. + priors of rewards. :param torch.distributions.Distribution dist_fn: for computing the action. :param float discount_factor: in [0, 1]. :param bool reward_normalization: normalize the reward to Normal(0, 1), From b78cba4451c05d61372ee763df93bd6f217a43b8 Mon Sep 17 00:00:00 2001 From: Yao Date: Sat, 5 Sep 2020 15:29:54 +0800 Subject: [PATCH 06/62] add value iteration --- examples/modelbase/psrl.py | 6 ++--- setup.py | 1 - tianshou/policy/modelbase/psrl.py | 39 +++++++++++++++---------------- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/examples/modelbase/psrl.py b/examples/modelbase/psrl.py index 3f42a6023..e10f52a1e 100644 --- a/examples/modelbase/psrl.py +++ b/examples/modelbase/psrl.py @@ -16,10 +16,10 @@ def get_args(): parser.add_argument('--task', type=str, default='FrozenLake-v0') parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0) - parser.add_argument('--eps-train', type=float, default=0.1) - parser.add_argument('--buffer-size', type=int, default=1) + parser.add_argument('--eps-train', type=float, default=0.3) + parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=500) parser.add_argument('--collect-per-step', type=int, default=100) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) diff --git a/setup.py b/setup.py index 61f32e766..175112c44 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,6 @@ 'numpy', 'tensorboard', 'torch>=1.4.0', - 'pymdptoolbox', 'numba>=0.51.0', ], extras_require={ diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 5effa2be6..79d36a4e8 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,4 +1,3 @@ -import mdptoolbox import numpy as np from typing import Any, Dict, Union, Optional @@ -13,7 +12,6 @@ class PSRLModel(object): :param np.ndarray rew_mean_prior: means of the normal priors of rewards. :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards. - :param float discount_factor: in (0, 1]. .. seealso:: @@ -26,16 +24,14 @@ def __init__( p_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, - discount_factor: float = 0.99, ) -> None: self.p = p_prior self.n_action = len(self.p) self.n_state = len(self.p[0]) self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior - self.discount_factor = discount_factor self.rew_count = np.zeros_like(rew_mean_prior) - self.sample_p = None + self.p_ml = None self.sample_rew = None self.policy = None self.updated = False @@ -58,16 +54,9 @@ def observe( self.rew_count) / sum_count_nonzero self.rew_count += rew_count - def sample_from_p(self) -> np.ndarray: - sample_p = [] - for a in range(self.n_action): - for i in range(self.n_state): - param = self.p[a][i] + \ - 1e-5 * np.random.randn(len(self.p[a][i])) - sample_p.append(param / np.sum(param)) - sample_p = np.array(sample_p).reshape( - self.n_action, self.n_state, self.n_state) - return sample_p + def get_p_ml(self) -> np.ndarray: + p_ml = self.p / np.sum(self.p, axis=-1, keepdims=True) + return p_ml def sample_from_rew(self) -> np.ndarray: sample_rew = np.random.randn(len(self.rew_mean), len(self.rew_mean[0])) @@ -76,14 +65,24 @@ def sample_from_rew(self) -> np.ndarray: def solve_policy(self) -> np.ndarray: self.updated = True - self.sample_p = self.sample_from_p() + self.p_ml = self.get_p_ml() self.sample_rew = self.sample_from_rew() - problem = mdptoolbox.mdp.ValueIteration( - self.sample_p, self.sample_rew, self.discount_factor) - problem.run() - self.policy = np.array(problem.policy) + self.policy = self.value_iteration(self.p_ml, self.sample_rew) return self.policy + @staticmethod + def value_iteration(p: np.ndarray, rew: np.ndarray, + epsilon: float = 0.01) -> np.ndarray: + value = np.zeros(len(rew)) + while True: + Q = rew + np.matmul(p, value).T + new_value = np.max(Q, axis=1) + if np.max(np.abs(new_value - value) / + (np.abs(new_value) + 1e-5)) < epsilon: + return np.argmax(Q, axis=1) + else: + value = new_value + def __call__(self, obs: np.ndarray, state=None, info=None) -> np.ndarray: if self.updated is False: self.solve_policy() From 577a26985d7f81408fecdba9e692afab230d9b5e Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 15:51:46 +0800 Subject: [PATCH 07/62] polish --- examples/modelbase/README.md | 6 +++--- examples/modelbase/psrl.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index 0112cb523..2f3bda58c 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -1,7 +1,7 @@ # PSRL -`NChain-v0`: `python3 psrl.py --task NChain-v0 --buffer-size 20000 --epoch 5` +`NChain-v0`: `python3 psrl.py --task NChain-v0 --epoch 5` -`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --buffer-size 1 --epoch 20` +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20` -`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --buffer-size 20000 --epoch 20` +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --epoch 20` diff --git a/examples/modelbase/psrl.py b/examples/modelbase/psrl.py index e10f52a1e..4e40229fa 100644 --- a/examples/modelbase/psrl.py +++ b/examples/modelbase/psrl.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument('--collect-per-step', type=int, default=100) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=1) + parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) return parser.parse_args() @@ -31,6 +31,7 @@ def get_args(): def test_psrl(args=get_args()): env = gym.make(args.task) + print('Reward threshold: ', env.spec.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # train_envs = gym.make(args.task) From 554d81e2eae7d5a677cf96bbf2676b002cf45a19 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 16:06:01 +0800 Subject: [PATCH 08/62] soft link --- examples/modelbase/README.md | 6 +- examples/modelbase/psrl.py | 96 +------------------ .../{NChain_psrl.py => test_psrl.py} | 16 ++-- 3 files changed, 13 insertions(+), 105 deletions(-) mode change 100644 => 120000 examples/modelbase/psrl.py rename test/modelbase/{NChain_psrl.py => test_psrl.py} (91%) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index 2f3bda58c..a530d23fd 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -1,7 +1,7 @@ # PSRL -`NChain-v0`: `python3 psrl.py --task NChain-v0 --epoch 5` +`NChain-v0`: `python3 psrl.py` -`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20` +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20 --step-per-epoch 500 --collect-per-step 100` -`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --epoch 20` +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --epoch 20 --step-per-epoch 500 --collect-per-step 100` diff --git a/examples/modelbase/psrl.py b/examples/modelbase/psrl.py deleted file mode 100644 index 4e40229fa..000000000 --- a/examples/modelbase/psrl.py +++ /dev/null @@ -1,95 +0,0 @@ -import gym -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import PSRLPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer -from tianshou.env import DummyVectorEnv, SubprocVectorEnv - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='FrozenLake-v0') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0) - parser.add_argument('--eps-train', type=float, default=0.3) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=500) - parser.add_argument('--collect-per-step', type=int, default=100) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - return parser.parse_args() - - -def test_psrl(args=get_args()): - env = gym.make(args.task) - print('Reward threshold: ', env.spec.reward_threshold) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.env.action_space.shape or env.env.action_space.n - # train_envs = gym.make(args.task) - # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - n_action = args.action_shape - n_state = args.state_shape - p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) - rew_mean = np.zeros((n_state, n_action)) - rew_std = np.ones((n_state, n_action)) - policy = PSRLPolicy(p_pri, rew_mean, rew_std) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) - # log - writer = SummaryWriter(args.logdir + '/' + args.task) - - def train_fn(x): - policy.set_eps(args.eps_train) - - def test_fn(x): - policy.set_eps(args.eps_test) - - def stop_fn(x): - if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold - else: - return False - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, - args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, writer=writer) - - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - policy.eval() - policy.set_eps(args.eps_test) - test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - pprint.pprint(result) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_psrl(get_args()) diff --git a/examples/modelbase/psrl.py b/examples/modelbase/psrl.py new file mode 120000 index 000000000..228d2594c --- /dev/null +++ b/examples/modelbase/psrl.py @@ -0,0 +1 @@ +../../test/modelbase/test_psrl.py \ No newline at end of file diff --git a/test/modelbase/NChain_psrl.py b/test/modelbase/test_psrl.py similarity index 91% rename from test/modelbase/NChain_psrl.py rename to test/modelbase/test_psrl.py index e98c44673..d87049027 100644 --- a/test/modelbase/NChain_psrl.py +++ b/test/modelbase/test_psrl.py @@ -1,15 +1,14 @@ +import gym import torch import pprint import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.policy import PSRLPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -import gym -from tianshou.policy import PSRLPolicy +from tianshou.env import DummyVectorEnv, SubprocVectorEnv def get_args(): @@ -19,12 +18,12 @@ def get_args(): parser.add_argument('--eps-test', type=float, default=0) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=100) - parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=1) + parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) return parser.parse_args() @@ -32,6 +31,9 @@ def get_args(): def test_psrl(args=get_args()): env = gym.make(args.task) + if args.task == "NChain-v0": + env.spec.reward_threshold = 3650 # discribed in PSRL paper + print("reward threahold:", env.spec.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # train_envs = gym.make(args.task) From 91d5bddf883b58185bc7896c7b42f9752ee3f60d Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 16:09:31 +0800 Subject: [PATCH 09/62] fix pytest --- test/modelbase/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/modelbase/__init__.py diff --git a/test/modelbase/__init__.py b/test/modelbase/__init__.py new file mode 100644 index 000000000..e69de29bb From df90377e07dd885bb01e610f7c2844cef8d56ad0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 16:18:11 +0800 Subject: [PATCH 10/62] fix pytest --- test/modelbase/test_psrl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index d87049027..3d6857504 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - return parser.parse_args() + return parser.parse_known_args()[0] def test_psrl(args=get_args()): From bd1c34cb16faa2af189732c41936e0bb69764e89 Mon Sep 17 00:00:00 2001 From: Yao Date: Sat, 5 Sep 2020 19:57:21 +0800 Subject: [PATCH 11/62] polish --- examples/modelbase/psrl.py | 2 +- test/modelbase/test_psrl.py | 10 +++++----- tianshou/policy/modelbase/psrl.py | 30 +++++++++++++++++++----------- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/examples/modelbase/psrl.py b/examples/modelbase/psrl.py index 228d2594c..010933b85 120000 --- a/examples/modelbase/psrl.py +++ b/examples/modelbase/psrl.py @@ -1 +1 @@ -../../test/modelbase/test_psrl.py \ No newline at end of file +../../test/modelbase/test_psrl.py diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 3d6857504..62cf0bc83 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -33,7 +33,7 @@ def test_psrl(args=get_args()): env = gym.make(args.task) if args.task == "NChain-v0": env.spec.reward_threshold = 3650 # discribed in PSRL paper - print("reward threahold:", env.spec.reward_threshold) + print("reward threshold:", env.spec.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # train_envs = gym.make(args.task) @@ -51,10 +51,10 @@ def test_psrl(args=get_args()): # model n_action = args.action_shape n_state = args.state_shape - p_pri = 1e-3 * np.ones((n_action, n_state, n_state)) - rew_mean = np.zeros((n_state, n_action)) - rew_std = np.ones((n_state, n_action)) - policy = PSRLPolicy(p_pri, rew_mean, rew_std) + p_prior = 1e-3 * np.ones((n_action, n_state, n_state)) + rew_mean_prior = np.zeros((n_state, n_action)) + rew_std_prior = np.ones((n_state, n_action)) + policy = PSRLPolicy(p_prior, rew_mean_prior, rew_std_prior) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 79d36a4e8..0b70f7056 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -8,10 +8,12 @@ class PSRLModel(object): """Implementation of Posterior Sampling Reinforcement Learning Model. - :param np.ndarray p_prior: dirichlet prior (alphas). - :param np.ndarray rew_mean_prior: means of the normal priors of rewards. + :param np.ndarray p_prior: dirichlet prior (alphas), + shape: (n_action, n_state, n_state). + :param np.ndarray rew_mean_prior: means of the normal priors of rewards, + shape: (n_state, n_action) :param np.ndarray rew_std_prior: standard deviations of the normal - priors of rewards. + priors of rewards, shape: (n_state, n_action). .. seealso:: @@ -39,7 +41,14 @@ def __init__( def observe( self, p: np.ndarray, rew_sum: np.ndarray, rew_count: np.ndarray ) -> None: - """Add data into memory pool.""" + """Add data into memory pool. + :param np.ndarray p: the number of observations, + shape: (n_action, n_state, n_state). + :param np.ndarray rew_sum: total rewards, + shape: (n_state, n_action) + :param np.ndarray rew_count: the number of rewards, + shape: (n_state, n_action). + """ self.updated = False self.p += p sum_count_nonzero = np.where(self.rew_count + rew_count == 0, @@ -55,11 +64,10 @@ def observe( self.rew_count += rew_count def get_p_ml(self) -> np.ndarray: - p_ml = self.p / np.sum(self.p, axis=-1, keepdims=True) - return p_ml + return self.p / np.sum(self.p, axis=-1, keepdims=True) def sample_from_rew(self) -> np.ndarray: - sample_rew = np.random.randn(len(self.rew_mean), len(self.rew_mean[0])) + sample_rew = np.random.randn(*self.rew_mean.shape) sample_rew = sample_rew * self.rew_std + self.rew_mean return sample_rew @@ -77,8 +85,7 @@ def value_iteration(p: np.ndarray, rew: np.ndarray, while True: Q = rew + np.matmul(p, value).T new_value = np.max(Q, axis=1) - if np.max(np.abs(new_value - value) / - (np.abs(new_value) + 1e-5)) < epsilon: + if np.allclose(new_value, value, epsilon): return np.argmax(Q, axis=1) else: value = new_value @@ -121,10 +128,11 @@ def __init__( ) -> None: super().__init__(**kwargs) self.model = PSRLModel(p_prior, rew_mean_prior, rew_std_prior) - assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]' + assert 0.0 <= discount_factor <= 1.0, \ + "discount factor should in [0, 1]" self._gamma = discount_factor self._rew_norm = reward_normalization - self.eps = 0 + self.eps = 0.0 def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" From f5fc5f8f3e378fab9122c7302fb3898e53294a31 Mon Sep 17 00:00:00 2001 From: Yao Date: Sat, 5 Sep 2020 20:05:57 +0800 Subject: [PATCH 12/62] bug fixed --- examples/modelbase/psrl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modelbase/psrl.py b/examples/modelbase/psrl.py index 010933b85..228d2594c 120000 --- a/examples/modelbase/psrl.py +++ b/examples/modelbase/psrl.py @@ -1 +1 @@ -../../test/modelbase/test_psrl.py +../../test/modelbase/test_psrl.py \ No newline at end of file From 077b01578ba692ef15792d6e13b930be7dccedf9 Mon Sep 17 00:00:00 2001 From: Yao Date: Sat, 5 Sep 2020 20:52:19 +0800 Subject: [PATCH 13/62] polish --- tianshou/policy/modelbase/psrl.py | 37 ++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 0b70f7056..e45a14d6b 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -8,12 +8,12 @@ class PSRLModel(object): """Implementation of Posterior Sampling Reinforcement Learning Model. - :param np.ndarray p_prior: dirichlet prior (alphas), - shape: (n_action, n_state, n_state). + :param np.ndarray p_prior: dirichlet prior (alphas), with shape + (n_action, n_state, n_state). :param np.ndarray rew_mean_prior: means of the normal priors of rewards, - shape: (n_state, n_action) + with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal - priors of rewards, shape: (n_state, n_action). + priors of rewards, with shape (n_state, n_action). .. seealso:: @@ -42,12 +42,20 @@ def observe( self, p: np.ndarray, rew_sum: np.ndarray, rew_count: np.ndarray ) -> None: """Add data into memory pool. - :param np.ndarray p: the number of observations, - shape: (n_action, n_state, n_state). - :param np.ndarray rew_sum: total rewards, - shape: (n_state, n_action) - :param np.ndarray rew_count: the number of rewards, - shape: (n_state, n_action). + + :param np.ndarray p: the number of observations, with shape + (n_action, n_state, n_state). + :param np.ndarray rew_sum: total rewards, with shape + (n_state, n_action). + :param np.ndarray rew_count: the number of rewards, with + shape (n_state, n_action). + Here self.p += p updates p. + For rewards, we have a normal prior at first. After + we observed a reward for a given state-action pair, we use + the mean value of our observations instead of the prior mean + as the posterior mean. The standard deviations are in + inverse proportion to the number of corresponding + observations. """ self.updated = False self.p += p @@ -81,6 +89,13 @@ def solve_policy(self) -> np.ndarray: @staticmethod def value_iteration(p: np.ndarray, rew: np.ndarray, epsilon: float = 0.01) -> np.ndarray: + """Value iteration solver for MDPs. + + :param np.ndarray p: transition probabilities, with shape + (n_action, n_state, n_state). + :param np.ndarray rew: rewards, with shape (n_state, n_action). + :param float epsilon: for precision control. + """ value = np.zeros(len(rew)) while True: Q = rew + np.matmul(p, value).T @@ -160,7 +175,7 @@ def forward( eps: Optional[float] = None, **kwargs: Any, ) -> Batch: - """Compute action over the given batch data. + """Compute action over the given batch data with PSRL model. :return: A :class:`~tianshou.data.Batch` with "act" key containing the action. From caf5ea4fba742ca675f179657e66aa52e1b3d303 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 21:48:57 +0800 Subject: [PATCH 14/62] minor update --- tianshou/policy/modelbase/psrl.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index e45a14d6b..4a0a78a15 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -43,19 +43,18 @@ def observe( ) -> None: """Add data into memory pool. + For rewards, we have a normal prior at first. After we observed a + reward for a given state-action pair, we use the mean value of our + observations instead of the prior mean as the posterior mean. The + standard deviations are in inverse proportion to the number of the + corresponding observations. + :param np.ndarray p: the number of observations, with shape (n_action, n_state, n_state). :param np.ndarray rew_sum: total rewards, with shape (n_state, n_action). :param np.ndarray rew_count: the number of rewards, with shape (n_state, n_action). - Here self.p += p updates p. - For rewards, we have a normal prior at first. After - we observed a reward for a given state-action pair, we use - the mean value of our observations instead of the prior mean - as the posterior mean. The standard deviations are in - inverse proportion to the number of corresponding - observations. """ self.updated = False self.p += p @@ -79,12 +78,11 @@ def sample_from_rew(self) -> np.ndarray: sample_rew = sample_rew * self.rew_std + self.rew_mean return sample_rew - def solve_policy(self) -> np.ndarray: + def solve_policy(self) -> None: self.updated = True self.p_ml = self.get_p_ml() self.sample_rew = self.sample_from_rew() self.policy = self.value_iteration(self.p_ml, self.sample_rew) - return self.policy @staticmethod def value_iteration(p: np.ndarray, rew: np.ndarray, @@ -194,9 +192,9 @@ def forward( act[i] = np.random.randint(0, self.model.n_action) return Batch(act=act) - def learn( # type: ignore + def learn( self, batch: Batch, **kwargs: Any - ) -> Dict[str, float]: + ) -> Dict[str, Union[float, List[float]]]: p = np.zeros((self.model.n_action, self.model.n_state, self.model.n_state)) rew_sum = np.zeros((self.model.n_state, self.model.n_action)) @@ -210,4 +208,4 @@ def learn( # type: ignore rew_sum[obs[i]][a[i]] += r[i] rew_count[obs[i]][a[i]] += 1 self.model.observe(p, rew_sum, rew_count) - return {'loss': 0.0} + return {} From 2bc2fb194cfee5f1e4633a9906f295fa98e253ef Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 22:02:36 +0800 Subject: [PATCH 15/62] polish --- tianshou/policy/modelbase/psrl.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 4a0a78a15..02d81fef3 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -28,14 +28,11 @@ def __init__( rew_std_prior: np.ndarray, ) -> None: self.p = p_prior - self.n_action = len(self.p) - self.n_state = len(self.p[0]) + self.n_action, self.n_state, _ = p_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior self.rew_count = np.zeros_like(rew_mean_prior) - self.p_ml = None - self.sample_rew = None - self.policy = None + self.policy: Optional[np.ndarray] = None self.updated = False def observe( @@ -80,9 +77,8 @@ def sample_from_rew(self) -> np.ndarray: def solve_policy(self) -> None: self.updated = True - self.p_ml = self.get_p_ml() - self.sample_rew = self.sample_from_rew() - self.policy = self.value_iteration(self.p_ml, self.sample_rew) + self.policy = self.value_iteration( + self.get_p_ml(), self.sample_from_rew()) @staticmethod def value_iteration(p: np.ndarray, rew: np.ndarray, @@ -163,7 +159,7 @@ def process_fn( discount factor, :math:`\gamma \in [0, 1]`. """ return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm + batch, gamma=self._gamma, gae_lambda=1.0, rew_norm=self._rew_norm ) def forward( @@ -192,17 +188,15 @@ def forward( act[i] = np.random.randint(0, self.model.n_action) return Batch(act=act) - def learn( + def learn( # type: ignore self, batch: Batch, **kwargs: Any - ) -> Dict[str, Union[float, List[float]]]: + ) -> Dict[str, float]: p = np.zeros((self.model.n_action, self.model.n_state, self.model.n_state)) rew_sum = np.zeros((self.model.n_state, self.model.n_action)) - rew_count = np.zeros((self.model.n_state, self.model.n_action)) - a = batch.act - r = batch.returns - obs = batch.obs - obs_next = batch.obs_next + rew_count = np.zeros_like(rew_sum) + a, r = batch.act, batch.returns + obs, obs_next = batch.obs, batch.obs_next for i in range(len(obs)): p[a[i]][obs[i]][obs_next[i]] += 1 rew_sum[obs[i]][a[i]] += r[i] From e825bfc5e74878f8c80db80ba7b203fa2b9cb8b0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 5 Sep 2020 22:09:21 +0800 Subject: [PATCH 16/62] docs --- tianshou/policy/modelbase/psrl.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 02d81fef3..ac2033ae3 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -12,8 +12,8 @@ class PSRLModel(object): (n_action, n_state, n_state). :param np.ndarray rew_mean_prior: means of the normal priors of rewards, with shape (n_state, n_action). - :param np.ndarray rew_std_prior: standard deviations of the normal - priors of rewards, with shape (n_state, n_action). + :param np.ndarray rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). .. seealso:: @@ -50,8 +50,8 @@ def observe( (n_action, n_state, n_state). :param np.ndarray rew_sum: total rewards, with shape (n_state, n_action). - :param np.ndarray rew_count: the number of rewards, with - shape (n_state, n_action). + :param np.ndarray rew_count: the number of rewards, with shape + (n_state, n_action). """ self.updated = False self.p += p @@ -89,6 +89,8 @@ def value_iteration(p: np.ndarray, rew: np.ndarray, (n_action, n_state, n_state). :param np.ndarray rew: rewards, with shape (n_state, n_action). :param float epsilon: for precision control. + + :return: the optimal policy with shape (n_state, ). """ value = np.zeros(len(rew)) while True: From ee8a1978ec9d73f8aebb9816ef620a31d1bb1eca Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 6 Sep 2020 08:28:45 +0800 Subject: [PATCH 17/62] simplify PSRLModel.observe --- tianshou/policy/modelbase/psrl.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index ac2033ae3..2cc2f1105 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -27,11 +27,12 @@ def __init__( rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, ) -> None: + self.__eps = np.finfo(np.float32).eps.item() self.p = p_prior self.n_action, self.n_state, _ = p_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior - self.rew_count = np.zeros_like(rew_mean_prior) + self.rew_count = np.zeros_like(rew_mean_prior) + self.__eps self.policy: Optional[np.ndarray] = None self.updated = False @@ -55,17 +56,10 @@ def observe( """ self.updated = False self.p += p - sum_count_nonzero = np.where(self.rew_count + rew_count == 0, - 1, self.rew_count + rew_count) - rew_count_nonzero = np.where(rew_count == 0, 1, rew_count) - self.rew_mean = np.where(self.rew_count == 0, - np.where(rew_count == 0, self.rew_mean, - rew_sum / rew_count_nonzero), - (self.rew_mean * self.rew_count + rew_sum) - / sum_count_nonzero) - self.rew_std *= np.where(self.rew_count == 0, 1, - self.rew_count) / sum_count_nonzero - self.rew_count += rew_count + sum_count = self.rew_count + rew_count + self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count + self.rew_std *= self.rew_count / sum_count + self.rew_count = sum_count def get_p_ml(self) -> np.ndarray: return self.p / np.sum(self.p, axis=-1, keepdims=True) @@ -81,14 +75,15 @@ def solve_policy(self) -> None: self.get_p_ml(), self.sample_from_rew()) @staticmethod - def value_iteration(p: np.ndarray, rew: np.ndarray, - epsilon: float = 0.01) -> np.ndarray: + def value_iteration( + p: np.ndarray, rew: np.ndarray, eps: float = 0.01 + ) -> np.ndarray: """Value iteration solver for MDPs. :param np.ndarray p: transition probabilities, with shape (n_action, n_state, n_state). :param np.ndarray rew: rewards, with shape (n_state, n_action). - :param float epsilon: for precision control. + :param float eps: for precision control. :return: the optimal policy with shape (n_state, ). """ @@ -96,7 +91,7 @@ def value_iteration(p: np.ndarray, rew: np.ndarray, while True: Q = rew + np.matmul(p, value).T new_value = np.max(Q, axis=1) - if np.allclose(new_value, value, epsilon): + if np.allclose(new_value, value, eps): return np.argmax(Q, axis=1) else: value = new_value From e97f8c29ae032d56806e8648b0185d43b9f4ff6d Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 6 Sep 2020 08:44:12 +0800 Subject: [PATCH 18/62] remove unnecessary part --- tianshou/policy/modelbase/psrl.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 2cc2f1105..da356fbe3 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -113,9 +113,6 @@ class PSRLPolicy(BasePolicy): :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards. :param torch.distributions.Distribution dist_fn: for computing the action. - :param float discount_factor: in [0, 1]. - :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to ``False``. .. seealso:: @@ -128,37 +125,16 @@ def __init__( p_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, - discount_factor: float = 0, - reward_normalization: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.model = PSRLModel(p_prior, rew_mean_prior, rew_std_prior) - assert 0.0 <= discount_factor <= 1.0, \ - "discount factor should in [0, 1]" - self._gamma = discount_factor - self._rew_norm = reward_normalization self.eps = 0.0 def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" self.eps = eps - def process_fn( - self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray - ) -> Batch: - r"""Compute the discounted returns for each frame: - - .. math:: - G_t = \sum_{i=t}^T \gamma^{i-t}r_i - - , where :math:`T` is the terminal time step, :math:`\gamma` is the - discount factor, :math:`\gamma \in [0, 1]`. - """ - return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1.0, rew_norm=self._rew_norm - ) - def forward( self, batch: Batch, From c77510173dce945bad5fca8a4305360586254c58 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 6 Sep 2020 08:45:40 +0800 Subject: [PATCH 19/62] fix pep8 --- tianshou/policy/modelbase/psrl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index da356fbe3..acf9e7bba 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,8 +1,8 @@ import numpy as np from typing import Any, Dict, Union, Optional +from tianshou.data import Batch from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer class PSRLModel(object): From 12698ed49c1a06e4d491dcfb124e923bc5f3dec3 Mon Sep 17 00:00:00 2001 From: Yao Date: Sun, 6 Sep 2020 09:02:01 +0800 Subject: [PATCH 20/62] use onpolicy instead of offpolicy --- test/modelbase/test_psrl.py | 7 ++++--- tianshou/policy/modelbase/psrl.py | 28 +++++++++++++++------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 62cf0bc83..65062ff46 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PSRLPolicy -from tianshou.trainer import offpolicy_trainer +from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv @@ -21,6 +21,7 @@ def get_args(): parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=100) parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--repeat', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) @@ -74,9 +75,9 @@ def stop_fn(x): else: return False # trainer - result = offpolicy_trainer( + result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, + args.step_per_epoch, args.collect_per_step, args.repeat, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, writer=writer) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 2cc2f1105..18b225cb4 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -186,17 +186,19 @@ def forward( return Batch(act=act) def learn( # type: ignore - self, batch: Batch, **kwargs: Any - ) -> Dict[str, float]: - p = np.zeros((self.model.n_action, self.model.n_state, - self.model.n_state)) - rew_sum = np.zeros((self.model.n_state, self.model.n_action)) - rew_count = np.zeros_like(rew_sum) - a, r = batch.act, batch.returns - obs, obs_next = batch.obs, batch.obs_next - for i in range(len(obs)): - p[a[i]][obs[i]][obs_next[i]] += 1 - rew_sum[obs[i]][a[i]] += r[i] - rew_count[obs[i]][a[i]] += 1 - self.model.observe(p, rew_sum, rew_count) + self, batch: Batch, batch_size: int, repeat: int, + **kwargs: Any) -> Dict[str, float]: + for _ in range(repeat): + for b in batch.split(batch_size, merge_last=True): + p = np.zeros((self.model.n_action, self.model.n_state, + self.model.n_state)) + rew_sum = np.zeros((self.model.n_state, self.model.n_action)) + rew_count = np.zeros_like(rew_sum) + a, r = b.act, b.returns + obs, obs_next = b.obs, b.obs_next + for i in range(len(obs)): + p[a[i]][obs[i]][obs_next[i]] += 1 + rew_sum[obs[i]][a[i]] += r[i] + rew_count[obs[i]][a[i]] += 1 + self.model.observe(p, rew_sum, rew_count) return {} From 13b83e8f2df4475555cd7bcd7671c45a96d45c93 Mon Sep 17 00:00:00 2001 From: Yao Date: Sun, 6 Sep 2020 09:06:27 +0800 Subject: [PATCH 21/62] add discount factor --- tianshou/policy/modelbase/psrl.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index e4a324be9..77b9f1c4a 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,7 +1,7 @@ import numpy as np from typing import Any, Dict, Union, Optional -from tianshou.data import Batch +from tianshou.data import Batch, ReplayBuffer from tianshou.policy import BasePolicy @@ -125,16 +125,35 @@ def __init__( p_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, + discount_factor: float = 0, + reward_normalization: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.model = PSRLModel(p_prior, rew_mean_prior, rew_std_prior) + assert 0.0 <= discount_factor <= 1.0, \ + "discount factor should in [0, 1]" + self._gamma = discount_factor + self._rew_norm = reward_normalization self.eps = 0.0 def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" self.eps = eps + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: + r"""Compute the discounted returns for each frame: + .. math:: + G_t = \sum_{i=t}^T \gamma^{i-t}r_i + , where :math:`T` is the terminal time step, :math:`\gamma` is the + discount factor, :math:`\gamma \in [0, 1]`. + """ + return self.compute_episodic_return( + batch, gamma=self._gamma, gae_lambda=1.0, rew_norm=self._rew_norm + ) + def forward( self, batch: Batch, From 19fcce0cda1fa6c548225cd5fef16ab37761c07a Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 6 Sep 2020 09:16:53 +0800 Subject: [PATCH 22/62] fix config --- test/modelbase/test_psrl.py | 4 +-- tianshou/policy/modelbase/psrl.py | 41 +++++++++++++++---------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 65062ff46..117a1c4c8 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -19,8 +19,8 @@ def get_args(): parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=100) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=1) + parser.add_argument('--collect-per-step', type=int, default=1) parser.add_argument('--repeat', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 77b9f1c4a..4712ac34c 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,8 +1,8 @@ import numpy as np from typing import Any, Dict, Union, Optional -from tianshou.data import Batch, ReplayBuffer from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer class PSRLModel(object): @@ -112,6 +112,7 @@ class PSRLPolicy(BasePolicy): :param np.ndarray rew_mean_prior: means of the normal priors of rewards. :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards. + :param float discount_factor: in [0, 1]. :param torch.distributions.Distribution dist_fn: for computing the action. .. seealso:: @@ -125,8 +126,7 @@ def __init__( p_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, - discount_factor: float = 0, - reward_normalization: bool = False, + discount_factor: float = 0.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -134,7 +134,6 @@ def __init__( assert 0.0 <= discount_factor <= 1.0, \ "discount factor should in [0, 1]" self._gamma = discount_factor - self._rew_norm = reward_normalization self.eps = 0.0 def set_eps(self, eps: float) -> None: @@ -145,14 +144,16 @@ def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: r"""Compute the discounted returns for each frame: + .. math:: + G_t = \sum_{i=t}^T \gamma^{i-t}r_i + , where :math:`T` is the terminal time step, :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. """ return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1.0, rew_norm=self._rew_norm - ) + batch, gamma=self._gamma, gae_lambda=1.0) def forward( self, @@ -181,19 +182,17 @@ def forward( return Batch(act=act) def learn( # type: ignore - self, batch: Batch, batch_size: int, repeat: int, - **kwargs: Any) -> Dict[str, float]: - for _ in range(repeat): - for b in batch.split(batch_size, merge_last=True): - p = np.zeros((self.model.n_action, self.model.n_state, - self.model.n_state)) - rew_sum = np.zeros((self.model.n_state, self.model.n_action)) - rew_count = np.zeros_like(rew_sum) - a, r = b.act, b.returns - obs, obs_next = b.obs, b.obs_next - for i in range(len(obs)): - p[a[i]][obs[i]][obs_next[i]] += 1 - rew_sum[obs[i]][a[i]] += r[i] - rew_count[obs[i]][a[i]] += 1 - self.model.observe(p, rew_sum, rew_count) + self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any + ) -> Dict[str, float]: + p = np.zeros((self.model.n_action, self.model.n_state, + self.model.n_state)) + rew_sum = np.zeros((self.model.n_state, self.model.n_action)) + rew_count = np.zeros_like(rew_sum) + a, r = batch.act, batch.returns + obs, obs_next = batch.obs, batch.obs_next + for i in range(len(obs)): + p[a[i]][obs[i]][obs_next[i]] += 1 + rew_sum[obs[i]][a[i]] += r[i] + rew_count[obs[i]][a[i]] += 1 + self.model.observe(p, rew_sum, rew_count) return {} From bee48e59043783148e7327e49b48c3f45418d1c0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 6 Sep 2020 09:19:16 +0800 Subject: [PATCH 23/62] fix docstring --- tianshou/policy/modelbase/psrl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 4712ac34c..51395f6b4 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -113,7 +113,6 @@ class PSRLPolicy(BasePolicy): :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards. :param float discount_factor: in [0, 1]. - :param torch.distributions.Distribution dist_fn: for computing the action. .. seealso:: From f60fd9dbcb4f61f58276dcd4bfe99e0941669ae0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 6 Sep 2020 12:36:12 +0800 Subject: [PATCH 24/62] tune taxi --- examples/modelbase/README.md | 2 +- test/modelbase/test_psrl.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index a530d23fd..c45134152 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -4,4 +4,4 @@ `FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20 --step-per-epoch 500 --collect-per-step 100` -`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --epoch 20 --step-per-epoch 500 --collect-per-step 100` +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --gamma 0 --step-per-epoch 20` diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 117a1c4c8..00f7a1c24 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -15,13 +15,13 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='NChain-v0') parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0) + parser.add_argument('--gamma', type=float, default=1.0) + parser.add_argument('--eps-test', type=float, default=0.0) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=1) parser.add_argument('--collect-per-step', type=int, default=1) - parser.add_argument('--repeat', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) @@ -55,7 +55,7 @@ def test_psrl(args=get_args()): p_prior = 1e-3 * np.ones((n_action, n_state, n_state)) rew_mean_prior = np.zeros((n_state, n_action)) rew_std_prior = np.ones((n_state, n_action)) - policy = PSRLPolicy(p_prior, rew_mean_prior, rew_std_prior) + policy = PSRLPolicy(p_prior, rew_mean_prior, rew_std_prior, args.gamma) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) @@ -77,17 +77,19 @@ def stop_fn(x): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat, + args.step_per_epoch, args.collect_per_step, 1, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) policy.eval() - result = collector.collect(n_episode=1, render=args.render) + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') From 99f92392c3fbff919f34baaf8cc47ce97f51b167 Mon Sep 17 00:00:00 2001 From: Yao Date: Sun, 6 Sep 2020 15:33:40 +0800 Subject: [PATCH 25/62] add operations for absorbing states --- examples/modelbase/README.md | 2 +- tianshou/policy/modelbase/psrl.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index c45134152..2b12faa82 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -2,6 +2,6 @@ `NChain-v0`: `python3 psrl.py` -`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20 --step-per-epoch 500 --collect-per-step 100` +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20 --gamma 0 --step-per-epoch 1000` `Taxi-v3`: `python3 psrl.py --task Taxi-v3 --gamma 0 --step-per-epoch 20` diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 51395f6b4..3fe580189 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -29,6 +29,7 @@ def __init__( ) -> None: self.__eps = np.finfo(np.float32).eps.item() self.p = p_prior + self.p_prior_sum = np.sum(p_prior, axis=2) self.n_action, self.n_state, _ = p_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior @@ -60,6 +61,12 @@ def observe( self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count self.rew_std *= self.rew_count / sum_count self.rew_count = sum_count + if np.sum(self.p) > np.sum(self.p_prior_sum) + 1000: + min_index = np.argmin(np.sum(self.p, axis=2), axis=1) + mask = np.isclose(np.sum(self.p, axis=2), + self.p_prior_sum).astype("float32") + self.p[np.array(range(self.n_action)), min_index, min_index] += \ + mask[np.array(range(self.n_action)), min_index] def get_p_ml(self) -> np.ndarray: return self.p / np.sum(self.p, axis=-1, keepdims=True) From 9aa23940dfa65d90e1c55afceb7b2317a92b605d Mon Sep 17 00:00:00 2001 From: Yao Date: Sun, 6 Sep 2020 16:12:34 +0800 Subject: [PATCH 26/62] polish --- test/modelbase/test_psrl.py | 5 ++- tianshou/policy/modelbase/psrl.py | 67 +++++++++++++++++-------------- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 00f7a1c24..1174a5ee5 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -52,10 +52,11 @@ def test_psrl(args=get_args()): # model n_action = args.action_shape n_state = args.state_shape - p_prior = 1e-3 * np.ones((n_action, n_state, n_state)) + trans_count_prior = 1e-3 * np.ones((n_action, n_state, n_state)) rew_mean_prior = np.zeros((n_state, n_action)) rew_std_prior = np.ones((n_state, n_action)) - policy = PSRLPolicy(p_prior, rew_mean_prior, rew_std_prior, args.gamma) + policy = PSRLPolicy(trans_count_prior, + rew_mean_prior, rew_std_prior, args.gamma) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 3fe580189..46837e2fc 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -23,14 +23,14 @@ class PSRLModel(object): def __init__( self, - p_prior: np.ndarray, + trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, ) -> None: self.__eps = np.finfo(np.float32).eps.item() - self.p = p_prior - self.p_prior_sum = np.sum(p_prior, axis=2) - self.n_action, self.n_state, _ = p_prior.shape + self.trans_count = trans_count_prior + self.trans_count_prior_sum = np.sum(trans_count_prior, axis=2) + self.n_action, self.n_state, _ = trans_count_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior self.rew_count = np.zeros_like(rew_mean_prior) + self.__eps @@ -38,7 +38,8 @@ def __init__( self.updated = False def observe( - self, p: np.ndarray, rew_sum: np.ndarray, rew_count: np.ndarray + self, trans_count: np.ndarray, + rew_sum: np.ndarray, rew_count: np.ndarray ) -> None: """Add data into memory pool. @@ -48,7 +49,7 @@ def observe( standard deviations are in inverse proportion to the number of the corresponding observations. - :param np.ndarray p: the number of observations, with shape + :param np.ndarray trans_count: the number of observations, with shape (n_action, n_state, n_state). :param np.ndarray rew_sum: total rewards, with shape (n_state, n_action). @@ -56,20 +57,24 @@ def observe( (n_state, n_action). """ self.updated = False - self.p += p + self.trans_count += trans_count sum_count = self.rew_count + rew_count self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count self.rew_std *= self.rew_count / sum_count self.rew_count = sum_count - if np.sum(self.p) > np.sum(self.p_prior_sum) + 1000: - min_index = np.argmin(np.sum(self.p, axis=2), axis=1) - mask = np.isclose(np.sum(self.p, axis=2), - self.p_prior_sum).astype("float32") - self.p[np.array(range(self.n_action)), min_index, min_index] += \ + if np.sum(self.trans_count) > \ + np.sum(self.trans_count_prior_sum) + 1000: + min_index = np.argmin(np.sum(self.trans_count, axis=2), axis=1) + mask = np.isclose(np.sum(self.trans_count, axis=2), + self.trans_count_prior_sum).astype("float32") + self.trans_count[np.array(range(self.n_action)), + min_index, min_index] += \ mask[np.array(range(self.n_action)), min_index] - def get_p_ml(self) -> np.ndarray: - return self.p / np.sum(self.p, axis=-1, keepdims=True) + def get_trans_prob_ml(self) -> np.ndarray: + """Here ml means maximum likelihood.""" + return self.trans_count / np.sum(self.trans_count, + axis=-1, keepdims=True) def sample_from_rew(self) -> np.ndarray: sample_rew = np.random.randn(*self.rew_mean.shape) @@ -79,29 +84,30 @@ def sample_from_rew(self) -> np.ndarray: def solve_policy(self) -> None: self.updated = True self.policy = self.value_iteration( - self.get_p_ml(), self.sample_from_rew()) + self.get_trans_prob_ml(), self.sample_from_rew()) @staticmethod def value_iteration( - p: np.ndarray, rew: np.ndarray, eps: float = 0.01 + trans_prob: np.ndarray, rew: np.ndarray, eps: float = 0.01 ) -> np.ndarray: """Value iteration solver for MDPs. - :param np.ndarray p: transition probabilities, with shape + :param np.ndarray trans_prob: transition probabilities, with shape (n_action, n_state, n_state). :param np.ndarray rew: rewards, with shape (n_state, n_action). :param float eps: for precision control. :return: the optimal policy with shape (n_state, ). """ + # print(trans_prob, rew) value = np.zeros(len(rew)) - while True: - Q = rew + np.matmul(p, value).T + Q = rew + np.matmul(trans_prob, value).T + new_value = np.max(Q, axis=1) + while not np.allclose(new_value, value, eps): + value = new_value + Q = rew + np.matmul(trans_prob, value).T new_value = np.max(Q, axis=1) - if np.allclose(new_value, value, eps): - return np.argmax(Q, axis=1) - else: - value = new_value + return np.argmax(Q, axis=1) def __call__(self, obs: np.ndarray, state=None, info=None) -> np.ndarray: if self.updated is False: @@ -115,7 +121,7 @@ class PSRLPolicy(BasePolicy): Reference: Strens M. A Bayesian framework for reinforcement learning [C] //ICML. 2000, 2000: 943-950. - :param np.ndarray p_prior: dirichlet prior (alphas). + :param np.ndarray trans_count_prior: dirichlet prior (alphas). :param np.ndarray rew_mean_prior: means of the normal priors of rewards. :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards. @@ -129,14 +135,15 @@ class PSRLPolicy(BasePolicy): def __init__( self, - p_prior: np.ndarray, + trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, discount_factor: float = 0.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.model = PSRLModel(p_prior, rew_mean_prior, rew_std_prior) + self.model = PSRLModel(trans_count_prior, rew_mean_prior, + rew_std_prior) assert 0.0 <= discount_factor <= 1.0, \ "discount factor should in [0, 1]" self._gamma = discount_factor @@ -190,15 +197,15 @@ def forward( def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, float]: - p = np.zeros((self.model.n_action, self.model.n_state, - self.model.n_state)) + trans_count = np.zeros((self.model.n_action, self.model.n_state, + self.model.n_state)) rew_sum = np.zeros((self.model.n_state, self.model.n_action)) rew_count = np.zeros_like(rew_sum) a, r = batch.act, batch.returns obs, obs_next = batch.obs, batch.obs_next for i in range(len(obs)): - p[a[i]][obs[i]][obs_next[i]] += 1 + trans_count[a[i]][obs[i]][obs_next[i]] += 1 rew_sum[obs[i]][a[i]] += r[i] rew_count[obs[i]][a[i]] += 1 - self.model.observe(p, rew_sum, rew_count) + self.model.observe(trans_count, rew_sum, rew_count) return {} From f580c4c5764ac8b569f20e047bd9f184c34d0263 Mon Sep 17 00:00:00 2001 From: Yao Date: Sun, 6 Sep 2020 16:30:24 +0800 Subject: [PATCH 27/62] polish --- tianshou/policy/modelbase/psrl.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 46837e2fc..10cb5403a 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -62,6 +62,12 @@ def observe( self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count self.rew_std *= self.rew_count / sum_count self.rew_count = sum_count + # I find that maybe we did not gather information + # for absorbing states' transition probabilities. + # After 1000 observations, if for some states, their + # trans_count almost do not change, the code adds + # the counts of these states transiting to themselves + # by 1. if np.sum(self.trans_count) > \ np.sum(self.trans_count_prior_sum) + 1000: min_index = np.argmin(np.sum(self.trans_count, axis=2), axis=1) @@ -71,8 +77,7 @@ def observe( min_index, min_index] += \ mask[np.array(range(self.n_action)), min_index] - def get_trans_prob_ml(self) -> np.ndarray: - """Here ml means maximum likelihood.""" + def get_trans_prob_max_likelihood(self) -> np.ndarray: return self.trans_count / np.sum(self.trans_count, axis=-1, keepdims=True) @@ -84,7 +89,7 @@ def sample_from_rew(self) -> np.ndarray: def solve_policy(self) -> None: self.updated = True self.policy = self.value_iteration( - self.get_trans_prob_ml(), self.sample_from_rew()) + self.get_trans_prob_max_likelihood(), self.sample_from_rew()) @staticmethod def value_iteration( @@ -99,7 +104,6 @@ def value_iteration( :return: the optimal policy with shape (n_state, ). """ - # print(trans_prob, rew) value = np.zeros(len(rew)) Q = rew + np.matmul(trans_prob, value).T new_value = np.max(Q, axis=1) From 40df71c424d68033f1f2041598a446be09bd18de Mon Sep 17 00:00:00 2001 From: Yao Date: Sun, 6 Sep 2020 16:46:40 +0800 Subject: [PATCH 28/62] fix gamma=0 --- examples/modelbase/README.md | 4 ++-- test/modelbase/test_psrl.py | 3 +-- tianshou/policy/modelbase/psrl.py | 24 ++---------------------- 3 files changed, 5 insertions(+), 26 deletions(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index 2b12faa82..447957be1 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -2,6 +2,6 @@ `NChain-v0`: `python3 psrl.py` -`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20 --gamma 0 --step-per-epoch 1000` +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20 --step-per-epoch 1000` -`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --gamma 0 --step-per-epoch 20` +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 20` diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 1174a5ee5..f63b7cce5 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -15,7 +15,6 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='NChain-v0') parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--gamma', type=float, default=1.0) parser.add_argument('--eps-test', type=float, default=0.0) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) @@ -56,7 +55,7 @@ def test_psrl(args=get_args()): rew_mean_prior = np.zeros((n_state, n_action)) rew_std_prior = np.ones((n_state, n_action)) policy = PSRLPolicy(trans_count_prior, - rew_mean_prior, rew_std_prior, args.gamma) + rew_mean_prior, rew_std_prior) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 10cb5403a..1feb44685 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Union, Optional from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch class PSRLModel(object): @@ -129,7 +129,6 @@ class PSRLPolicy(BasePolicy): :param np.ndarray rew_mean_prior: means of the normal priors of rewards. :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards. - :param float discount_factor: in [0, 1]. .. seealso:: @@ -142,36 +141,17 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, - discount_factor: float = 0.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.model = PSRLModel(trans_count_prior, rew_mean_prior, rew_std_prior) - assert 0.0 <= discount_factor <= 1.0, \ - "discount factor should in [0, 1]" - self._gamma = discount_factor self.eps = 0.0 def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" self.eps = eps - def process_fn( - self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray - ) -> Batch: - r"""Compute the discounted returns for each frame: - - .. math:: - - G_t = \sum_{i=t}^T \gamma^{i-t}r_i - - , where :math:`T` is the terminal time step, :math:`\gamma` is the - discount factor, :math:`\gamma \in [0, 1]`. - """ - return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1.0) - def forward( self, batch: Batch, @@ -205,7 +185,7 @@ def learn( # type: ignore self.model.n_state)) rew_sum = np.zeros((self.model.n_state, self.model.n_action)) rew_count = np.zeros_like(rew_sum) - a, r = batch.act, batch.returns + a, r = batch.act, batch.rew obs, obs_next = batch.obs, batch.obs_next for i in range(len(obs)): trans_count[a[i]][obs[i]][obs_next[i]] += 1 From 874fb68c488531dc17d030c89805ed50b0616a4f Mon Sep 17 00:00:00 2001 From: Yao Date: Mon, 7 Sep 2020 12:27:59 +0800 Subject: [PATCH 29/62] use sampling, delete epsilon greedy --- examples/modelbase/README.md | 4 ++-- test/modelbase/test_psrl.py | 14 ++------------ tianshou/policy/modelbase/psrl.py | 30 +++++++++++------------------- 3 files changed, 15 insertions(+), 33 deletions(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index 447957be1..073748472 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -1,7 +1,7 @@ # PSRL -`NChain-v0`: `python3 psrl.py` +`NChain-v0`: `python3 psrl.py --step-per-epoch 100` `FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20 --step-per-epoch 1000` -`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 20` +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --epoch 20 --step-per-epoch 100` diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index f63b7cce5..271d3d570 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -15,8 +15,6 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='NChain-v0') parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0.0) - parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=1) @@ -51,7 +49,7 @@ def test_psrl(args=get_args()): # model n_action = args.action_shape n_state = args.state_shape - trans_count_prior = 1e-3 * np.ones((n_action, n_state, n_state)) + trans_count_prior = np.ones((n_action, n_state, n_state)) rew_mean_prior = np.zeros((n_state, n_action)) rew_std_prior = np.ones((n_state, n_action)) policy = PSRLPolicy(trans_count_prior, @@ -63,12 +61,6 @@ def test_psrl(args=get_args()): # log writer = SummaryWriter(args.logdir + '/' + args.task) - def train_fn(x): - policy.set_eps(args.eps_train) - - def test_fn(x): - policy.set_eps(args.eps_test) - def stop_fn(x): if env.env.spec.reward_threshold: return x >= env.spec.reward_threshold @@ -78,14 +70,12 @@ def stop_fn(x): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, 1, - args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, writer=writer) + args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! policy.eval() - policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 1feb44685..cfa46e213 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -27,13 +27,12 @@ def __init__( rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, ) -> None: - self.__eps = np.finfo(np.float32).eps.item() self.trans_count = trans_count_prior self.trans_count_prior_sum = np.sum(trans_count_prior, axis=2) self.n_action, self.n_state, _ = trans_count_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior - self.rew_count = np.zeros_like(rew_mean_prior) + self.__eps + self.rew_count = np.ones_like(rew_mean_prior) self.policy: Optional[np.ndarray] = None self.updated = False @@ -67,7 +66,7 @@ def observe( # After 1000 observations, if for some states, their # trans_count almost do not change, the code adds # the counts of these states transiting to themselves - # by 1. + # by 100. if np.sum(self.trans_count) > \ np.sum(self.trans_count_prior_sum) + 1000: min_index = np.argmin(np.sum(self.trans_count, axis=2), axis=1) @@ -75,11 +74,15 @@ def observe( self.trans_count_prior_sum).astype("float32") self.trans_count[np.array(range(self.n_action)), min_index, min_index] += \ - mask[np.array(range(self.n_action)), min_index] + mask[np.array(range(self.n_action)), min_index] * 100 - def get_trans_prob_max_likelihood(self) -> np.ndarray: - return self.trans_count / np.sum(self.trans_count, - axis=-1, keepdims=True) + def sample_from_prob(self) -> np.ndarray: + sample_prob = np.zeros_like(self.trans_count) + for i in range(self.n_action): + for j in range(self.n_state): + sample_prob[i][j] = np.random.dirichlet( + self.trans_count[i][j]) + return sample_prob def sample_from_rew(self) -> np.ndarray: sample_rew = np.random.randn(*self.rew_mean.shape) @@ -89,7 +92,7 @@ def sample_from_rew(self) -> np.ndarray: def solve_policy(self) -> None: self.updated = True self.policy = self.value_iteration( - self.get_trans_prob_max_likelihood(), self.sample_from_rew()) + self.sample_from_prob(), self.sample_from_rew()) @staticmethod def value_iteration( @@ -146,11 +149,6 @@ def __init__( super().__init__(**kwargs) self.model = PSRLModel(trans_count_prior, rew_mean_prior, rew_std_prior) - self.eps = 0.0 - - def set_eps(self, eps: float) -> None: - """Set the eps for epsilon-greedy exploration.""" - self.eps = eps def forward( self, @@ -170,12 +168,6 @@ def forward( more detailed explanation. """ act = self.model(batch.obs, state=state, info=batch.info) - if eps is None: - eps = self.eps - if not np.isclose(eps, 0): - for i in range(len(act)): - if np.random.rand() < eps: - act[i] = np.random.randint(0, self.model.n_action) return Batch(act=act) def learn( # type: ignore From 5fb3a77c16585cc8dd020bb66098fcb57a99ab21 Mon Sep 17 00:00:00 2001 From: Yao Date: Mon, 7 Sep 2020 21:21:22 +0800 Subject: [PATCH 30/62] polish --- tianshou/policy/modelbase/psrl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index cfa46e213..7f87e3da2 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -85,8 +85,7 @@ def sample_from_prob(self) -> np.ndarray: return sample_prob def sample_from_rew(self) -> np.ndarray: - sample_rew = np.random.randn(*self.rew_mean.shape) - sample_rew = sample_rew * self.rew_std + self.rew_mean + sample_rew = np.random.normal(self.rew_mean, self.rew_std) return sample_rew def solve_policy(self) -> None: From 3a4e45472d448c280bce4c7231e295a7be967e7d Mon Sep 17 00:00:00 2001 From: n+e Date: Mon, 7 Sep 2020 21:39:04 +0800 Subject: [PATCH 31/62] Update tianshou/policy/modelbase/psrl.py --- tianshou/policy/modelbase/psrl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 7f87e3da2..4a8088ed1 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -85,8 +85,7 @@ def sample_from_prob(self) -> np.ndarray: return sample_prob def sample_from_rew(self) -> np.ndarray: - sample_rew = np.random.normal(self.rew_mean, self.rew_std) - return sample_rew + return np.random.normal(self.rew_mean, self.rew_std) def solve_policy(self) -> None: self.updated = True From ea766bda0e50a3bc5790e525b6667721283c7f17 Mon Sep 17 00:00:00 2001 From: Yao Date: Tue, 8 Sep 2020 13:20:48 +0800 Subject: [PATCH 32/62] polish --- tianshou/policy/modelbase/psrl.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 7f87e3da2..111dfa12a 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -61,20 +61,6 @@ def observe( self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count self.rew_std *= self.rew_count / sum_count self.rew_count = sum_count - # I find that maybe we did not gather information - # for absorbing states' transition probabilities. - # After 1000 observations, if for some states, their - # trans_count almost do not change, the code adds - # the counts of these states transiting to themselves - # by 100. - if np.sum(self.trans_count) > \ - np.sum(self.trans_count_prior_sum) + 1000: - min_index = np.argmin(np.sum(self.trans_count, axis=2), axis=1) - mask = np.isclose(np.sum(self.trans_count, axis=2), - self.trans_count_prior_sum).astype("float32") - self.trans_count[np.array(range(self.n_action)), - min_index, min_index] += \ - mask[np.array(range(self.n_action)), min_index] * 100 def sample_from_prob(self) -> np.ndarray: sample_prob = np.zeros_like(self.trans_count) @@ -182,5 +168,10 @@ def learn( # type: ignore trans_count[a[i]][obs[i]][obs_next[i]] += 1 rew_sum[obs[i]][a[i]] += r[i] rew_count[obs[i]][a[i]] += 1 + if batch.done[i]: + if hasattr(batch.info, 'TimeLimit.truncated') \ + and batch.info['TimeLimit.truncated'][i]: + continue + trans_count[:, obs_next[i], obs_next[i]] += 1 self.model.observe(trans_count, rew_sum, rew_count) return {} From 9222afa59a1ddeeeed217cc01cb0f63b408d63ce Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 9 Sep 2020 15:51:53 +0800 Subject: [PATCH 33/62] polish --- test/modelbase/test_psrl.py | 13 +++++++------ tianshou/policy/modelbase/psrl.py | 31 ++++++++++++++++--------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 271d3d570..ebfb564cb 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -17,9 +17,8 @@ def get_args(): parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=1) + parser.add_argument('--step-per-epoch', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=1) - parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -30,10 +29,10 @@ def get_args(): def test_psrl(args=get_args()): env = gym.make(args.task) if args.task == "NChain-v0": - env.spec.reward_threshold = 3650 # discribed in PSRL paper + env.spec.reward_threshold = 3650 # described in PSRL paper print("reward threshold:", env.spec.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.env.action_space.shape or env.env.action_space.n + args.action_shape = env.action_space.shape or env.action_space.n # train_envs = gym.make(args.task) # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( @@ -62,7 +61,7 @@ def test_psrl(args=get_args()): writer = SummaryWriter(args.logdir + '/' + args.task) def stop_fn(x): - if env.env.spec.reward_threshold: + if env.spec.reward_threshold: return x >= env.spec.reward_threshold else: return False @@ -70,7 +69,7 @@ def stop_fn(x): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, 1, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.test_num, 0, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) @@ -81,6 +80,8 @@ def stop_fn(x): result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') + elif env.spec.reward_threshold: + assert result["best_reward"] >= env.spec.reward_threshold if __name__ == '__main__': diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 6de8d6e0f..b1be1d6d4 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,8 +1,8 @@ import numpy as np from typing import Any, Dict, Union, Optional -from tianshou.policy import BasePolicy from tianshou.data import Batch +from tianshou.policy import BasePolicy class PSRLModel(object): @@ -37,8 +37,10 @@ def __init__( self.updated = False def observe( - self, trans_count: np.ndarray, - rew_sum: np.ndarray, rew_count: np.ndarray + self, + trans_count: np.ndarray, + rew_sum: np.ndarray, + rew_count: np.ndarray ) -> None: """Add data into memory pool. @@ -131,8 +133,8 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.model = PSRLModel(trans_count_prior, rew_mean_prior, - rew_std_prior) + self.model = PSRLModel( + trans_count_prior, rew_mean_prior, rew_std_prior) def forward( self, @@ -151,22 +153,21 @@ def forward( Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ - act = self.model(batch.obs, state=state, info=batch.info) - return Batch(act=act) + return Batch(act=self.model(batch.obs, state=state, info=batch.info)) def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, float]: - trans_count = np.zeros((self.model.n_action, self.model.n_state, - self.model.n_state)) - rew_sum = np.zeros((self.model.n_state, self.model.n_action)) - rew_count = np.zeros_like(rew_sum) - a, r = batch.act, batch.rew + n_s, n_a = self.model.n_state, self.model.n_action + trans_count = np.zeros((n_a, n_s, n_s)) + rew_sum = np.zeros((n_s, n_a)) + rew_count = np.zeros((n_s, n_a)) + act, rew = batch.act, batch.rew obs, obs_next = batch.obs, batch.obs_next for i in range(len(obs)): - trans_count[a[i]][obs[i]][obs_next[i]] += 1 - rew_sum[obs[i]][a[i]] += r[i] - rew_count[obs[i]][a[i]] += 1 + trans_count[act[i]][obs[i]][obs_next[i]] += 1 + rew_sum[obs[i]][act[i]] += rew[i] + rew_count[obs[i]][act[i]] += 1 if batch.done[i]: if hasattr(batch.info, 'TimeLimit.truncated') \ and batch.info['TimeLimit.truncated'][i]: From 336718fc0e8a117e6267cdd2c09d61bacc5aca1e Mon Sep 17 00:00:00 2001 From: Yao Date: Wed, 9 Sep 2020 16:34:42 +0800 Subject: [PATCH 34/62] polish --- test/modelbase/test_psrl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 271d3d570..fdf9b98e0 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -51,7 +51,7 @@ def test_psrl(args=get_args()): n_state = args.state_shape trans_count_prior = np.ones((n_action, n_state, n_state)) rew_mean_prior = np.zeros((n_state, n_action)) - rew_std_prior = np.ones((n_state, n_action)) + rew_std_prior = 10 * np.ones((n_state, n_action)) policy = PSRLPolicy(trans_count_prior, rew_mean_prior, rew_std_prior) # collector From c759161af29b3bfec7b4c78bc1f555fe257a97ce Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 9 Sep 2020 17:23:59 +0800 Subject: [PATCH 35/62] add rew-mean-prior and rew-std-prior argument in test_psrl --- examples/modelbase/README.md | 2 +- test/modelbase/test_psrl.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index 073748472..2a298c005 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -1,6 +1,6 @@ # PSRL -`NChain-v0`: `python3 psrl.py --step-per-epoch 100` +`NChain-v0`: `python3 psrl.py` `FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20 --step-per-epoch 1000` diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index c36a888bd..9f2813e00 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -22,7 +22,9 @@ def get_args(): parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--render', type=float, default=0.0) + parser.add_argument('--rew-mean-prior', type=float, default=0.0) + parser.add_argument('--rew-std-prior', type=float, default=1.0) return parser.parse_known_args()[0] @@ -49,10 +51,10 @@ def test_psrl(args=get_args()): n_action = args.action_shape n_state = args.state_shape trans_count_prior = np.ones((n_action, n_state, n_state)) - rew_mean_prior = np.zeros((n_state, n_action)) - rew_std_prior = 10 * np.ones((n_state, n_action)) - policy = PSRLPolicy(trans_count_prior, - rew_mean_prior, rew_std_prior) + rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) + rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) + policy = PSRLPolicy( + trans_count_prior, rew_mean_prior, rew_std_prior) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) @@ -69,7 +71,8 @@ def stop_fn(x): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, 1, - args.test_num, 0, stop_fn=stop_fn, writer=writer) + args.test_num, 0, stop_fn=stop_fn, writer=writer, + test_in_train=False) if __name__ == '__main__': pprint.pprint(result) From a0cf86d9cee2288cef5061d25018de52a2a34de8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 9 Sep 2020 17:36:11 +0800 Subject: [PATCH 36/62] fix test --- test/modelbase/test_psrl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 9f2813e00..bac30e5b1 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -16,7 +16,7 @@ def get_args(): parser.add_argument('--task', type=str, default='NChain-v0') parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=1) parser.add_argument('--training-num', type=int, default=8) @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.0) parser.add_argument('--rew-mean-prior', type=float, default=0.0) - parser.add_argument('--rew-std-prior', type=float, default=1.0) + parser.add_argument('--rew-std-prior', type=float, default=10.0) return parser.parse_known_args()[0] From 3d1ffb8da8dd5327edade4fd85ffc255243f69f7 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 9 Sep 2020 18:11:24 +0800 Subject: [PATCH 37/62] change (a, s, s) to (s, a, s) --- test/modelbase/test_psrl.py | 2 +- tianshou/policy/modelbase/psrl.py | 39 ++++++++++++++----------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index bac30e5b1..76699111b 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -50,7 +50,7 @@ def test_psrl(args=get_args()): # model n_action = args.action_shape n_state = args.state_shape - trans_count_prior = np.ones((n_action, n_state, n_state)) + trans_count_prior = np.ones((n_state, n_action, n_state)) rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) policy = PSRLPolicy( diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index b1be1d6d4..96b5f4bb1 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -9,16 +9,11 @@ class PSRLModel(object): """Implementation of Posterior Sampling Reinforcement Learning Model. :param np.ndarray p_prior: dirichlet prior (alphas), with shape - (n_action, n_state, n_state). + (n_state, n_action, n_state). :param np.ndarray rew_mean_prior: means of the normal priors of rewards, with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. """ def __init__( @@ -28,8 +23,7 @@ def __init__( rew_std_prior: np.ndarray, ) -> None: self.trans_count = trans_count_prior - self.trans_count_prior_sum = np.sum(trans_count_prior, axis=2) - self.n_action, self.n_state, _ = trans_count_prior.shape + self.n_state, self.n_action = rew_mean_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior self.rew_count = np.ones_like(rew_mean_prior) @@ -51,7 +45,7 @@ def observe( corresponding observations. :param np.ndarray trans_count: the number of observations, with shape - (n_action, n_state, n_state). + (n_state, n_action, n_state). :param np.ndarray rew_sum: total rewards, with shape (n_state, n_action). :param np.ndarray rew_count: the number of rewards, with shape @@ -66,8 +60,8 @@ def observe( def sample_from_prob(self) -> np.ndarray: sample_prob = np.zeros_like(self.trans_count) - for i in range(self.n_action): - for j in range(self.n_state): + for i in range(self.n_state): + for j in range(self.n_action): sample_prob[i][j] = np.random.dirichlet( self.trans_count[i][j]) return sample_prob @@ -94,11 +88,12 @@ def value_iteration( :return: the optimal policy with shape (n_state, ). """ value = np.zeros(len(rew)) - Q = rew + np.matmul(trans_prob, value).T - new_value = np.max(Q, axis=1) + print(trans_prob.shape, value.shape) + Q = rew + trans_prob.dot(value) # (s, a) = (s, a) + (s, a, s) * (s) + new_value = np.max(Q, axis=1) # (s) = (s, a).max(axis=1) while not np.allclose(new_value, value, eps): value = new_value - Q = rew + np.matmul(trans_prob, value).T + Q = rew + trans_prob.dot(value) new_value = np.max(Q, axis=1) return np.argmax(Q, axis=1) @@ -114,10 +109,12 @@ class PSRLPolicy(BasePolicy): Reference: Strens M. A Bayesian framework for reinforcement learning [C] //ICML. 2000, 2000: 943-950. - :param np.ndarray trans_count_prior: dirichlet prior (alphas). - :param np.ndarray rew_mean_prior: means of the normal priors of rewards. - :param np.ndarray rew_std_prior: standard deviations of the normal - priors of rewards. + :param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape + (n_state, n_action, n_state). + :param np.ndarray rew_mean_prior: means of the normal priors of rewards, + with shape (n_state, n_action). + :param np.ndarray rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). .. seealso:: @@ -159,19 +156,19 @@ def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, float]: n_s, n_a = self.model.n_state, self.model.n_action - trans_count = np.zeros((n_a, n_s, n_s)) + trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a)) act, rew = batch.act, batch.rew obs, obs_next = batch.obs, batch.obs_next for i in range(len(obs)): - trans_count[act[i]][obs[i]][obs_next[i]] += 1 + trans_count[obs[i]][act[i]][obs_next[i]] += 1 rew_sum[obs[i]][act[i]] += rew[i] rew_count[obs[i]][act[i]] += 1 if batch.done[i]: if hasattr(batch.info, 'TimeLimit.truncated') \ and batch.info['TimeLimit.truncated'][i]: continue - trans_count[:, obs_next[i], obs_next[i]] += 1 + trans_count[obs_next[i], :, obs_next[i]] += 1 self.model.observe(trans_count, rew_sum, rew_count) return {} From 6e9ba48b63588ba96f5c50630c07cf1c5f5e81a9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 9 Sep 2020 18:58:23 +0800 Subject: [PATCH 38/62] add value iteration eps to arguments --- test/modelbase/test_psrl.py | 3 +- tianshou/policy/modelbase/psrl.py | 48 +++++++++++++++++-------------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 76699111b..cdb3bbbbc 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -25,6 +25,7 @@ def get_args(): parser.add_argument('--render', type=float, default=0.0) parser.add_argument('--rew-mean-prior', type=float, default=0.0) parser.add_argument('--rew-std-prior', type=float, default=10.0) + parser.add_argument('--eps', type=float, default=0.01) return parser.parse_known_args()[0] @@ -54,7 +55,7 @@ def test_psrl(args=get_args()): rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) policy = PSRLPolicy( - trans_count_prior, rew_mean_prior, rew_std_prior) + trans_count_prior, rew_mean_prior, rew_std_prior, args.eps) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 96b5f4bb1..437ffeb94 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -14,6 +14,7 @@ class PSRLModel(object): with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). + :param float epsilon: for precision control in value iteration. """ def __init__( @@ -21,12 +22,14 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, + epsilon: float, ) -> None: self.trans_count = trans_count_prior self.n_state, self.n_action = rew_mean_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior self.rew_count = np.ones_like(rew_mean_prior) + self.eps = epsilon self.policy: Optional[np.ndarray] = None self.updated = False @@ -58,12 +61,13 @@ def observe( self.rew_std *= self.rew_count / sum_count self.rew_count = sum_count - def sample_from_prob(self) -> np.ndarray: - sample_prob = np.zeros_like(self.trans_count) - for i in range(self.n_state): - for j in range(self.n_action): - sample_prob[i][j] = np.random.dirichlet( - self.trans_count[i][j]) + @staticmethod + def sample_from_prob(trans_count: np.ndarray) -> np.ndarray: + sample_prob = np.zeros_like(trans_count) + n_s, n_a = trans_count.shape[:2] + for i in range(n_s): + for j in range(n_a): # numba does not support dirichlet :( + sample_prob[i][j] = np.random.dirichlet(trans_count[i][j]) return sample_prob def sample_from_rew(self) -> np.ndarray: @@ -72,11 +76,14 @@ def sample_from_rew(self) -> np.ndarray: def solve_policy(self) -> None: self.updated = True self.policy = self.value_iteration( - self.sample_from_prob(), self.sample_from_rew()) + self.sample_from_prob(self.trans_count), + self.sample_from_rew(), + self.eps, + ) @staticmethod def value_iteration( - trans_prob: np.ndarray, rew: np.ndarray, eps: float = 0.01 + trans_prob: np.ndarray, rew: np.ndarray, eps: float ) -> np.ndarray: """Value iteration solver for MDPs. @@ -87,18 +94,16 @@ def value_iteration( :return: the optimal policy with shape (n_state, ). """ - value = np.zeros(len(rew)) - print(trans_prob.shape, value.shape) - Q = rew + trans_prob.dot(value) # (s, a) = (s, a) + (s, a, s) * (s) - new_value = np.max(Q, axis=1) # (s) = (s, a).max(axis=1) + value = -np.nan + new_value = rew.max(axis=1) while not np.allclose(new_value, value, eps): value = new_value Q = rew + trans_prob.dot(value) - new_value = np.max(Q, axis=1) - return np.argmax(Q, axis=1) + new_value = Q.max(axis=1) + return Q.argmax(axis=1) def __call__(self, obs: np.ndarray, state=None, info=None) -> np.ndarray: - if self.updated is False: + if not self.updated: self.solve_policy() return self.policy[obs] @@ -115,6 +120,7 @@ class PSRLPolicy(BasePolicy): with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). + :param float epsilon: for precision control in value iteration. .. seealso:: @@ -127,17 +133,17 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, + epsilon: float = 0.01, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.model = PSRLModel( - trans_count_prior, rew_mean_prior, rew_std_prior) + trans_count_prior, rew_mean_prior, rew_std_prior, epsilon) def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, - eps: Optional[float] = None, **kwargs: Any, ) -> Batch: """Compute action over the given batch data with PSRL model. @@ -153,7 +159,7 @@ def forward( return Batch(act=self.model(batch.obs, state=state, info=batch.info)) def learn( # type: ignore - self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any + self, batch: Batch, *args: Any, **kwargs: Any ) -> Dict[str, float]: n_s, n_a = self.model.n_state, self.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) @@ -162,9 +168,9 @@ def learn( # type: ignore act, rew = batch.act, batch.rew obs, obs_next = batch.obs, batch.obs_next for i in range(len(obs)): - trans_count[obs[i]][act[i]][obs_next[i]] += 1 - rew_sum[obs[i]][act[i]] += rew[i] - rew_count[obs[i]][act[i]] += 1 + trans_count[obs[i], act[i], obs_next[i]] += 1 + rew_sum[obs[i], act[i]] += rew[i] + rew_count[obs[i], act[i]] += 1 if batch.done[i]: if hasattr(batch.info, 'TimeLimit.truncated') \ and batch.info['TimeLimit.truncated'][i]: From 71c1e0faab632da6decf7a3d17f6bdb41ef5b795 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 9 Sep 2020 21:48:34 +0800 Subject: [PATCH 39/62] add rew-count-prior and improve value-iteration efficiency --- test/modelbase/test_psrl.py | 9 ++++++-- tianshou/policy/modelbase/psrl.py | 36 +++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index cdb3bbbbc..0d7d8aabb 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -23,8 +23,9 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.0) - parser.add_argument('--rew-mean-prior', type=float, default=0.0) + parser.add_argument('--rew-mean-prior', type=float, default=1.0) parser.add_argument('--rew-std-prior', type=float, default=10.0) + parser.add_argument('--rew-count-prior', type=int, default=1) parser.add_argument('--eps', type=float, default=0.01) return parser.parse_known_args()[0] @@ -54,8 +55,10 @@ def test_psrl(args=get_args()): trans_count_prior = np.ones((n_state, n_action, n_state)) rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) + rew_count_prior = np.full((n_state, n_action), args.rew_count_prior) policy = PSRLPolicy( - trans_count_prior, rew_mean_prior, rew_std_prior, args.eps) + trans_count_prior, rew_mean_prior, rew_std_prior, rew_count_prior, + args.eps) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) @@ -64,6 +67,8 @@ def test_psrl(args=get_args()): writer = SummaryWriter(args.logdir + '/' + args.task) def stop_fn(x): + print(policy.model.rew_mean, policy.model.rew_mean.mean()) + print(policy.model.rew_std, policy.model.rew_std.mean()) if env.spec.reward_threshold: return x >= env.spec.reward_threshold else: diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 437ffeb94..dae360807 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -14,6 +14,8 @@ class PSRLModel(object): with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). + :param np.ndarray rew_count_prior: count (weight) of the normal priors of + rewards, with shape (n_state, n_action). :param float epsilon: for precision control in value iteration. """ @@ -22,15 +24,17 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, + rew_count_prior: np.ndarray, epsilon: float, ) -> None: self.trans_count = trans_count_prior self.n_state, self.n_action = rew_mean_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior - self.rew_count = np.ones_like(rew_mean_prior) + self.rew_count = rew_count_prior self.eps = epsilon self.policy: Optional[np.ndarray] = None + self.value = np.zeros(self.n_state) self.updated = False def observe( @@ -75,15 +79,16 @@ def sample_from_rew(self) -> np.ndarray: def solve_policy(self) -> None: self.updated = True - self.policy = self.value_iteration( + self.policy, self.value = self.value_iteration( self.sample_from_prob(self.trans_count), self.sample_from_rew(), self.eps, + self.value, ) @staticmethod def value_iteration( - trans_prob: np.ndarray, rew: np.ndarray, eps: float + trans_prob: np.ndarray, rew: np.ndarray, eps: float, value: np.ndarray ) -> np.ndarray: """Value iteration solver for MDPs. @@ -91,16 +96,18 @@ def value_iteration( (n_action, n_state, n_state). :param np.ndarray rew: rewards, with shape (n_state, n_action). :param float eps: for precision control. + :param np.ndarray value: the initialize value of value array, with + shape (n_state, ). :return: the optimal policy with shape (n_state, ). """ - value = -np.nan - new_value = rew.max(axis=1) + Q = rew + trans_prob.dot(value) + new_value = Q.max(axis=1) while not np.allclose(new_value, value, eps): value = new_value Q = rew + trans_prob.dot(value) new_value = Q.max(axis=1) - return Q.argmax(axis=1) + return Q.argmax(axis=1), new_value def __call__(self, obs: np.ndarray, state=None, info=None) -> np.ndarray: if not self.updated: @@ -120,6 +127,8 @@ class PSRLPolicy(BasePolicy): with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). + :param np.ndarray rew_count_prior: count (weight) of the normal priors of + rewards, with shape (n_state, n_action). :param float epsilon: for precision control in value iteration. .. seealso:: @@ -133,12 +142,13 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, + rew_count_prior: np.ndarray, epsilon: float = 0.01, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.model = PSRLModel( - trans_count_prior, rew_mean_prior, rew_std_prior, epsilon) + self.model = PSRLModel(trans_count_prior, rew_mean_prior, + rew_std_prior, rew_count_prior, epsilon) def forward( self, @@ -172,9 +182,13 @@ def learn( # type: ignore rew_sum[obs[i], act[i]] += rew[i] rew_count[obs[i], act[i]] += 1 if batch.done[i]: - if hasattr(batch.info, 'TimeLimit.truncated') \ - and batch.info['TimeLimit.truncated'][i]: + if hasattr(batch.info, "TimeLimit.truncated") \ + and batch.info["TimeLimit.truncated"][i]: continue trans_count[obs_next[i], :, obs_next[i]] += 1 + rew_count[obs_next[i], :] += 1 self.model.observe(trans_count, rew_sum, rew_count) - return {} + return { + "psrl/rew_mean": self.model.rew_mean.mean(), + "psrl/rew_std": self.model.rew_std.mean(), + } From 80f188e9d4ac4494849250766fbc8d2d2c779a92 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 9 Sep 2020 21:50:34 +0800 Subject: [PATCH 40/62] remove print --- test/modelbase/test_psrl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 0d7d8aabb..1bb22a04c 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -67,8 +67,6 @@ def test_psrl(args=get_args()): writer = SummaryWriter(args.logdir + '/' + args.task) def stop_fn(x): - print(policy.model.rew_mean, policy.model.rew_mean.mean()) - print(policy.model.rew_std, policy.model.rew_std.mean()) if env.spec.reward_threshold: return x >= env.spec.reward_threshold else: From a4118fcb557f2b561f8e6fc791a120238d1d8812 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 10 Sep 2020 08:32:37 +0800 Subject: [PATCH 41/62] remove weight-prior --- test/modelbase/test_psrl.py | 9 ++++----- tianshou/policy/modelbase/psrl.py | 12 +++--------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 1bb22a04c..b9c3fc8a3 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -15,7 +15,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='NChain-v0') parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--buffer-size', type=int, default=50000) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=1) @@ -25,7 +25,6 @@ def get_args(): parser.add_argument('--render', type=float, default=0.0) parser.add_argument('--rew-mean-prior', type=float, default=1.0) parser.add_argument('--rew-std-prior', type=float, default=10.0) - parser.add_argument('--rew-count-prior', type=int, default=1) parser.add_argument('--eps', type=float, default=0.01) return parser.parse_known_args()[0] @@ -55,10 +54,8 @@ def test_psrl(args=get_args()): trans_count_prior = np.ones((n_state, n_action, n_state)) rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) - rew_count_prior = np.full((n_state, n_action), args.rew_count_prior) policy = PSRLPolicy( - trans_count_prior, rew_mean_prior, rew_std_prior, rew_count_prior, - args.eps) + trans_count_prior, rew_mean_prior, rew_std_prior, args.eps) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) @@ -71,6 +68,8 @@ def stop_fn(x): return x >= env.spec.reward_threshold else: return False + + train_collector.collect(n_step=args.buffer_size, random=True) # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index dae360807..46a3227e8 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -14,8 +14,6 @@ class PSRLModel(object): with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). - :param np.ndarray rew_count_prior: count (weight) of the normal priors of - rewards, with shape (n_state, n_action). :param float epsilon: for precision control in value iteration. """ @@ -24,14 +22,13 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, - rew_count_prior: np.ndarray, epsilon: float, ) -> None: self.trans_count = trans_count_prior self.n_state, self.n_action = rew_mean_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior - self.rew_count = rew_count_prior + self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight self.eps = epsilon self.policy: Optional[np.ndarray] = None self.value = np.zeros(self.n_state) @@ -127,8 +124,6 @@ class PSRLPolicy(BasePolicy): with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). - :param np.ndarray rew_count_prior: count (weight) of the normal priors of - rewards, with shape (n_state, n_action). :param float epsilon: for precision control in value iteration. .. seealso:: @@ -142,13 +137,12 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, - rew_count_prior: np.ndarray, epsilon: float = 0.01, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.model = PSRLModel(trans_count_prior, rew_mean_prior, - rew_std_prior, rew_count_prior, epsilon) + self.model = PSRLModel( + trans_count_prior, rew_mean_prior, rew_std_prior, epsilon) def forward( self, From efa9a9a477d3044deca063bd61536dc85547b4fe Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 10 Sep 2020 17:37:39 +0800 Subject: [PATCH 42/62] discount factor regression --- test/modelbase/test_psrl.py | 3 ++- tianshou/policy/modelbase/psrl.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index b9c3fc8a3..58e59bc85 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -25,6 +25,7 @@ def get_args(): parser.add_argument('--render', type=float, default=0.0) parser.add_argument('--rew-mean-prior', type=float, default=1.0) parser.add_argument('--rew-std-prior', type=float, default=10.0) + parser.add_argument('--gamma', type=float, default=0.0) parser.add_argument('--eps', type=float, default=0.01) return parser.parse_known_args()[0] @@ -55,7 +56,7 @@ def test_psrl(args=get_args()): rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) policy = PSRLPolicy( - trans_count_prior, rew_mean_prior, rew_std_prior, args.eps) + trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 46a3227e8..93baa0b2a 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,8 +1,8 @@ import numpy as np from typing import Any, Dict, Union, Optional -from tianshou.data import Batch from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer class PSRLModel(object): @@ -124,6 +124,7 @@ class PSRLPolicy(BasePolicy): with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). + :param float discount_factor: in [0, 1]. :param float epsilon: for precision control in value iteration. .. seealso:: @@ -137,12 +138,16 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, + discount_factor: float = 0.99, epsilon: float = 0.01, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.model = PSRLModel( trans_count_prior, rew_mean_prior, rew_std_prior, epsilon) + assert 0.0 <= discount_factor <= 1.0, \ + "discount factor should be in [0, 1]" + self._gamma = discount_factor def forward( self, @@ -162,6 +167,12 @@ def forward( """ return Batch(act=self.model(batch.obs, state=state, info=batch.info)) + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: + return self.compute_episodic_return( + batch, gamma=self._gamma, gae_lambda=1.) + def learn( # type: ignore self, batch: Batch, *args: Any, **kwargs: Any ) -> Dict[str, float]: @@ -169,7 +180,7 @@ def learn( # type: ignore trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a)) - act, rew = batch.act, batch.rew + act, rew = batch.act, batch.returns obs, obs_next = batch.obs, batch.obs_next for i in range(len(obs)): trans_count[obs[i], act[i], obs_next[i]] += 1 From dd7f99fbc8c61eead22d78c5e9b38603e31f239a Mon Sep 17 00:00:00 2001 From: Yao Date: Thu, 10 Sep 2020 21:32:31 +0800 Subject: [PATCH 43/62] polish --- tianshou/policy/modelbase/psrl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 93baa0b2a..817dc7083 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -104,6 +104,7 @@ def value_iteration( value = new_value Q = rew + trans_prob.dot(value) new_value = Q.max(axis=1) + Q += eps * np.random.randn(*np.shape(Q)) return Q.argmax(axis=1), new_value def __call__(self, obs: np.ndarray, state=None, info=None) -> np.ndarray: From 5a094ff346247982af9c2abc2a36f5991fdf73cf Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 11 Sep 2020 09:56:42 +0800 Subject: [PATCH 44/62] small update --- tianshou/policy/modelbase/psrl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 817dc7083..7d94c961a 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -104,7 +104,7 @@ def value_iteration( value = new_value Q = rew + trans_prob.dot(value) new_value = Q.max(axis=1) - Q += eps * np.random.randn(*np.shape(Q)) + Q += eps * np.random.randn(*Q.shape) return Q.argmax(axis=1), new_value def __call__(self, obs: np.ndarray, state=None, info=None) -> np.ndarray: From a231fd7b2f309f04962d03a51bd9f7abaadbac4f Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 12 Sep 2020 09:42:42 +0800 Subject: [PATCH 45/62] modify readme --- examples/modelbase/README.md | 6 +++--- test/modelbase/test_psrl.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index 2a298c005..542c2ec3c 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -1,7 +1,7 @@ # PSRL -`NChain-v0`: `python3 psrl.py` +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1` -`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch 20 --step-per-epoch 1000` +`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 1 --rew-std-prior 10` -`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --epoch 20 --step-per-epoch 100` +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --test-num 8 --epoch 20` diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 58e59bc85..76e239216 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -13,18 +13,18 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='NChain-v0') + parser.add_argument('--task', type=str, default='FrozenLake-v0') parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=50000) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=10) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=1) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.0) - parser.add_argument('--rew-mean-prior', type=float, default=1.0) - parser.add_argument('--rew-std-prior', type=float, default=10.0) + parser.add_argument('--rew-mean-prior', type=float, default=0.0) + parser.add_argument('--rew-std-prior', type=float, default=1.0) parser.add_argument('--gamma', type=float, default=0.0) parser.add_argument('--eps', type=float, default=0.01) return parser.parse_known_args()[0] From 8e801de915c5d76c553e22b554ffff5ba4ef4f27 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 12 Sep 2020 09:54:29 +0800 Subject: [PATCH 46/62] NChain hparam --- examples/modelbase/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index 542c2ec3c..c2a83e916 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -2,6 +2,6 @@ `FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1` -`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 1 --rew-std-prior 10` +`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 1 --rew-std-prior 1` `Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --test-num 8 --epoch 20` From 8a94c48aea0e772c8f8629e366966fc6db4c24fc Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 12 Sep 2020 10:07:31 +0800 Subject: [PATCH 47/62] NChain hparam --- examples/modelbase/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index c2a83e916..83f2fa506 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -2,6 +2,6 @@ `FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1` -`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 1 --rew-std-prior 1` +`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1` `Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --test-num 8 --epoch 20` From b852ad9cdf694a55fd6620ce1c7475f653780a64 Mon Sep 17 00:00:00 2001 From: Yao Date: Sun, 13 Sep 2020 10:17:02 +0800 Subject: [PATCH 48/62] add discount factor for value iteration --- test/modelbase/test_psrl.py | 2 +- tianshou/policy/modelbase/psrl.py | 25 ++++++++++++------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 76e239216..b24da5791 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -25,7 +25,7 @@ def get_args(): parser.add_argument('--render', type=float, default=0.0) parser.add_argument('--rew-mean-prior', type=float, default=0.0) parser.add_argument('--rew-std-prior', type=float, default=1.0) - parser.add_argument('--gamma', type=float, default=0.0) + parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--eps', type=float, default=0.01) return parser.parse_known_args()[0] diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 7d94c961a..a2918dc3b 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -14,6 +14,7 @@ class PSRLModel(object): with shape (n_state, n_action). :param np.ndarray rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). + :param float discount_factor: in [0, 1]. :param float epsilon: for precision control in value iteration. """ @@ -22,12 +23,14 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, + discount_factor: float, epsilon: float, ) -> None: self.trans_count = trans_count_prior self.n_state, self.n_action = rew_mean_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior + self.discount_factor = discount_factor self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight self.eps = epsilon self.policy: Optional[np.ndarray] = None @@ -79,13 +82,15 @@ def solve_policy(self) -> None: self.policy, self.value = self.value_iteration( self.sample_from_prob(self.trans_count), self.sample_from_rew(), + self.discount_factor, self.eps, self.value, ) @staticmethod def value_iteration( - trans_prob: np.ndarray, rew: np.ndarray, eps: float, value: np.ndarray + trans_prob: np.ndarray, rew: np.ndarray, + discount_factor: float, eps: float, value: np.ndarray ) -> np.ndarray: """Value iteration solver for MDPs. @@ -93,16 +98,16 @@ def value_iteration( (n_action, n_state, n_state). :param np.ndarray rew: rewards, with shape (n_state, n_action). :param float eps: for precision control. + :param float discount_factor: in [0, 1]. :param np.ndarray value: the initialize value of value array, with shape (n_state, ). - :return: the optimal policy with shape (n_state, ). """ - Q = rew + trans_prob.dot(value) + Q = rew + discount_factor * trans_prob.dot(value) new_value = Q.max(axis=1) while not np.allclose(new_value, value, eps): value = new_value - Q = rew + trans_prob.dot(value) + Q = rew + discount_factor * trans_prob.dot(value) new_value = Q.max(axis=1) Q += eps * np.random.randn(*Q.shape) return Q.argmax(axis=1), new_value @@ -145,10 +150,10 @@ def __init__( ) -> None: super().__init__(**kwargs) self.model = PSRLModel( - trans_count_prior, rew_mean_prior, rew_std_prior, epsilon) + trans_count_prior, rew_mean_prior, rew_std_prior, + discount_factor, epsilon) assert 0.0 <= discount_factor <= 1.0, \ "discount factor should be in [0, 1]" - self._gamma = discount_factor def forward( self, @@ -168,12 +173,6 @@ def forward( """ return Batch(act=self.model(batch.obs, state=state, info=batch.info)) - def process_fn( - self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray - ) -> Batch: - return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1.) - def learn( # type: ignore self, batch: Batch, *args: Any, **kwargs: Any ) -> Dict[str, float]: @@ -181,7 +180,7 @@ def learn( # type: ignore trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a)) - act, rew = batch.act, batch.returns + act, rew = batch.act, batch.rew obs, obs_next = batch.obs, batch.obs_next for i in range(len(obs)): trans_count[obs[i], act[i], obs_next[i]] += 1 From 324fb4ac138564f281c690e90f26b462e2eaa958 Mon Sep 17 00:00:00 2001 From: Yao Date: Sun, 13 Sep 2020 10:40:30 +0800 Subject: [PATCH 49/62] polish and fix an annotation --- tianshou/policy/modelbase/psrl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index a2918dc3b..574fa05e7 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -2,13 +2,13 @@ from typing import Any, Dict, Union, Optional from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch class PSRLModel(object): """Implementation of Posterior Sampling Reinforcement Learning Model. - :param np.ndarray p_prior: dirichlet prior (alphas), with shape + :param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape (n_state, n_action, n_state). :param np.ndarray rew_mean_prior: means of the normal priors of rewards, with shape (n_state, n_action). From d002785d0f517b725dcbc63e486c5edea1058355 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 13 Sep 2020 15:43:37 +0800 Subject: [PATCH 50/62] fix timing in trainer --- test/modelbase/test_psrl.py | 6 +++--- tianshou/data/collector.py | 6 +++++- tianshou/trainer/offpolicy.py | 2 ++ tianshou/trainer/onpolicy.py | 2 ++ 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index b24da5791..9232a390b 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -13,11 +13,11 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='FrozenLake-v0') + parser.add_argument('--task', type=str, default='NChain-v0') parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=50000) parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-epoch', type=int, default=5) parser.add_argument('--collect-per-step', type=int, default=1) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) @@ -33,7 +33,7 @@ def get_args(): def test_psrl(args=get_args()): env = gym.make(args.task) if args.task == "NChain-v0": - env.spec.reward_threshold = 3650 # described in PSRL paper + env.spec.reward_threshold = 3647 # described in PSRL paper print("reward threshold:", env.spec.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index aacaa9b6d..8fe823896 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -129,10 +129,14 @@ def reset(self) -> None: obs_next={}, policy={}) self.reset_env() self.reset_buffer() - self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 + self.reset_stat() if self._action_noise is not None: self._action_noise.reset() + def reset_stat(self) -> None: + """Reset the statistic variables.""" + self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 + def reset_buffer(self) -> None: """Reset the main data buffer.""" if self.buffer is not None: diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index bbf523396..a9d37d80d 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -75,6 +75,8 @@ def offpolicy_trainer( best_epoch, best_reward = -1, -1.0 stat = {} start_time = time.time() + train_collector.reset_stat() + test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): # train diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index ac97ba782..b330b98e5 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -75,6 +75,8 @@ def onpolicy_trainer( best_epoch, best_reward = -1, -1.0 stat = {} start_time = time.time() + train_collector.reset_stat() + test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): # train From a0fae3ab31c4e9965c0c7ecfb5accdae114a10ad Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 13 Sep 2020 19:48:12 +0800 Subject: [PATCH 51/62] fix test --- tianshou/policy/modelbase/psrl.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 574fa05e7..b1cf9c094 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,8 +1,8 @@ import numpy as np from typing import Any, Dict, Union, Optional -from tianshou.policy import BasePolicy from tianshou.data import Batch +from tianshou.policy import BasePolicy class PSRLModel(object): @@ -33,7 +33,7 @@ def __init__( self.discount_factor = discount_factor self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight self.eps = epsilon - self.policy: Optional[np.ndarray] = None + self.policy: np.ndarray self.value = np.zeros(self.n_state) self.updated = False @@ -112,7 +112,12 @@ def value_iteration( Q += eps * np.random.randn(*Q.shape) return Q.argmax(axis=1), new_value - def __call__(self, obs: np.ndarray, state=None, info=None) -> np.ndarray: + def __call__( + self, + obs: np.ndarray, + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> np.ndarray: if not self.updated: self.solve_policy() return self.policy[obs] @@ -173,7 +178,7 @@ def forward( """ return Batch(act=self.model(batch.obs, state=state, info=batch.info)) - def learn( # type: ignore + def learn( self, batch: Batch, *args: Any, **kwargs: Any ) -> Dict[str, float]: n_s, n_a = self.model.n_state, self.model.n_action From bef8ba4251e89a77afe8fa0e92cb6e145e5a29be Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 14 Sep 2020 11:31:47 +0800 Subject: [PATCH 52/62] fix docs --- tianshou/policy/modelbase/psrl.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index b1cf9c094..f70cf8547 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -89,13 +89,16 @@ def solve_policy(self) -> None: @staticmethod def value_iteration( - trans_prob: np.ndarray, rew: np.ndarray, - discount_factor: float, eps: float, value: np.ndarray + trans_prob: np.ndarray, + rew: np.ndarray, + discount_factor: float, + eps: float, + value: np.ndarray ) -> np.ndarray: """Value iteration solver for MDPs. :param np.ndarray trans_prob: transition probabilities, with shape - (n_action, n_state, n_state). + (n_state, n_action, n_state). :param np.ndarray rew: rewards, with shape (n_state, n_action). :param float eps: for precision control. :param float discount_factor: in [0, 1]. @@ -109,6 +112,7 @@ def value_iteration( value = new_value Q = rew + discount_factor * trans_prob.dot(value) new_value = Q.max(axis=1) + # this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly Q += eps * np.random.randn(*Q.shape) return Q.argmax(axis=1), new_value @@ -195,6 +199,7 @@ def learn( if hasattr(batch.info, "TimeLimit.truncated") \ and batch.info["TimeLimit.truncated"][i]: continue + # special operation for terminal states: add a self-loop trans_count[obs_next[i], :, obs_next[i]] += 1 rew_count[obs_next[i], :] += 1 self.model.observe(trans_count, rew_sum, rew_count) From 9a5ea8a13009abed40832a827eb42fbdf8622e54 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 14 Sep 2020 11:35:30 +0800 Subject: [PATCH 53/62] update --- tianshou/policy/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 98963aaf6..3cae35450 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -8,8 +8,8 @@ from tianshou.policy.modelfree.ppo import PPOPolicy from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy -from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager from tianshou.policy.modelbase.psrl import PSRLPolicy +from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager __all__ = [ @@ -23,6 +23,6 @@ "PPOPolicy", "TD3Policy", "SACPolicy", - "MultiAgentPolicyManager", "PSRLPolicy", + "MultiAgentPolicyManager", ] From 93ebee98755f159d270370b8aa3afd420ed3624f Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 14 Sep 2020 19:28:28 +0800 Subject: [PATCH 54/62] polish --- tianshou/policy/modelbase/psrl.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index f70cf8547..c661c0a12 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -104,6 +104,7 @@ def value_iteration( :param float discount_factor: in [0, 1]. :param np.ndarray value: the initialize value of value array, with shape (n_state, ). + :return: the optimal policy with shape (n_state, ). """ Q = rew + discount_factor * trans_prob.dot(value) @@ -158,11 +159,12 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) + assert ( + 0.0 <= discount_factor <= 1.0 + ), "discount factor should be in [0, 1]" self.model = PSRLModel( trans_count_prior, rew_mean_prior, rew_std_prior, discount_factor, epsilon) - assert 0.0 <= discount_factor <= 1.0, \ - "discount factor should be in [0, 1]" def forward( self, @@ -180,7 +182,8 @@ def forward( Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ - return Batch(act=self.model(batch.obs, state=state, info=batch.info)) + act = self.model(batch.obs, state=state, info=batch.info) + return Batch(act=act) def learn( self, batch: Batch, *args: Any, **kwargs: Any @@ -189,19 +192,18 @@ def learn( trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a)) - act, rew = batch.act, batch.rew - obs, obs_next = batch.obs, batch.obs_next - for i in range(len(obs)): - trans_count[obs[i], act[i], obs_next[i]] += 1 - rew_sum[obs[i], act[i]] += rew[i] - rew_count[obs[i], act[i]] += 1 - if batch.done[i]: - if hasattr(batch.info, "TimeLimit.truncated") \ - and batch.info["TimeLimit.truncated"][i]: + for (obs, act, rew, done, obs_next, info) in zip( + batch.obs, batch.act, batch.rew, batch.done, batch.obs_next, + batch.info): + trans_count[obs, act, obs_next] += 1 + rew_sum[obs, act] += rew + rew_count[obs, act] += 1 + if done: + if info.get("TimeLimit.truncated"): continue # special operation for terminal states: add a self-loop - trans_count[obs_next[i], :, obs_next[i]] += 1 - rew_count[obs_next[i], :] += 1 + trans_count[obs_next, :, obs_next] += 1 + rew_count[obs_next, :] += 1 self.model.observe(trans_count, rew_sum, rew_count) return { "psrl/rew_mean": self.model.rew_mean.mean(), From 10c3ee0a5cff2b4c0ad2d89b7328a874ef7f4935 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 14 Sep 2020 21:17:18 +0800 Subject: [PATCH 55/62] polish --- tianshou/policy/modelbase/psrl.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index c661c0a12..aee4e9ef7 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -192,15 +192,12 @@ def learn( trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a)) - for (obs, act, rew, done, obs_next, info) in zip( - batch.obs, batch.act, batch.rew, batch.done, batch.obs_next, - batch.info): + for b in batch.split(size=1): + obs, act, obs_next = b.obs, b.act, b.obs_next trans_count[obs, act, obs_next] += 1 - rew_sum[obs, act] += rew + rew_sum[obs, act] += b.rew rew_count[obs, act] += 1 - if done: - if info.get("TimeLimit.truncated"): - continue + if b.done: # special operation for terminal states: add a self-loop trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 From 1edb735ade849935387a18eaabc4e533f63dcf1b Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 16 Sep 2020 17:56:56 +0800 Subject: [PATCH 56/62] add_done_loop --- examples/modelbase/README.md | 6 +++--- test/modelbase/test_psrl.py | 4 +++- tianshou/policy/modelbase/psrl.py | 6 +++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index 83f2fa506..88ef8c16d 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -1,7 +1,7 @@ # PSRL -`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1` - `NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1` -`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --test-num 8 --epoch 20` +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop` + +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20` diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 9232a390b..b54f96267 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -27,6 +27,7 @@ def get_args(): parser.add_argument('--rew-std-prior', type=float, default=1.0) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--eps', type=float, default=0.01) + parser.add_argument('--add-done-loop', action='store_true') return parser.parse_known_args()[0] @@ -56,7 +57,8 @@ def test_psrl(args=get_args()): rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) policy = PSRLPolicy( - trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps) + trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps, + args.add_done_loop) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index aee4e9ef7..0aca272ee 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -142,6 +142,8 @@ class PSRLPolicy(BasePolicy): of rewards, with shape (n_state, n_action). :param float discount_factor: in [0, 1]. :param float epsilon: for precision control in value iteration. + :param bool add_done_loop: whether to add an extra self-loop for the + terminal state in MDP, defaults to False. .. seealso:: @@ -156,6 +158,7 @@ def __init__( rew_std_prior: np.ndarray, discount_factor: float = 0.99, epsilon: float = 0.01, + add_done_loop: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -165,6 +168,7 @@ def __init__( self.model = PSRLModel( trans_count_prior, rew_mean_prior, rew_std_prior, discount_factor, epsilon) + self._add_done_loop = add_done_loop def forward( self, @@ -197,7 +201,7 @@ def learn( trans_count[obs, act, obs_next] += 1 rew_sum[obs, act] += b.rew rew_count[obs, act] += 1 - if b.done: + if self._add_done_loop and b.done: # special operation for terminal states: add a self-loop trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 From c5963b4c3c83dc0daec7955cd1d30fde46d25778 Mon Sep 17 00:00:00 2001 From: Yao Date: Wed, 16 Sep 2020 21:12:25 +0800 Subject: [PATCH 57/62] fix rew_std calculation --- tianshou/policy/modelbase/psrl.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index 0aca272ee..c92159790 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -30,6 +30,8 @@ def __init__( self.n_state, self.n_action = rew_mean_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior + self.rew_square_sum = np.zeros_like(rew_mean_prior) + self.rew_std_prior = rew_std_prior self.discount_factor = discount_factor self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight self.eps = epsilon @@ -41,6 +43,7 @@ def observe( self, trans_count: np.ndarray, rew_sum: np.ndarray, + rew_square_sum: np.ndarray, rew_count: np.ndarray ) -> None: """Add data into memory pool. @@ -55,6 +58,8 @@ def observe( (n_state, n_action, n_state). :param np.ndarray rew_sum: total rewards, with shape (n_state, n_action). + :param np.ndarray rew_square_sum: total rewards' squares, with shape + (n_state, n_action). :param np.ndarray rew_count: the number of rewards, with shape (n_state, n_action). """ @@ -62,7 +67,10 @@ def observe( self.trans_count += trans_count sum_count = self.rew_count + rew_count self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count - self.rew_std *= self.rew_count / sum_count + self.rew_square_sum += rew_square_sum + self.rew_std = np.sqrt(1 / ((sum_count / + (self.rew_square_sum + 1e-6)) + + 1 / self.rew_std_prior ** 2)) self.rew_count = sum_count @staticmethod @@ -195,17 +203,19 @@ def learn( n_s, n_a = self.model.n_state, self.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) + rew_square_sum = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a)) for b in batch.split(size=1): obs, act, obs_next = b.obs, b.act, b.obs_next trans_count[obs, act, obs_next] += 1 rew_sum[obs, act] += b.rew + rew_square_sum[obs, act] += b.rew ** 2 rew_count[obs, act] += 1 if self._add_done_loop and b.done: # special operation for terminal states: add a self-loop trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 - self.model.observe(trans_count, rew_sum, rew_count) + self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) return { "psrl/rew_mean": self.model.rew_mean.mean(), "psrl/rew_std": self.model.rew_std.mean(), From 6eaf94ffb6279b89a2fc0860cfb02365ab897d40 Mon Sep 17 00:00:00 2001 From: Yao Date: Wed, 16 Sep 2020 22:55:27 +0800 Subject: [PATCH 58/62] bug fixed --- tianshou/policy/modelbase/psrl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index c92159790..a1ad84c79 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -69,7 +69,8 @@ def observe( self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count self.rew_square_sum += rew_square_sum self.rew_std = np.sqrt(1 / ((sum_count / - (self.rew_square_sum + 1e-6)) + + (self.rew_square_sum / sum_count - + self.rew_mean ** 2 + 1e-6)) + 1 / self.rew_std_prior ** 2)) self.rew_count = sum_count From 1a58d57754948fc82f9357ddbfd31f1ee2ec634a Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 17 Sep 2020 10:46:05 +0800 Subject: [PATCH 59/62] polish --- examples/modelbase/README.md | 2 +- tianshou/policy/modelbase/psrl.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/modelbase/README.md b/examples/modelbase/README.md index 88ef8c16d..c3563f629 100644 --- a/examples/modelbase/README.md +++ b/examples/modelbase/README.md @@ -2,6 +2,6 @@ `NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1` -`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop` +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20` `Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20` diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index a1ad84c79..edcf43d9a 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -38,13 +38,14 @@ def __init__( self.policy: np.ndarray self.value = np.zeros(self.n_state) self.updated = False + self.__eps = np.finfo(np.float32).eps.item() def observe( self, trans_count: np.ndarray, rew_sum: np.ndarray, rew_square_sum: np.ndarray, - rew_count: np.ndarray + rew_count: np.ndarray, ) -> None: """Add data into memory pool. @@ -68,10 +69,9 @@ def observe( sum_count = self.rew_count + rew_count self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count self.rew_square_sum += rew_square_sum - self.rew_std = np.sqrt(1 / ((sum_count / - (self.rew_square_sum / sum_count - - self.rew_mean ** 2 + 1e-6)) + - 1 / self.rew_std_prior ** 2)) + raw_std2 = self.rew_square_sum / sum_count - self.rew_mean ** 2 + self.rew_std = np.sqrt(1 / ( + sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2)) self.rew_count = sum_count @staticmethod @@ -102,7 +102,7 @@ def value_iteration( rew: np.ndarray, discount_factor: float, eps: float, - value: np.ndarray + value: np.ndarray, ) -> np.ndarray: """Value iteration solver for MDPs. From 0496b88facc069eb1481f503bb81c072f1b742bd Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 19 Sep 2020 10:37:08 +0800 Subject: [PATCH 60/62] update readme --- README.md | 1 + docs/index.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 2c9a098dc..636b487ee 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) +- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) diff --git a/docs/index.rst b/docs/index.rst index 92a38d1d5..7d96f3fdd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -18,6 +18,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ +* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ * :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ From 73fac5dacd357ab49d42960c0a6d3b565b22731f Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 23 Sep 2020 08:54:52 +0800 Subject: [PATCH 61/62] polish --- tianshou/policy/modelbase/psrl.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index edcf43d9a..dcf6a5d05 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -1,3 +1,4 @@ +import torch import numpy as np from typing import Any, Dict, Union, Optional @@ -74,23 +75,19 @@ def observe( sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2)) self.rew_count = sum_count - @staticmethod - def sample_from_prob(trans_count: np.ndarray) -> np.ndarray: - sample_prob = np.zeros_like(trans_count) - n_s, n_a = trans_count.shape[:2] - for i in range(n_s): - for j in range(n_a): # numba does not support dirichlet :( - sample_prob[i][j] = np.random.dirichlet(trans_count[i][j]) + def sample_trans_prob(self) -> np.ndarray: + sample_prob = torch.distributions.Dirichlet( + torch.from_numpy(self.trans_count)).sample().numpy() return sample_prob - def sample_from_rew(self) -> np.ndarray: + def sample_reward(self) -> np.ndarray: return np.random.normal(self.rew_mean, self.rew_std) def solve_policy(self) -> None: self.updated = True self.policy, self.value = self.value_iteration( - self.sample_from_prob(self.trans_count), - self.sample_from_rew(), + self.sample_trans_prob(), + self.sample_reward(), self.discount_factor, self.eps, self.value, From 1b8a0b96dded29dadad66a13bc1cc0f3c6700e93 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 23 Sep 2020 09:03:03 +0800 Subject: [PATCH 62/62] faster test --- test/modelbase/test_psrl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index b54f96267..6fb0e16ad 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -19,7 +19,7 @@ def get_args(): parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=5) parser.add_argument('--collect-per-step', type=int, default=1) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=1) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.0)