-
Notifications
You must be signed in to change notification settings - Fork 1.2k
add PSRL policy #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add PSRL policy #202
Changes from all commits
f14e4c6
71ec861
f26255d
5716fd1
a317933
6c25cea
43e2581
b78cba4
577a269
554d81e
91d5bdd
df90377
bd1c34c
f5fc5f8
077b015
caf5ea4
2bc2fb1
e825bfc
ee8a197
e97f8c2
c775101
12698ed
89c87f1
13b83e8
19fcce0
bee48e5
f60fd9d
99f9239
9aa2394
f580c4c
40df71c
458ecfa
874fb68
460039f
e9b66d9
5fb3a77
7a21aa9
3a4e454
ea766bd
aba288e
46054bd
9222afa
336718f
3cdb8f8
c759161
a0cf86d
3d1ffb8
6e9ba48
71c1e0f
80f188e
a4118fc
efa9a9a
dd7f99f
c98774c
5a094ff
a231fd7
8e801de
3407ab0
8a94c48
5330ccc
505af59
b852ad9
324fb4a
d002785
c0fe899
a0fae3a
bef8ba4
9a5ea8a
1649e28
04247d9
93ebee9
10c3ee0
693c864
1edb735
c5963b4
6eaf94f
1a58d57
0496b88
7569490
73fac5d
1b8a0b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# PSRL | ||
|
||
`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1` | ||
|
||
`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20` | ||
|
||
`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../test/modelbase/test_psrl.py |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import gym | ||
import torch | ||
import pprint | ||
import argparse | ||
import numpy as np | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from tianshou.policy import PSRLPolicy | ||
from tianshou.trainer import onpolicy_trainer | ||
from tianshou.data import Collector, ReplayBuffer | ||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--task', type=str, default='NChain-v0') | ||
parser.add_argument('--seed', type=int, default=1626) | ||
parser.add_argument('--buffer-size', type=int, default=50000) | ||
parser.add_argument('--epoch', type=int, default=5) | ||
parser.add_argument('--step-per-epoch', type=int, default=5) | ||
parser.add_argument('--collect-per-step', type=int, default=1) | ||
parser.add_argument('--training-num', type=int, default=1) | ||
parser.add_argument('--test-num', type=int, default=100) | ||
parser.add_argument('--logdir', type=str, default='log') | ||
parser.add_argument('--render', type=float, default=0.0) | ||
parser.add_argument('--rew-mean-prior', type=float, default=0.0) | ||
parser.add_argument('--rew-std-prior', type=float, default=1.0) | ||
parser.add_argument('--gamma', type=float, default=0.99) | ||
parser.add_argument('--eps', type=float, default=0.01) | ||
parser.add_argument('--add-done-loop', action='store_true') | ||
return parser.parse_known_args()[0] | ||
|
||
|
||
def test_psrl(args=get_args()): | ||
env = gym.make(args.task) | ||
if args.task == "NChain-v0": | ||
env.spec.reward_threshold = 3647 # described in PSRL paper | ||
print("reward threshold:", env.spec.reward_threshold) | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
# train_envs = gym.make(args.task) | ||
# train_envs = gym.make(args.task) | ||
train_envs = DummyVectorEnv( | ||
[lambda: gym.make(args.task) for _ in range(args.training_num)]) | ||
# test_envs = gym.make(args.task) | ||
test_envs = SubprocVectorEnv( | ||
[lambda: gym.make(args.task) for _ in range(args.test_num)]) | ||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
train_envs.seed(args.seed) | ||
test_envs.seed(args.seed) | ||
# model | ||
n_action = args.action_shape | ||
n_state = args.state_shape | ||
trans_count_prior = np.ones((n_state, n_action, n_state)) | ||
rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) | ||
rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) | ||
policy = PSRLPolicy( | ||
trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps, | ||
args.add_done_loop) | ||
# collector | ||
train_collector = Collector( | ||
policy, train_envs, ReplayBuffer(args.buffer_size)) | ||
test_collector = Collector(policy, test_envs) | ||
# log | ||
writer = SummaryWriter(args.logdir + '/' + args.task) | ||
|
||
def stop_fn(x): | ||
if env.spec.reward_threshold: | ||
return x >= env.spec.reward_threshold | ||
else: | ||
return False | ||
|
||
train_collector.collect(n_step=args.buffer_size, random=True) | ||
# trainer | ||
result = onpolicy_trainer( | ||
policy, train_collector, test_collector, args.epoch, | ||
args.step_per_epoch, args.collect_per_step, 1, | ||
args.test_num, 0, stop_fn=stop_fn, writer=writer, | ||
test_in_train=False) | ||
|
||
if __name__ == '__main__': | ||
pprint.pprint(result) | ||
# Let's watch its performance! | ||
policy.eval() | ||
test_envs.seed(args.seed) | ||
test_collector.reset() | ||
result = test_collector.collect(n_episode=[1] * args.test_num, | ||
render=args.render) | ||
print(f'Final reward: {result["rew"]}, length: {result["len"]}') | ||
elif env.spec.reward_threshold: | ||
assert result["best_reward"] >= env.spec.reward_threshold | ||
|
||
|
||
if __name__ == '__main__': | ||
test_psrl() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import torch | ||
import numpy as np | ||
from typing import Any, Dict, Union, Optional | ||
|
||
from tianshou.data import Batch | ||
from tianshou.policy import BasePolicy | ||
|
||
|
||
class PSRLModel(object): | ||
"""Implementation of Posterior Sampling Reinforcement Learning Model. | ||
|
||
:param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape | ||
(n_state, n_action, n_state). | ||
:param np.ndarray rew_mean_prior: means of the normal priors of rewards, | ||
with shape (n_state, n_action). | ||
:param np.ndarray rew_std_prior: standard deviations of the normal priors | ||
of rewards, with shape (n_state, n_action). | ||
:param float discount_factor: in [0, 1]. | ||
:param float epsilon: for precision control in value iteration. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
trans_count_prior: np.ndarray, | ||
rew_mean_prior: np.ndarray, | ||
rew_std_prior: np.ndarray, | ||
discount_factor: float, | ||
epsilon: float, | ||
) -> None: | ||
self.trans_count = trans_count_prior | ||
self.n_state, self.n_action = rew_mean_prior.shape | ||
self.rew_mean = rew_mean_prior | ||
self.rew_std = rew_std_prior | ||
self.rew_square_sum = np.zeros_like(rew_mean_prior) | ||
self.rew_std_prior = rew_std_prior | ||
self.discount_factor = discount_factor | ||
self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight | ||
self.eps = epsilon | ||
self.policy: np.ndarray | ||
self.value = np.zeros(self.n_state) | ||
self.updated = False | ||
self.__eps = np.finfo(np.float32).eps.item() | ||
|
||
def observe( | ||
self, | ||
trans_count: np.ndarray, | ||
rew_sum: np.ndarray, | ||
rew_square_sum: np.ndarray, | ||
rew_count: np.ndarray, | ||
) -> None: | ||
"""Add data into memory pool. | ||
yaofeng1998 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
For rewards, we have a normal prior at first. After we observed a | ||
reward for a given state-action pair, we use the mean value of our | ||
observations instead of the prior mean as the posterior mean. The | ||
standard deviations are in inverse proportion to the number of the | ||
corresponding observations. | ||
|
||
:param np.ndarray trans_count: the number of observations, with shape | ||
(n_state, n_action, n_state). | ||
:param np.ndarray rew_sum: total rewards, with shape | ||
(n_state, n_action). | ||
:param np.ndarray rew_square_sum: total rewards' squares, with shape | ||
(n_state, n_action). | ||
:param np.ndarray rew_count: the number of rewards, with shape | ||
(n_state, n_action). | ||
""" | ||
self.updated = False | ||
self.trans_count += trans_count | ||
sum_count = self.rew_count + rew_count | ||
self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count | ||
self.rew_square_sum += rew_square_sum | ||
raw_std2 = self.rew_square_sum / sum_count - self.rew_mean ** 2 | ||
self.rew_std = np.sqrt(1 / ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this. Can you explain it more? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The line of calculating |
||
sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2)) | ||
self.rew_count = sum_count | ||
|
||
def sample_trans_prob(self) -> np.ndarray: | ||
sample_prob = torch.distributions.Dirichlet( | ||
torch.from_numpy(self.trans_count)).sample().numpy() | ||
return sample_prob | ||
|
||
def sample_reward(self) -> np.ndarray: | ||
return np.random.normal(self.rew_mean, self.rew_std) | ||
|
||
def solve_policy(self) -> None: | ||
self.updated = True | ||
self.policy, self.value = self.value_iteration( | ||
self.sample_trans_prob(), | ||
self.sample_reward(), | ||
self.discount_factor, | ||
self.eps, | ||
self.value, | ||
) | ||
|
||
@staticmethod | ||
def value_iteration( | ||
trans_prob: np.ndarray, | ||
rew: np.ndarray, | ||
discount_factor: float, | ||
eps: float, | ||
value: np.ndarray, | ||
) -> np.ndarray: | ||
"""Value iteration solver for MDPs. | ||
|
||
:param np.ndarray trans_prob: transition probabilities, with shape | ||
(n_state, n_action, n_state). | ||
:param np.ndarray rew: rewards, with shape (n_state, n_action). | ||
:param float eps: for precision control. | ||
:param float discount_factor: in [0, 1]. | ||
:param np.ndarray value: the initialize value of value array, with | ||
shape (n_state, ). | ||
|
||
:return: the optimal policy with shape (n_state, ). | ||
""" | ||
Q = rew + discount_factor * trans_prob.dot(value) | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
new_value = Q.max(axis=1) | ||
while not np.allclose(new_value, value, eps): | ||
value = new_value | ||
Q = rew + discount_factor * trans_prob.dot(value) | ||
new_value = Q.max(axis=1) | ||
# this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly | ||
Q += eps * np.random.randn(*Q.shape) | ||
return Q.argmax(axis=1), new_value | ||
yaofeng1998 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __call__( | ||
self, | ||
obs: np.ndarray, | ||
state: Optional[Any] = None, | ||
info: Dict[str, Any] = {}, | ||
Comment on lines
+129
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two arguments are not used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's to be consistent with model API There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean by "model API"? PSRLModel is the first model and I don't see it has any base-class. |
||
) -> np.ndarray: | ||
if not self.updated: | ||
self.solve_policy() | ||
return self.policy[obs] | ||
|
||
|
||
class PSRLPolicy(BasePolicy): | ||
"""Implementation of Posterior Sampling Reinforcement Learning. | ||
|
||
Reference: Strens M. A Bayesian framework for reinforcement learning [C] | ||
//ICML. 2000, 2000: 943-950. | ||
|
||
:param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape | ||
(n_state, n_action, n_state). | ||
:param np.ndarray rew_mean_prior: means of the normal priors of rewards, | ||
with shape (n_state, n_action). | ||
:param np.ndarray rew_std_prior: standard deviations of the normal priors | ||
of rewards, with shape (n_state, n_action). | ||
:param float discount_factor: in [0, 1]. | ||
:param float epsilon: for precision control in value iteration. | ||
:param bool add_done_loop: whether to add an extra self-loop for the | ||
terminal state in MDP, defaults to False. | ||
|
||
.. seealso:: | ||
|
||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed | ||
explanation. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
trans_count_prior: np.ndarray, | ||
rew_mean_prior: np.ndarray, | ||
rew_std_prior: np.ndarray, | ||
discount_factor: float = 0.99, | ||
epsilon: float = 0.01, | ||
add_done_loop: bool = False, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
assert ( | ||
0.0 <= discount_factor <= 1.0 | ||
), "discount factor should be in [0, 1]" | ||
self.model = PSRLModel( | ||
trans_count_prior, rew_mean_prior, rew_std_prior, | ||
discount_factor, epsilon) | ||
self._add_done_loop = add_done_loop | ||
|
||
def forward( | ||
self, | ||
batch: Batch, | ||
state: Optional[Union[dict, Batch, np.ndarray]] = None, | ||
**kwargs: Any, | ||
) -> Batch: | ||
"""Compute action over the given batch data with PSRL model. | ||
|
||
:return: A :class:`~tianshou.data.Batch` with "act" key containing | ||
the action. | ||
|
||
.. seealso:: | ||
|
||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for | ||
more detailed explanation. | ||
""" | ||
act = self.model(batch.obs, state=state, info=batch.info) | ||
return Batch(act=act) | ||
|
||
Trinkle23897 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def learn( | ||
self, batch: Batch, *args: Any, **kwargs: Any | ||
) -> Dict[str, float]: | ||
n_s, n_a = self.model.n_state, self.model.n_action | ||
trans_count = np.zeros((n_s, n_a, n_s)) | ||
rew_sum = np.zeros((n_s, n_a)) | ||
rew_square_sum = np.zeros((n_s, n_a)) | ||
rew_count = np.zeros((n_s, n_a)) | ||
for b in batch.split(size=1): | ||
obs, act, obs_next = b.obs, b.act, b.obs_next | ||
trans_count[obs, act, obs_next] += 1 | ||
rew_sum[obs, act] += b.rew | ||
rew_square_sum[obs, act] += b.rew ** 2 | ||
rew_count[obs, act] += 1 | ||
if self._add_done_loop and b.done: | ||
# special operation for terminal states: add a self-loop | ||
trans_count[obs_next, :, obs_next] += 1 | ||
rew_count[obs_next, :] += 1 | ||
self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) | ||
return { | ||
"psrl/rew_mean": self.model.rew_mean.mean(), | ||
"psrl/rew_std": self.model.rew_std.mean(), | ||
} |
Uh oh!
There was an error while loading. Please reload this page.