这是indexloc提供的服务,不要输入任何密码
Skip to content

add PSRL policy #202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 81 commits into from
Sep 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
f14e4c6
add PSRL policy
Sep 4, 2020
71ec861
Merge branch 'master' into dev
Trinkle23897 Sep 4, 2020
f26255d
improve PSRL code
Sep 5, 2020
5716fd1
merge
Sep 5, 2020
a317933
polish
Trinkle23897 Sep 5, 2020
6c25cea
pep8
Trinkle23897 Sep 5, 2020
43e2581
fix docs error
Trinkle23897 Sep 5, 2020
b78cba4
add value iteration
Sep 5, 2020
577a269
polish
Trinkle23897 Sep 5, 2020
554d81e
soft link
Trinkle23897 Sep 5, 2020
91d5bdd
fix pytest
Trinkle23897 Sep 5, 2020
df90377
fix pytest
Trinkle23897 Sep 5, 2020
bd1c34c
polish
Sep 5, 2020
f5fc5f8
bug fixed
Sep 5, 2020
077b015
polish
Sep 5, 2020
caf5ea4
minor update
Trinkle23897 Sep 5, 2020
2bc2fb1
polish
Trinkle23897 Sep 5, 2020
e825bfc
docs
Trinkle23897 Sep 5, 2020
ee8a197
simplify PSRLModel.observe
Trinkle23897 Sep 6, 2020
e97f8c2
remove unnecessary part
Trinkle23897 Sep 6, 2020
c775101
fix pep8
Trinkle23897 Sep 6, 2020
12698ed
use onpolicy instead of offpolicy
Sep 6, 2020
89c87f1
Merge branch 'dev' of https://github.com/feng-y16/tianshou into dev
Sep 6, 2020
13b83e8
add discount factor
Sep 6, 2020
19fcce0
fix config
Trinkle23897 Sep 6, 2020
bee48e5
fix docstring
Trinkle23897 Sep 6, 2020
f60fd9d
tune taxi
Trinkle23897 Sep 6, 2020
99f9239
add operations for absorbing states
Sep 6, 2020
9aa2394
polish
Sep 6, 2020
f580c4c
polish
Sep 6, 2020
40df71c
fix gamma=0
Sep 6, 2020
458ecfa
Merge branch 'master' into dev
Trinkle23897 Sep 6, 2020
874fb68
use sampling, delete epsilon greedy
Sep 7, 2020
460039f
Merge branch 'dev' of https://github.com/feng-y16/tianshou into dev
Sep 7, 2020
e9b66d9
Merge branch 'master' into dev
Trinkle23897 Sep 7, 2020
5fb3a77
polish
Sep 7, 2020
7a21aa9
Merge branch 'dev' of https://github.com/feng-y16/tianshou into dev
Sep 7, 2020
3a4e454
Update tianshou/policy/modelbase/psrl.py
Trinkle23897 Sep 7, 2020
ea766bd
polish
Sep 8, 2020
aba288e
Merge branch 'dev' of https://github.com/feng-y16/tianshou into dev
Sep 8, 2020
46054bd
Merge branch 'master' into dev
Trinkle23897 Sep 8, 2020
9222afa
polish
Trinkle23897 Sep 9, 2020
336718f
polish
Sep 9, 2020
3cdb8f8
Merge branch 'dev' of https://github.com/feng-y16/tianshou into dev
Sep 9, 2020
c759161
add rew-mean-prior and rew-std-prior argument in test_psrl
Trinkle23897 Sep 9, 2020
a0cf86d
fix test
Trinkle23897 Sep 9, 2020
3d1ffb8
change (a, s, s) to (s, a, s)
Trinkle23897 Sep 9, 2020
6e9ba48
add value iteration eps to arguments
Trinkle23897 Sep 9, 2020
71c1e0f
add rew-count-prior and improve value-iteration efficiency
Trinkle23897 Sep 9, 2020
80f188e
remove print
Trinkle23897 Sep 9, 2020
a4118fc
remove weight-prior
Trinkle23897 Sep 10, 2020
efa9a9a
discount factor regression
Trinkle23897 Sep 10, 2020
dd7f99f
polish
Sep 10, 2020
c98774c
Merge branch 'master' into dev
Trinkle23897 Sep 11, 2020
5a094ff
small update
Trinkle23897 Sep 11, 2020
a231fd7
modify readme
Trinkle23897 Sep 12, 2020
8e801de
NChain hparam
Trinkle23897 Sep 12, 2020
3407ab0
Merge branch 'master' into dev
Trinkle23897 Sep 12, 2020
8a94c48
NChain hparam
Trinkle23897 Sep 12, 2020
5330ccc
Merge branch 'dev' of github.com:feng-y16/tianshou into dev
Trinkle23897 Sep 12, 2020
505af59
Merge branch 'master' into dev
Trinkle23897 Sep 12, 2020
b852ad9
add discount factor for value iteration
Sep 13, 2020
324fb4a
polish and fix an annotation
Sep 13, 2020
d002785
fix timing in trainer
Trinkle23897 Sep 13, 2020
c0fe899
Merge branch 'master' into dev
Trinkle23897 Sep 13, 2020
a0fae3a
fix test
Trinkle23897 Sep 13, 2020
bef8ba4
fix docs
Trinkle23897 Sep 14, 2020
9a5ea8a
update
Trinkle23897 Sep 14, 2020
1649e28
Merge branch 'master' into dev
Trinkle23897 Sep 14, 2020
04247d9
Merge branch 'master' into dev
Trinkle23897 Sep 14, 2020
93ebee9
polish
Trinkle23897 Sep 14, 2020
10c3ee0
polish
Trinkle23897 Sep 14, 2020
693c864
Merge branch 'master' into dev
Trinkle23897 Sep 16, 2020
1edb735
add_done_loop
Trinkle23897 Sep 16, 2020
c5963b4
fix rew_std calculation
Sep 16, 2020
6eaf94f
bug fixed
Sep 16, 2020
1a58d57
polish
Trinkle23897 Sep 17, 2020
0496b88
update readme
Trinkle23897 Sep 19, 2020
7569490
Merge branch 'master' into dev
Trinkle23897 Sep 22, 2020
73fac5d
polish
Trinkle23897 Sep 23, 2020
1b8a0b9
faster test
Trinkle23897 Sep 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
- Vanilla Imitation Learning
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)

Here is Tianshou's other features:

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
Expand Down
7 changes: 7 additions & 0 deletions examples/modelbase/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# PSRL

`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1`

`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20`

`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20`
1 change: 1 addition & 0 deletions examples/modelbase/psrl.py
Empty file added test/modelbase/__init__.py
Empty file.
97 changes: 97 additions & 0 deletions test/modelbase/test_psrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import PSRLPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='NChain-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=50000)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=5)
parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--training-num', type=int, default=1)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.0)
parser.add_argument('--rew-mean-prior', type=float, default=0.0)
parser.add_argument('--rew-std-prior', type=float, default=1.0)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--eps', type=float, default=0.01)
parser.add_argument('--add-done-loop', action='store_true')
return parser.parse_known_args()[0]


def test_psrl(args=get_args()):
env = gym.make(args.task)
if args.task == "NChain-v0":
env.spec.reward_threshold = 3647 # described in PSRL paper
print("reward threshold:", env.spec.reward_threshold)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
n_action = args.action_shape
n_state = args.state_shape
trans_count_prior = np.ones((n_state, n_action, n_state))
rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior)
rew_std_prior = np.full((n_state, n_action), args.rew_std_prior)
policy = PSRLPolicy(
trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps,
args.add_done_loop)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir + '/' + args.task)

def stop_fn(x):
if env.spec.reward_threshold:
return x >= env.spec.reward_threshold
else:
return False

train_collector.collect(n_step=args.buffer_size, random=True)
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, 1,
args.test_num, 0, stop_fn=stop_fn, writer=writer,
test_in_train=False)

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
elif env.spec.reward_threshold:
assert result["best_reward"] >= env.spec.reward_threshold


if __name__ == '__main__':
test_psrl()
2 changes: 2 additions & 0 deletions tianshou/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tianshou.policy.modelfree.td3 import TD3Policy
from tianshou.policy.modelfree.sac import SACPolicy
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.modelbase.psrl import PSRLPolicy
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager


Expand All @@ -24,5 +25,6 @@
"TD3Policy",
"SACPolicy",
"DiscreteSACPolicy",
"PSRLPolicy",
"MultiAgentPolicyManager",
]
Empty file.
220 changes: 220 additions & 0 deletions tianshou/policy/modelbase/psrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import torch
import numpy as np
from typing import Any, Dict, Union, Optional

from tianshou.data import Batch
from tianshou.policy import BasePolicy


class PSRLModel(object):
"""Implementation of Posterior Sampling Reinforcement Learning Model.

:param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape
(n_state, n_action, n_state).
:param np.ndarray rew_mean_prior: means of the normal priors of rewards,
with shape (n_state, n_action).
:param np.ndarray rew_std_prior: standard deviations of the normal priors
of rewards, with shape (n_state, n_action).
:param float discount_factor: in [0, 1].
:param float epsilon: for precision control in value iteration.
"""

def __init__(
self,
trans_count_prior: np.ndarray,
rew_mean_prior: np.ndarray,
rew_std_prior: np.ndarray,
discount_factor: float,
epsilon: float,
) -> None:
self.trans_count = trans_count_prior
self.n_state, self.n_action = rew_mean_prior.shape
self.rew_mean = rew_mean_prior
self.rew_std = rew_std_prior
self.rew_square_sum = np.zeros_like(rew_mean_prior)
self.rew_std_prior = rew_std_prior
self.discount_factor = discount_factor
self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight
self.eps = epsilon
self.policy: np.ndarray
self.value = np.zeros(self.n_state)
self.updated = False
self.__eps = np.finfo(np.float32).eps.item()

def observe(
self,
trans_count: np.ndarray,
rew_sum: np.ndarray,
rew_square_sum: np.ndarray,
rew_count: np.ndarray,
) -> None:
"""Add data into memory pool.

For rewards, we have a normal prior at first. After we observed a
reward for a given state-action pair, we use the mean value of our
observations instead of the prior mean as the posterior mean. The
standard deviations are in inverse proportion to the number of the
corresponding observations.

:param np.ndarray trans_count: the number of observations, with shape
(n_state, n_action, n_state).
:param np.ndarray rew_sum: total rewards, with shape
(n_state, n_action).
:param np.ndarray rew_square_sum: total rewards' squares, with shape
(n_state, n_action).
:param np.ndarray rew_count: the number of rewards, with shape
(n_state, n_action).
"""
self.updated = False
self.trans_count += trans_count
sum_count = self.rew_count + rew_count
self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count
self.rew_square_sum += rew_square_sum
raw_std2 = self.rew_square_sum / sum_count - self.rew_mean ** 2
self.rew_std = np.sqrt(1 / (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this. Can you explain it more?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line of calculating self.rew_std is strange.

sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2))
self.rew_count = sum_count

def sample_trans_prob(self) -> np.ndarray:
sample_prob = torch.distributions.Dirichlet(
torch.from_numpy(self.trans_count)).sample().numpy()
return sample_prob

def sample_reward(self) -> np.ndarray:
return np.random.normal(self.rew_mean, self.rew_std)

def solve_policy(self) -> None:
self.updated = True
self.policy, self.value = self.value_iteration(
self.sample_trans_prob(),
self.sample_reward(),
self.discount_factor,
self.eps,
self.value,
)

@staticmethod
def value_iteration(
trans_prob: np.ndarray,
rew: np.ndarray,
discount_factor: float,
eps: float,
value: np.ndarray,
) -> np.ndarray:
"""Value iteration solver for MDPs.

:param np.ndarray trans_prob: transition probabilities, with shape
(n_state, n_action, n_state).
:param np.ndarray rew: rewards, with shape (n_state, n_action).
:param float eps: for precision control.
:param float discount_factor: in [0, 1].
:param np.ndarray value: the initialize value of value array, with
shape (n_state, ).

:return: the optimal policy with shape (n_state, ).
"""
Q = rew + discount_factor * trans_prob.dot(value)
new_value = Q.max(axis=1)
while not np.allclose(new_value, value, eps):
value = new_value
Q = rew + discount_factor * trans_prob.dot(value)
new_value = Q.max(axis=1)
# this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly
Q += eps * np.random.randn(*Q.shape)
return Q.argmax(axis=1), new_value

def __call__(
self,
obs: np.ndarray,
state: Optional[Any] = None,
info: Dict[str, Any] = {},
Comment on lines +129 to +130
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two arguments are not used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's to be consistent with model API

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "model API"? PSRLModel is the first model and I don't see it has any base-class.

) -> np.ndarray:
if not self.updated:
self.solve_policy()
return self.policy[obs]


class PSRLPolicy(BasePolicy):
"""Implementation of Posterior Sampling Reinforcement Learning.

Reference: Strens M. A Bayesian framework for reinforcement learning [C]
//ICML. 2000, 2000: 943-950.

:param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape
(n_state, n_action, n_state).
:param np.ndarray rew_mean_prior: means of the normal priors of rewards,
with shape (n_state, n_action).
:param np.ndarray rew_std_prior: standard deviations of the normal priors
of rewards, with shape (n_state, n_action).
:param float discount_factor: in [0, 1].
:param float epsilon: for precision control in value iteration.
:param bool add_done_loop: whether to add an extra self-loop for the
terminal state in MDP, defaults to False.

.. seealso::

Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""

def __init__(
self,
trans_count_prior: np.ndarray,
rew_mean_prior: np.ndarray,
rew_std_prior: np.ndarray,
discount_factor: float = 0.99,
epsilon: float = 0.01,
add_done_loop: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
assert (
0.0 <= discount_factor <= 1.0
), "discount factor should be in [0, 1]"
self.model = PSRLModel(
trans_count_prior, rew_mean_prior, rew_std_prior,
discount_factor, epsilon)
self._add_done_loop = add_done_loop

def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data with PSRL model.

:return: A :class:`~tianshou.data.Batch` with "act" key containing
the action.

.. seealso::

Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
act = self.model(batch.obs, state=state, info=batch.info)
return Batch(act=act)

def learn(
self, batch: Batch, *args: Any, **kwargs: Any
) -> Dict[str, float]:
n_s, n_a = self.model.n_state, self.model.n_action
trans_count = np.zeros((n_s, n_a, n_s))
rew_sum = np.zeros((n_s, n_a))
rew_square_sum = np.zeros((n_s, n_a))
rew_count = np.zeros((n_s, n_a))
for b in batch.split(size=1):
obs, act, obs_next = b.obs, b.act, b.obs_next
trans_count[obs, act, obs_next] += 1
rew_sum[obs, act] += b.rew
rew_square_sum[obs, act] += b.rew ** 2
rew_count[obs, act] += 1
if self._add_done_loop and b.done:
# special operation for terminal states: add a self-loop
trans_count[obs_next, :, obs_next] += 1
rew_count[obs_next, :] += 1
self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count)
return {
"psrl/rew_mean": self.model.rew_mean.mean(),
"psrl/rew_std": self.model.rew_std.mean(),
}