From 8229b479bf500a229e453b7f67653217a179e7a8 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Tue, 19 Jan 2021 17:00:52 +0800 Subject: [PATCH 001/104] add first version of cached replay buffer(baseline), add standard api for replaybuffer --- tianshou/data/buffer.py | 336 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 327 insertions(+), 9 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 74299df9d..dd97585dd 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -11,7 +11,8 @@ class ReplayBuffer: """:class:`~tianshou.data.ReplayBuffer` stores data generated from \ - interaction between the policy and environment. + interaction between the policy and environment. ReplayBuffer can be \ + considered as a specialized form(management) of Batch. The current implementation of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: @@ -138,6 +139,7 @@ class ReplayBuffer: index when using frame-stack sampling method, defaults to False. This feature is not supported in Prioritized Replay Buffer currently. """ + _reserved_keys = {'obs', 'act', 'rew', 'done', 'obs_next', 'info', 'policy'} def __init__( self, @@ -148,8 +150,10 @@ def __init__( sample_avail: bool = False, ) -> None: super().__init__() + #TODO _maxsize == 0 handle self._maxsize = size self._indices = np.arange(size) + # consider move stacking option to another self.stack_num = stack_num self._avail = sample_avail and stack_num > 1 self._avail_index: List[int] = [] @@ -189,18 +193,18 @@ def __getstate__(self) -> dict: state = {k: v for k, v in self.__dict__.items() if k not in exclude} return state + def __setattr__(self, key: str, value: Any) -> None: + """Set self.key = value.""" + assert key not in self._reserved_keys, ( + "key '{}' is reserved and cannot be assigned".format(key)) + super().__setattr__(key, value) + def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] - if isinstance(inst, (torch.Tensor, np.ndarray)): - if inst.shape != value.shape[1:]: - raise ValueError( - "Cannot add data to a buffer with different shape with key" - f" {name}, expect {value.shape[1:]}, given {inst.shape}." - ) try: value[self._index] = inst except KeyError: @@ -219,6 +223,10 @@ def stack_num(self, num: int) -> None: def update(self, buffer: "ReplayBuffer") -> None: """Move the data from the given buffer to self.""" + # TODO 'one by one copying' can be greatly improved. + # Can we move data as a whole batch to save time? + # what if self._maxsize << buffer._maxize, can we detect that + # and just ignore those element to be rewrite? if len(buffer) == 0: return i = begin = buffer._index % len(buffer) @@ -242,7 +250,9 @@ def add( policy: Optional[Union[dict, Batch]] = {}, **kwargs: Any, ) -> None: - """Add a batch of data into replay buffer.""" + """Add a batch of data into replay buffer. + Expect all input to be batch, dict, or numpy array""" + # TODO should we consider to support batch input? assert isinstance( info, (dict, Batch) ), "You should return a dict in the last argument of env.step()." @@ -379,6 +389,32 @@ def __getitem__( policy=self.get(index, "policy"), ) + def set_batch(self, batch: "Batch"): + """Manually choose the batch you want the ReplayBuffer to manage. This + method should be called instantly after the ReplayBuffer is initialised. + """ + assert self._meta.is_empty(), "This method cannot be called after add() method" + self._meta = batch + assert not self._is_meta_corrupted(), ( + "Input batch doesn't meet ReplayBuffer's data form requirement.") + + def _is_meta_corrupted(self)-> bool: + """Assert if self._meta: Batch is still in legal form. + """ + #TODO do we need to check the chlid? is cache + if set(self._meta.keys()) != self._reserved_keys: + return True + for v in self._meta.values(): + if isinstance(v, Batch): + if not v.is_empty() and v.shape[0]!= self._maxsize: + return True + elif isinstance(v, np.ndarray): + if v.shape[0]!= self._maxsize: + return True + else: + return True + return False + def save_hdf5(self, path: str) -> None: """Save replay buffer to HDF5 file.""" with h5py.File(path, "w") as f: @@ -394,6 +430,37 @@ def load_hdf5( buf.__setstate__(from_hdf5(f, device=device)) return buf + def start(self, index): + """return start indices of given indices""" + assert index < len(self) and index >= 0, "Input index illegal." + sorted_starts = np.sort(self.starts()) + ret = np.searchsorted(sorted_starts, index, side="right") - 1 + ret[ret < 0] = sorted_starts[-1] + return ret + + def next(self, index): + """return next n step indices""" + assert index < len(self) and index >= 0, "Input index illegal." + return (index + ~(self.done[index]|index==self._index))%len(self) + + def starts(self): + """return indices of all episodes""" + if len(self) > 0: + return (self.ends()+1)%len(self) + else: + return np.array([], dtype = np.int) + + def ends(self): + """return unfinished indices.""" + if len(self) > 0: + last_write_in = int((self._index - 1)%len(self)) + if self.done[last_write_in]: + return np.where(self.done[:len(self)])[0] + else: + return np.append(np.where(self.done[:len(self)])[0], last_write_in) + else: + return np.array([], dtype = np.int) + class ListReplayBuffer(ReplayBuffer): """List-based replay buffer. @@ -505,7 +572,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def update_weight( self, indice: Union[np.ndarray], - new_weight: Union[np.ndarray, torch.Tensor] + new_weight: np.ndarray ) -> None: """Update priority weight by indice in this buffer. @@ -530,3 +597,254 @@ def __getitem__( policy=self.get(index, "policy"), weight=self.weight[index], ) + +class CachedReplayBuffer(ReplayBuffer): + """CachedReplayBuffer can be considered as a combination of one main buffer + and a list of cached_buffers. It's designed to used by collector to allow + parallel collecting in collector. In CachedReplayBuffer is not organized + chronologically, but standard API like start()/starts()/ends/next() are provided + to help CachedReplayBuffer to be used just like ReplayBuffer. + """ + def __init__( + self, + size: int, + cached_buf_n: int, + max_length: int, + **kwargs: Any, + ) -> None: + """ + TODO support stack in the future + """ + assert cached_buf_n > 0 + # TODO what if people don't care about how buffer is organized + assert max_length > 0 + if cached_buf_n == 1: + import warnings + warnings.warn( + "CachedReplayBuffer with cached_buf_n = 1 will cause low efficiency. " + "Please consider using ReplayBuffer which is not in cached form.", + Warning) + + _maxsize = size+cached_buf_n*max_length + self.cached_bufs_n = cached_buf_n + # TODO see if we can generalize to all kinds of buufer + self.main_buf = ReplayBuffer(size, **kwargs) + # TODO cached_bufs can be consider to be replced by vector + # buffer in the future + self.cached_bufs = np.array([ReplayBuffer(max_length, **kwargs) + for _ in range(cached_buf_n)]) + super().__init__(size= _maxsize, **kwargs) + # TODO support, or just delete stack_num option from Replay buffer for now + assert self.stack_num == 1 + + def __len__(self) -> int: + """Return len(self).""" + return len(self.main_buf) + np.sum([len(b) for b in self.cached_bufs]) + + def update(self, buffer: "ReplayBuffer") -> int: + """CachedReplayBuffer will only update data from buffer which is in + episode form. Return an integer which indicates the number of steps + being ignored.""" + # For now update method copy element one by one, which is too slow. + if isinstance(buffer, CachedReplayBuffer): + buffer = buffer.main_buf + # now treat buffer like a normal ReplayBuffer and remove those incomplete steps + if len(buffer) == 0: + return 0 + diposed_count = 0 + # TODO use standard API now + end = (buffer._index - 1) % len(buffer) + begin = buffer._index % len(buffer) + while True: + if buffer.done[end] > 0: + break + else: + diposed_count = diposed_count + 1 + if end == begin: + assert diposed_count == len(self) + return diposed_count + end = (end - 1) % len(buffer) + while True: + self.main_buf.add(**buffer[begin]) + if begin == end: + return diposed_count + begin = (begin + 1) % len(buffer) + + def add( + self, + obs: Any, + act: Any, + rew: Union[Number, np.number, np.ndarray], + done: Union[Number, np.number, np.bool_], + obs_next: Any = None, + info: Optional[Union[dict, Batch]] = {}, + policy: Optional[Union[dict, Batch]] = {}, + index: Optional[Union[int, np.integer, np.ndarray, List[int]]] = None, + **kwargs: Any + ) -> None: + """ + + """ + if index is None: + index = range(self.cached_bufs_n) + index = np.atleast_1d(index).astype(np.int) + assert(index.ndim == 1) + + obs = np.atleast_1d(obs) + act = np.atleast_1d(act) + rew = np.atleast_1d(rew) + done = np.atleast_1d(done) + # TODO ugly code + if isinstance(obs_next, Batch) and obs_next.is_empty(): + obs_next = None + if isinstance(info, Batch) and info.is_empty(): + info = {} + if isinstance(policy, Batch) and policy.is_empty(): + policy = {} + obs_next = np.atleast_1d([None]*len(index)) if obs_next is None else np.atleast_1d(obs_next) + info = np.atleast_1d([{}]*len(index)) if info == {} else np.atleast_1d(info) + policy = np.atleast_1d([{}]*len(index)) if policy == {} else np.atleast_1d(policy) + + # TODO what if data is already in episodes, what if i want to add mutiple data ? + # can accelerate + if self._meta.is_empty(): + self._cache_initialise(obs[0], act[0], rew[0], done[0], obs_next[0], + info[0], policy[0]) + # now we add data to selected cached_bufs one by one + cached_bufs_slice = self.cached_bufs[index] + for i, b in enumerate(cached_bufs_slice): + b.add(obs[i], act[i], rew[i], done[i], + obs_next[i], info[i], policy[i]) + return self._main_buf_update() + + def _main_buf_update(self): + lens = np.zeros((self.cached_bufs_n, ), dtype = np.int) + rews = np.zeros((self.cached_bufs_n, )) + start_indexs = np.zeros((self.cached_bufs_n, ), dtype = np.int) + for i, buf in enumerate(self.cached_bufs): + if buf.done[buf._index - 1] > 0: + lens[i] = len(buf) + rews[i] = np.sum(buf.rew[:lens[i]]) + start_indexs[i] = self.main_buf._index + if self.main_buf._maxsize > 0: + # _maxsize of main_buf might be 0 in test collector. + self.main_buf.update(buf) + buf.reset() + return lens, rews, start_indexs + + def reset(self) -> None: + for buf in self.cached_bufs: + buf.reset() + self.main_buf.reset() + self._avail_index = [] + #TODO finish + + def sample(self, batch_size: int, + is_from_main_buf = False) -> Tuple[Batch, np.ndarray]: + if is_from_main_buf: + return self.main_buf.sample(batch_size) + + _all = np.arange(len(self), dtype=np.int) + start = len(self.main_buf) + add = self.main_buf._maxsize - len(self.main_buf) + for buf in self.cached_bufs: + end = start + len(buf) + _all[start:end] = _all[start:end] + add + start = end + add = add + buf._maxsize - len(buf) + indice = np.random.choice(_all, batch_size) + assert len(indice) > 0, "No available indice can be sampled." + return self[indice], indice + + def get( + self, + indice: Union[slice, int, np.integer, np.ndarray], + key: str, + stack_num: Optional[int] = None, + ) -> Union[Batch, np.ndarray]: + if stack_num is None: + stack_num = self.stack_num + assert(stack_num == 1) + #TODO support stack + return super().get(indice, key, stack_num) + + def _cache_initialise( + self, + obs: Any, + act: Any, + rew: Union[Number, np.number, np.ndarray], + done: Union[Number, np.number, np.bool_], + obs_next: Any = None, + info: Optional[Union[dict, Batch]] = {}, + policy: Optional[Union[dict, Batch]] = {} + ) -> None: + assert(self._meta.is_empty()) + # to initialise self._meta + super().add(obs, act, rew, done, obs_next, info, policy) + super().reset() #TODO delete useless varible? + del self._index + del self._size + self.main_buf.set_batch(self._meta[:self.main_buf._maxsize]) + start = self.main_buf._maxsize + for buf in self.cached_bufs: + end = start + buf._maxsize + buf.set_batch(self._meta[start: end]) + start = end + + def ends(self): + return np.concatenate( + [self.main_buf.ends(), *[b.ends() for b in self.cached_bufs]]) + + def starts(self): + return np.concatenate( + [self.main_buf.starts(), *[b.starts() for b in self.cached_bufs]]) + + def next(self, index): + assert index >= 0 and index < self._maxsize, "Input index illegal." + all_buffer = [self.main_buf, *self.cached_bufs] + ret = index.copy() + upper = 0 + lower = 0 + for b in all_buffer: + lower = upper + upper += b._maxsize + mask = ret>=lower and ret=lower and ret=0 and global_index=lower and global_index Date: Wed, 20 Jan 2021 16:46:22 +0800 Subject: [PATCH 002/104] add cached buffer, vec buffer --- tianshou/data/buffer.py | 342 +++++++++++++++++++++++++++++----------- 1 file changed, 252 insertions(+), 90 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index dd97585dd..202f9574d 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -434,9 +434,9 @@ def start(self, index): """return start indices of given indices""" assert index < len(self) and index >= 0, "Input index illegal." sorted_starts = np.sort(self.starts()) - ret = np.searchsorted(sorted_starts, index, side="right") - 1 - ret[ret < 0] = sorted_starts[-1] - return ret + start_indices = np.searchsorted(sorted_starts, index, side="right") - 1 + start_indices[start_indices < 0] = sorted_starts[-1] + return start_indices def next(self, index): """return next n step indices""" @@ -451,7 +451,7 @@ def starts(self): return np.array([], dtype = np.int) def ends(self): - """return unfinished indices.""" + """return last indices of finished episodes. """ if len(self) > 0: last_write_in = int((self._index - 1)%len(self)) if self.done[last_write_in]: @@ -598,6 +598,179 @@ def __getitem__( weight=self.weight[index], ) +class VecReplayBuffer(ReplayBuffer): + def __init__( + self, + size: int, + buf_n: int, + **kwargs: Any, + ) -> None: + # TODO can size==0? + assert size > 0 + assert buf_n > 0 + if buf_n == 1: + import warnings + warnings.warn( + "VecReplayBuffer with buf_n = 1 will cause low efficiency. " + "Please consider using ReplayBuffer which is not in vector form.", + Warning) + _maxsize = buf_n*size + self.buf_n = buf_n + self.bufs = np.array([ReplayBuffer(size, **kwargs) + for _ in range(buf_n)]) + super().__init__(size= _maxsize, **kwargs) + + def __len__(self) -> int: + return np.sum([len(b) for b in self.bufs]) + + def update(self, **kwargs): + raise NotImplementedError + + def add( + self, + obs: Any, + act: Any, + rew: Union[Number, np.number, np.ndarray], + done: Union[Number, np.number, np.bool_], + obs_next: Any = None, + info: Optional[Union[dict, Batch]] = {}, + policy: Optional[Union[dict, Batch]] = {}, + index: Optional[Union[int, np.integer, np.ndarray, List[int]]] = None, + type_check: bool = True, + **kwargs: Any + ) -> None: + if type_check: + if index is None: + index = range(self.cached_bufs_n) + index = np.atleast_1d(index).astype(np.int) + assert(index.ndim == 1) + + obs = np.atleast_1d(obs) + act = np.atleast_1d(act) + rew = np.atleast_1d(rew) + done = np.atleast_1d(done) + # TODO ugly code + if isinstance(obs_next, Batch) and obs_next.is_empty(): + obs_next = None + if isinstance(info, Batch) and info.is_empty(): + info = {} + if isinstance(policy, Batch) and policy.is_empty(): + policy = {} + obs_next = np.atleast_1d([None]*len(index)) if obs_next is None else np.atleast_1d(obs_next) + info = np.atleast_1d([{}]*len(index)) if info == {} else np.atleast_1d(info) + policy = np.atleast_1d([{}]*len(index)) if policy == {} else np.atleast_1d(policy) + + # can accelerate + if self._meta.is_empty(): + self._initialise(obs[0], act[0], rew[0], done[0], obs_next[0], + info[0], policy[0]) + # now we add data to selected bufs one by one + bufs_slice = self.bufs[index] + for i, b in enumerate(bufs_slice): + b.add(obs[i], act[i], rew[i], done[i], + obs_next[i], info[i], policy[i]) + + def _initialise( + self, + obs: Any, + act: Any, + rew: Union[Number, np.number, np.ndarray], + done: Union[Number, np.number, np.bool_], + obs_next: Any = None, + info: Optional[Union[dict, Batch]] = {}, + policy: Optional[Union[dict, Batch]] = {} + ) -> None: + assert(self._meta.is_empty()) + # to initialise self._meta + super().add(obs, act, rew, done, obs_next, info, policy) + super().reset() #TODO delete useless varible? + del self._index + del self._size + del self._avail_index + self._set_batch() + + def _set_batch(self): + start = 0 + for buf in self.bufs: + end = start + buf._maxsize + buf.set_batch(self._meta[start: end]) + start = end + + def set_batch(self, batch: "Batch"): + """Manually choose the batch you want the ReplayBuffer to manage. This + method should be called instantly after the ReplayBuffer is initialised. + """ + assert self.bufs.is_empty(), "This method cannot be called after add() method" + self._meta = batch + assert not self._is_meta_corrupted(), ( + "Input batch doesn't meet ReplayBuffer's data form requirement.") + self._set_batch() + + def reset(self) -> None: + for buf in self.bufs: + buf.reset() + #TODO finish + + def sample(self, batch_size: int, return_only_indice: bool = False) -> Tuple[Batch, np.ndarray]: + _all = np.arange(len(self), dtype=np.int) + start = 0 + add = 0 + for buf in self.bufs: + end = start + len(buf) + _all[start:end] = _all[start:end] + add + start = end + add = add + buf._maxsize - len(buf) + # TODO consider making _all a seperate method + indice = np.random.choice(_all, batch_size) + assert len(indice) > 0, "No available indice can be sampled." + if return_only_indice: + return indice + else: + return self[indice], indice + + def get( + self, + indice: Union[slice, int, np.integer, np.ndarray], + key: str, + stack_num: Optional[int] = None, + ) -> Union[Batch, np.ndarray]: + if stack_num is None: + stack_num = self.stack_num + assert(stack_num == 1) + #TODO support stack + return super().get(indice, key, stack_num) + + def ends(self): + return np.concatenate([b.ends() for b in self.bufs]) + + def starts(self): + return np.concatenate([b.starts() for b in self.cached_buffer.bufs]) + + def next(self, index): + assert index >= 0 and index < self._maxsize, "Input index illegal." + next_indices = np.full(index.shape, -1) + upper = 0 + lower = 0 + for b in self.bufs: + lower = upper + upper += b._maxsize + mask = next_indices>=lower and next_indices=lower and start_indices None: """ TODO support stack in the future """ - assert cached_buf_n > 0 + assert cached_buffer_n > 0 # TODO what if people don't care about how buffer is organized assert max_length > 0 - if cached_buf_n == 1: + if cached_buffer_n == 1: import warnings warnings.warn( - "CachedReplayBuffer with cached_buf_n = 1 will cause low efficiency. " + "CachedReplayBuffer with cached_buffer_n = 1 will cause low efficiency. " "Please consider using ReplayBuffer which is not in cached form.", Warning) - _maxsize = size+cached_buf_n*max_length - self.cached_bufs_n = cached_buf_n - # TODO see if we can generalize to all kinds of buufer - self.main_buf = ReplayBuffer(size, **kwargs) - # TODO cached_bufs can be consider to be replced by vector + _maxsize = size+cached_buffer_n*max_length + self.cached_bufs_n = cached_buffer_n + # TODO see if we can generalize to all kinds of buffer + self.main_buffer = ReplayBuffer(size, **kwargs) + # TODO cached_buffer can be consider to be replced by vector # buffer in the future - self.cached_bufs = np.array([ReplayBuffer(max_length, **kwargs) - for _ in range(cached_buf_n)]) + self.cached_buffer = VecReplayBuffer(max_length, cached_buffer_n, **kwargs) super().__init__(size= _maxsize, **kwargs) # TODO support, or just delete stack_num option from Replay buffer for now assert self.stack_num == 1 def __len__(self) -> int: """Return len(self).""" - return len(self.main_buf) + np.sum([len(b) for b in self.cached_bufs]) + return len(self.main_buffer) + len(self.cached_buffer) def update(self, buffer: "ReplayBuffer") -> int: """CachedReplayBuffer will only update data from buffer which is in @@ -647,7 +819,7 @@ def update(self, buffer: "ReplayBuffer") -> int: being ignored.""" # For now update method copy element one by one, which is too slow. if isinstance(buffer, CachedReplayBuffer): - buffer = buffer.main_buf + buffer = buffer.main_buffer # now treat buffer like a normal ReplayBuffer and remove those incomplete steps if len(buffer) == 0: return 0 @@ -665,7 +837,7 @@ def update(self, buffer: "ReplayBuffer") -> int: return diposed_count end = (end - 1) % len(buffer) while True: - self.main_buf.add(**buffer[begin]) + self.main_buffer.add(**buffer[begin]) if begin == end: return diposed_count begin = (begin + 1) % len(buffer) @@ -708,46 +880,45 @@ def add( # TODO what if data is already in episodes, what if i want to add mutiple data ? # can accelerate if self._meta.is_empty(): - self._cache_initialise(obs[0], act[0], rew[0], done[0], obs_next[0], + self._initialise(obs[0], act[0], rew[0], done[0], obs_next[0], info[0], policy[0]) - # now we add data to selected cached_bufs one by one - cached_bufs_slice = self.cached_bufs[index] - for i, b in enumerate(cached_bufs_slice): - b.add(obs[i], act[i], rew[i], done[i], - obs_next[i], info[i], policy[i]) + + self.cached_buffer.add(obs, act, rew, done, obs_next, + info, policy, index, False, **kwargs) return self._main_buf_update() def _main_buf_update(self): lens = np.zeros((self.cached_bufs_n, ), dtype = np.int) rews = np.zeros((self.cached_bufs_n, )) start_indexs = np.zeros((self.cached_bufs_n, ), dtype = np.int) - for i, buf in enumerate(self.cached_bufs): + for i, buf in enumerate(self.cached_buffer.bufs): if buf.done[buf._index - 1] > 0: lens[i] = len(buf) rews[i] = np.sum(buf.rew[:lens[i]]) - start_indexs[i] = self.main_buf._index - if self.main_buf._maxsize > 0: - # _maxsize of main_buf might be 0 in test collector. - self.main_buf.update(buf) + start_indexs[i] = self.main_buffer._index + if self.main_buffer._maxsize > 0: + # _maxsize of main_buffer might be 0 in test collector. + self.main_buffer.update(buf) buf.reset() return lens, rews, start_indexs def reset(self) -> None: - for buf in self.cached_bufs: - buf.reset() - self.main_buf.reset() + self.cached_buffer.reset() + self.main_buffer.reset() self._avail_index = [] #TODO finish def sample(self, batch_size: int, is_from_main_buf = False) -> Tuple[Batch, np.ndarray]: if is_from_main_buf: - return self.main_buf.sample(batch_size) + return self.main_buffer.sample(batch_size) + + # TODO use all() method to replace _all = np.arange(len(self), dtype=np.int) - start = len(self.main_buf) - add = self.main_buf._maxsize - len(self.main_buf) - for buf in self.cached_bufs: + start = len(self.main_buffer) + add = self.main_buffer._maxsize - len(self.main_buffer) + for buf in self.cached_buffer.bufs: end = start + len(buf) _all[start:end] = _all[start:end] + add start = end @@ -768,7 +939,7 @@ def get( #TODO support stack return super().get(indice, key, stack_num) - def _cache_initialise( + def _initialise( self, obs: Any, act: Any, @@ -784,67 +955,58 @@ def _cache_initialise( super().reset() #TODO delete useless varible? del self._index del self._size - self.main_buf.set_batch(self._meta[:self.main_buf._maxsize]) - start = self.main_buf._maxsize - for buf in self.cached_bufs: - end = start + buf._maxsize - buf.set_batch(self._meta[start: end]) - start = end - + self.main_buffer.set_batch(self._meta[:self.main_buffer._maxsize]) + self.cached_buffer.set_batch(self._meta[self.main_buffer._maxsize:]) + + + + #TODO add standard API for vec buffer and use vec buffer to replace self.cached_buffer.bufs def ends(self): return np.concatenate( - [self.main_buf.ends(), *[b.ends() for b in self.cached_bufs]]) + [self.main_buffer.ends(), self.cached_buffer.ends()]) def starts(self): return np.concatenate( - [self.main_buf.starts(), *[b.starts() for b in self.cached_bufs]]) + [self.main_buffer.starts(), self.cached_buffer.starts()]) def next(self, index): assert index >= 0 and index < self._maxsize, "Input index illegal." - all_buffer = [self.main_buf, *self.cached_bufs] - ret = index.copy() - upper = 0 - lower = 0 - for b in all_buffer: - lower = upper - upper += b._maxsize - mask = ret>=lower and ret=lower and ret=0 and global_index=lower and global_index= 0 and index < self._maxsize, "Input index illegal." + start_indices = np.full(index.shape, -1) + mask = index=0 and global_index=lower and global_index Date: Wed, 20 Jan 2021 16:47:26 +0800 Subject: [PATCH 003/104] simple pep8 fix --- tianshou/data/buffer.py | 126 +++++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 59 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 202f9574d..9a5510654 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -139,7 +139,8 @@ class ReplayBuffer: index when using frame-stack sampling method, defaults to False. This feature is not supported in Prioritized Replay Buffer currently. """ - _reserved_keys = {'obs', 'act', 'rew', 'done', 'obs_next', 'info', 'policy'} + _reserved_keys = {'obs', 'act', 'rew', + 'done', 'obs_next', 'info', 'policy'} def __init__( self, @@ -150,7 +151,7 @@ def __init__( sample_avail: bool = False, ) -> None: super().__init__() - #TODO _maxsize == 0 handle + # TODO _maxsize == 0 handle self._maxsize = size self._indices = np.arange(size) # consider move stacking option to another @@ -196,7 +197,7 @@ def __getstate__(self) -> dict: def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" assert key not in self._reserved_keys, ( - "key '{}' is reserved and cannot be assigned".format(key)) + "key '{}' is reserved and cannot be assigned".format(key)) super().__setattr__(key, value) def _add_to_buffer(self, name: str, inst: Any) -> None: @@ -223,10 +224,10 @@ def stack_num(self, num: int) -> None: def update(self, buffer: "ReplayBuffer") -> None: """Move the data from the given buffer to self.""" - # TODO 'one by one copying' can be greatly improved. - # Can we move data as a whole batch to save time? - # what if self._maxsize << buffer._maxize, can we detect that - # and just ignore those element to be rewrite? + # TODO 'one by one copying' can be greatly improved. + # Can we move data as a whole batch to save time? + # what if self._maxsize << buffer._maxize, can we detect that + # and just ignore those element to be rewrite? if len(buffer) == 0: return i = begin = buffer._index % len(buffer) @@ -398,18 +399,18 @@ def set_batch(self, batch: "Batch"): assert not self._is_meta_corrupted(), ( "Input batch doesn't meet ReplayBuffer's data form requirement.") - def _is_meta_corrupted(self)-> bool: + def _is_meta_corrupted(self) -> bool: """Assert if self._meta: Batch is still in legal form. """ - #TODO do we need to check the chlid? is cache + # TODO do we need to check the chlid? is cache if set(self._meta.keys()) != self._reserved_keys: return True for v in self._meta.values(): if isinstance(v, Batch): - if not v.is_empty() and v.shape[0]!= self._maxsize: + if not v.is_empty() and v.shape[0] != self._maxsize: return True elif isinstance(v, np.ndarray): - if v.shape[0]!= self._maxsize: + if v.shape[0] != self._maxsize: return True else: return True @@ -441,25 +442,25 @@ def start(self, index): def next(self, index): """return next n step indices""" assert index < len(self) and index >= 0, "Input index illegal." - return (index + ~(self.done[index]|index==self._index))%len(self) + return (index + ~(self.done[index] | index == self._index)) % len(self) def starts(self): """return indices of all episodes""" if len(self) > 0: - return (self.ends()+1)%len(self) + return (self.ends()+1) % len(self) else: - return np.array([], dtype = np.int) + return np.array([], dtype=np.int) def ends(self): """return last indices of finished episodes. """ if len(self) > 0: - last_write_in = int((self._index - 1)%len(self)) + last_write_in = int((self._index - 1) % len(self)) if self.done[last_write_in]: return np.where(self.done[:len(self)])[0] else: return np.append(np.where(self.done[:len(self)])[0], last_write_in) else: - return np.array([], dtype = np.int) + return np.array([], dtype=np.int) class ListReplayBuffer(ReplayBuffer): @@ -598,6 +599,7 @@ def __getitem__( weight=self.weight[index], ) + class VecReplayBuffer(ReplayBuffer): def __init__( self, @@ -605,7 +607,7 @@ def __init__( buf_n: int, **kwargs: Any, ) -> None: - # TODO can size==0? + # TODO can size==0? assert size > 0 assert buf_n > 0 if buf_n == 1: @@ -617,9 +619,9 @@ def __init__( _maxsize = buf_n*size self.buf_n = buf_n self.bufs = np.array([ReplayBuffer(size, **kwargs) - for _ in range(buf_n)]) - super().__init__(size= _maxsize, **kwargs) - + for _ in range(buf_n)]) + super().__init__(size=_maxsize, **kwargs) + def __len__(self) -> int: return np.sum([len(b) for b in self.bufs]) @@ -638,7 +640,7 @@ def add( index: Optional[Union[int, np.integer, np.ndarray, List[int]]] = None, type_check: bool = True, **kwargs: Any - ) -> None: + ) -> None: if type_check: if index is None: index = range(self.cached_bufs_n) @@ -649,27 +651,30 @@ def add( act = np.atleast_1d(act) rew = np.atleast_1d(rew) done = np.atleast_1d(done) - # TODO ugly code + # TODO ugly code if isinstance(obs_next, Batch) and obs_next.is_empty(): obs_next = None if isinstance(info, Batch) and info.is_empty(): info = {} if isinstance(policy, Batch) and policy.is_empty(): policy = {} - obs_next = np.atleast_1d([None]*len(index)) if obs_next is None else np.atleast_1d(obs_next) - info = np.atleast_1d([{}]*len(index)) if info == {} else np.atleast_1d(info) - policy = np.atleast_1d([{}]*len(index)) if policy == {} else np.atleast_1d(policy) + obs_next = np.atleast_1d( + [None]*len(index)) if obs_next is None else np.atleast_1d(obs_next) + info = np.atleast_1d( + [{}]*len(index)) if info == {} else np.atleast_1d(info) + policy = np.atleast_1d( + [{}]*len(index)) if policy == {} else np.atleast_1d(policy) # can accelerate if self._meta.is_empty(): self._initialise(obs[0], act[0], rew[0], done[0], obs_next[0], - info[0], policy[0]) + info[0], policy[0]) # now we add data to selected bufs one by one bufs_slice = self.bufs[index] for i, b in enumerate(bufs_slice): b.add(obs[i], act[i], rew[i], done[i], - obs_next[i], info[i], policy[i]) - + obs_next[i], info[i], policy[i]) + def _initialise( self, obs: Any, @@ -683,7 +688,7 @@ def _initialise( assert(self._meta.is_empty()) # to initialise self._meta super().add(obs, act, rew, done, obs_next, info, policy) - super().reset() #TODO delete useless varible? + super().reset() # TODO delete useless varible? del self._index del self._size del self._avail_index @@ -709,7 +714,7 @@ def set_batch(self, batch: "Batch"): def reset(self) -> None: for buf in self.bufs: buf.reset() - #TODO finish + # TODO finish def sample(self, batch_size: int, return_only_indice: bool = False) -> Tuple[Batch, np.ndarray]: _all = np.arange(len(self), dtype=np.int) @@ -717,7 +722,7 @@ def sample(self, batch_size: int, return_only_indice: bool = False) -> Tuple[Bat add = 0 for buf in self.bufs: end = start + len(buf) - _all[start:end] = _all[start:end] + add + _all[start:end] = _all[start:end] + add start = end add = add + buf._maxsize - len(buf) # TODO consider making _all a seperate method @@ -737,7 +742,7 @@ def get( if stack_num is None: stack_num = self.stack_num assert(stack_num == 1) - #TODO support stack + # TODO support stack return super().get(indice, key, stack_num) def ends(self): @@ -754,7 +759,7 @@ def next(self, index): for b in self.bufs: lower = upper upper += b._maxsize - mask = next_indices>=lower and next_indices= lower and next_indices < upper next_indices[mask] = b.next(next_indices[mask]-lower)+lower return next_indices @@ -766,7 +771,7 @@ def start(self, index): for b in self.bufs: lower = upper upper += b._maxsize - mask = start_indices>=lower and start_indices= lower and start_indices < upper start_indices[mask] = b.start(start_indices[mask]-lower)+lower return start_indices @@ -778,6 +783,7 @@ class CachedReplayBuffer(ReplayBuffer): chronologically, but standard API like start()/starts()/ends/next() are provided to help CachedReplayBuffer to be used just like ReplayBuffer. """ + def __init__( self, size: int, @@ -797,18 +803,19 @@ def __init__( "CachedReplayBuffer with cached_buffer_n = 1 will cause low efficiency. " "Please consider using ReplayBuffer which is not in cached form.", Warning) - + _maxsize = size+cached_buffer_n*max_length self.cached_bufs_n = cached_buffer_n # TODO see if we can generalize to all kinds of buffer self.main_buffer = ReplayBuffer(size, **kwargs) - # TODO cached_buffer can be consider to be replced by vector + # TODO cached_buffer can be consider to be replced by vector # buffer in the future - self.cached_buffer = VecReplayBuffer(max_length, cached_buffer_n, **kwargs) - super().__init__(size= _maxsize, **kwargs) + self.cached_buffer = VecReplayBuffer( + max_length, cached_buffer_n, **kwargs) + super().__init__(size=_maxsize, **kwargs) # TODO support, or just delete stack_num option from Replay buffer for now assert self.stack_num == 1 - + def __len__(self) -> int: """Return len(self).""" return len(self.main_buffer) + len(self.cached_buffer) @@ -853,9 +860,9 @@ def add( policy: Optional[Union[dict, Batch]] = {}, index: Optional[Union[int, np.integer, np.ndarray, List[int]]] = None, **kwargs: Any - ) -> None: + ) -> None: """ - + """ if index is None: index = range(self.cached_bufs_n) @@ -866,31 +873,34 @@ def add( act = np.atleast_1d(act) rew = np.atleast_1d(rew) done = np.atleast_1d(done) - # TODO ugly code + # TODO ugly code if isinstance(obs_next, Batch) and obs_next.is_empty(): obs_next = None if isinstance(info, Batch) and info.is_empty(): info = {} if isinstance(policy, Batch) and policy.is_empty(): policy = {} - obs_next = np.atleast_1d([None]*len(index)) if obs_next is None else np.atleast_1d(obs_next) - info = np.atleast_1d([{}]*len(index)) if info == {} else np.atleast_1d(info) - policy = np.atleast_1d([{}]*len(index)) if policy == {} else np.atleast_1d(policy) + obs_next = np.atleast_1d( + [None]*len(index)) if obs_next is None else np.atleast_1d(obs_next) + info = np.atleast_1d( + [{}]*len(index)) if info == {} else np.atleast_1d(info) + policy = np.atleast_1d( + [{}]*len(index)) if policy == {} else np.atleast_1d(policy) # TODO what if data is already in episodes, what if i want to add mutiple data ? # can accelerate if self._meta.is_empty(): self._initialise(obs[0], act[0], rew[0], done[0], obs_next[0], - info[0], policy[0]) + info[0], policy[0]) self.cached_buffer.add(obs, act, rew, done, obs_next, info, policy, index, False, **kwargs) return self._main_buf_update() def _main_buf_update(self): - lens = np.zeros((self.cached_bufs_n, ), dtype = np.int) + lens = np.zeros((self.cached_bufs_n, ), dtype=np.int) rews = np.zeros((self.cached_bufs_n, )) - start_indexs = np.zeros((self.cached_bufs_n, ), dtype = np.int) + start_indexs = np.zeros((self.cached_bufs_n, ), dtype=np.int) for i, buf in enumerate(self.cached_buffer.bufs): if buf.done[buf._index - 1] > 0: lens[i] = len(buf) @@ -906,13 +916,13 @@ def reset(self) -> None: self.cached_buffer.reset() self.main_buffer.reset() self._avail_index = [] - #TODO finish + # TODO finish def sample(self, batch_size: int, - is_from_main_buf = False) -> Tuple[Batch, np.ndarray]: + is_from_main_buf=False) -> Tuple[Batch, np.ndarray]: if is_from_main_buf: return self.main_buffer.sample(batch_size) - + # TODO use all() method to replace _all = np.arange(len(self), dtype=np.int) @@ -920,7 +930,7 @@ def sample(self, batch_size: int, add = self.main_buffer._maxsize - len(self.main_buffer) for buf in self.cached_buffer.bufs: end = start + len(buf) - _all[start:end] = _all[start:end] + add + _all[start:end] = _all[start:end] + add start = end add = add + buf._maxsize - len(buf) indice = np.random.choice(_all, batch_size) @@ -936,7 +946,7 @@ def get( if stack_num is None: stack_num = self.stack_num assert(stack_num == 1) - #TODO support stack + # TODO support stack return super().get(indice, key, stack_num) def _initialise( @@ -952,15 +962,14 @@ def _initialise( assert(self._meta.is_empty()) # to initialise self._meta super().add(obs, act, rew, done, obs_next, info, policy) - super().reset() #TODO delete useless varible? + super().reset() # TODO delete useless varible? del self._index del self._size self.main_buffer.set_batch(self._meta[:self.main_buffer._maxsize]) self.cached_buffer.set_batch(self._meta[self.main_buffer._maxsize:]) + # TODO add standard API for vec buffer and use vec buffer to replace self.cached_buffer.bufs - - #TODO add standard API for vec buffer and use vec buffer to replace self.cached_buffer.bufs def ends(self): return np.concatenate( [self.main_buffer.ends(), self.cached_buffer.ends()]) @@ -972,7 +981,7 @@ def starts(self): def next(self, index): assert index >= 0 and index < self._maxsize, "Input index illegal." next_indices = np.full(index.shape, -1) - mask = index= 0 and index < self._maxsize, "Input index illegal." start_indices = np.full(index.shape, -1) - mask = index=0 and global_index Date: Wed, 20 Jan 2021 16:49:45 +0800 Subject: [PATCH 004/104] init --- tianshou/data/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index e51b8d161..fb480a983 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -2,7 +2,8 @@ from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer import ReplayBuffer, \ - ListReplayBuffer, PrioritizedReplayBuffer + ListReplayBuffer, PrioritizedReplayBuffer, \ + VecReplayBuffer, CachedReplayBuffer from tianshou.data.collector import Collector __all__ = [ @@ -14,5 +15,7 @@ "ReplayBuffer", "ListReplayBuffer", "PrioritizedReplayBuffer", + "CachedReplayBuffer", + "VecReplayBuffer", "Collector", ] From 36e799e09793ec292c08c047306454778c57dc17 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 22 Jan 2021 10:58:07 +0800 Subject: [PATCH 005/104] some change --- tianshou/data/buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 9a5510654..7bfa1d126 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -983,7 +983,7 @@ def next(self, index): next_indices = np.full(index.shape, -1) mask = index < self.main_buffer._maxsize next_indices[mask] = self.main_buffer.next(index[mask]) - next_indices[~mask] = self.cached_buffer.next(index[~mask]) + next_indices[~mask] = self.cached_buffer.next(index[~mask]) + self.main_buffer._maxsize return next_indices def start(self, index): @@ -992,7 +992,7 @@ def start(self, index): start_indices = np.full(index.shape, -1) mask = index < self.main_buffer._maxsize start_indices[mask] = self.main_buffer.start(index[mask]) - start_indices[~mask] = self.cached_buffer.start(index[~mask]) + start_indices[~mask] = self.cached_buffer.start(index[~mask]) + self.main_buffer._maxsize return start_indices # def _global2local(self, global_index): From 50b20a0530a057660a9c286941ef4caaa2beacf8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 22 Jan 2021 14:26:07 +0800 Subject: [PATCH 006/104] update ReplayBuffer --- test/base/test_buffer.py | 14 +- tianshou/data/buffer.py | 317 ++++++++++++++++----------------------- 2 files changed, 133 insertions(+), 198 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 1695195b4..78ca69419 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -38,11 +38,11 @@ def test_replaybuffer(size=10, bufsize=20): assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() b = ReplayBuffer(size=10) - b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) + b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}}) assert b.obs[0] == 1 - assert b.done[0] == 'str' + assert b.obs_next[0] == 'str' assert np.all(b.obs[1:] == 0) - assert np.all(b.done[1:] == np.array(None)) + assert np.all(b.obs_next[1:] == np.array(None)) assert b.info.a[0] == 3 and b.info.a.dtype == np.integer assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact @@ -320,7 +320,7 @@ def test_hdf5(): assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num - assert _buffers[k]._maxsize == buffers[k]._maxsize + assert _buffers[k].maxsize == buffers[k].maxsize assert _buffers[k]._index == buffers[k]._index assert np.all(_buffers[k]._indices == buffers[k]._indices) for k in ["array", "prioritized"]: @@ -336,12 +336,12 @@ def test_hdf5(): os.remove(path) # raise exception when value cannot be pickled - data = {"not_supported": lambda x: x*x} + data = {"not_supported": lambda x: x * x} grp = h5py.Group with pytest.raises(NotImplementedError): to_hdf5(data, grp) # ndarray with data type not supported by HDF5 that cannot be pickled - data = {"not_supported": np.array(lambda x: x*x)} + data = {"not_supported": np.array(lambda x: x * x)} grp = h5py.Group with pytest.raises(RuntimeError): to_hdf5(data, grp) @@ -351,9 +351,9 @@ def test_hdf5(): test_hdf5() test_replaybuffer() test_ignore_obs_next() + test_update() test_stack() test_pickle() test_segtree() test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) - test_update() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 9a5510654..c8e60ede6 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -139,8 +139,8 @@ class ReplayBuffer: index when using frame-stack sampling method, defaults to False. This feature is not supported in Prioritized Replay Buffer currently. """ - _reserved_keys = {'obs', 'act', 'rew', - 'done', 'obs_next', 'info', 'policy'} + _reserved_keys = {"obs", "act", "rew", "done", + "obs_next", "info", "policy"} def __init__( self, @@ -151,17 +151,15 @@ def __init__( sample_avail: bool = False, ) -> None: super().__init__() - # TODO _maxsize == 0 handle - self._maxsize = size - self._indices = np.arange(size) - # consider move stacking option to another + self.maxsize = size + assert stack_num > 0, "stack_num should greater than 0" self.stack_num = stack_num + self._indices = np.arange(size) + self._save_obs_next = not ignore_obs_next + self._save_only_last_obs = save_only_last_obs self._avail = sample_avail and stack_num > 1 - self._avail_index: List[int] = [] - self._save_s_ = not ignore_obs_next - self._last_obs = save_only_last_obs - self._index = 0 - self._size = 0 + self._index = 0 # current index + self._size = 0 # current buffer size self._meta: Batch = Batch() self.reset() @@ -186,14 +184,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None: We need it because pickling buffer does not work out-of-the-box ("buffer.__getattr__" is customized). """ - self._indices = np.arange(state["_maxsize"]) self.__dict__.update(state) - def __getstate__(self) -> dict: - exclude = {"_indices"} - state = {k: v for k, v in self.__dict__.items() if k not in exclude} - return state - def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" assert key not in self._reserved_keys, ( @@ -204,41 +196,60 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] except KeyError: - self._meta.__dict__[name] = _create_value(inst, self._maxsize) + self._meta.__dict__[name] = _create_value(inst, self.maxsize) value = self._meta.__dict__[name] + if isinstance(inst, (torch.Tensor, np.ndarray)): + if inst.shape != value.shape[1:]: + raise ValueError( + "Cannot add data to a buffer with different shape with key" + f" {name}, expect {value.shape[1:]}, given {inst.shape}." + ) try: value[self._index] = inst - except KeyError: + except KeyError: # inst is a dict/Batch for key in set(inst.keys()).difference(value.__dict__.keys()): - value.__dict__[key] = _create_value(inst[key], self._maxsize) + value.__dict__[key] = _create_value(inst[key], self.maxsize) value[self._index] = inst - @property - def stack_num(self) -> int: - return self._stack + def unfinished_index(self) -> int: + return (self._index - 1) % self._size - @stack_num.setter - def stack_num(self, num: int) -> None: - assert num > 0, "stack_num should greater than 0" - self._stack = num + def prev( + self, + index: Union[int, np.integer, np.ndarray], + within_episode: bool = False, + ) -> np.ndarray: + """Return one step previous index.""" + index = self._indices[:self._size][index] + prev_index = (index - 1) % self._size + if within_episode: + done = self.done[prev_index] | \ + (prev_index == self.unfinished_index()) + prev_index = (prev_index + done) % self._size + return prev_index + + def next( + self, + index: Union[int, np.integer, np.ndarray], + within_episode: bool = False, + ) -> np.ndarray: + """Return one step next index.""" + index = self._indices[:self._size][index] + if within_episode: + done = self.done[index] | (index == self.unfinished_index()) + return (index + (1 - done)) % self._size + else: + return (index + 1) % self._size def update(self, buffer: "ReplayBuffer") -> None: - """Move the data from the given buffer to self.""" - # TODO 'one by one copying' can be greatly improved. - # Can we move data as a whole batch to save time? - # what if self._maxsize << buffer._maxize, can we detect that - # and just ignore those element to be rewrite? + """Move the data from the given buffer to current buffer.""" if len(buffer) == 0: return - i = begin = buffer._index % len(buffer) - stack_num_orig = buffer.stack_num - buffer.stack_num = 1 - while True: - self.add(**buffer[i]) # type: ignore - i = (i + 1) % len(buffer) - if i == begin: - break + stack_num_orig, buffer.stack_num = buffer.stack_num, 1 + batch, _ = buffer.sample(0) buffer.stack_num = stack_num_orig + for b in batch: + self.add(**b) def add( self, @@ -251,55 +262,35 @@ def add( policy: Optional[Union[dict, Batch]] = {}, **kwargs: Any, ) -> None: - """Add a batch of data into replay buffer. - Expect all input to be batch, dict, or numpy array""" - # TODO should we consider to support batch input? + """Add a batch of data into replay buffer.""" assert isinstance( info, (dict, Batch) ), "You should return a dict in the last argument of env.step()." - if self._last_obs: + if self._save_only_last_obs: obs = obs[-1] self._add_to_buffer("obs", obs) self._add_to_buffer("act", act) - # make sure the reward is a float instead of an int + # make sure the data type of reward is float instead of int self._add_to_buffer("rew", rew * 1.0) # type: ignore - self._add_to_buffer("done", done) - if self._save_s_: + self._add_to_buffer("done", bool(done)) # done should be a bool scalar + if self._save_obs_next: if obs_next is None: obs_next = Batch() - elif self._last_obs: + elif self._save_only_last_obs: obs_next = obs_next[-1] self._add_to_buffer("obs_next", obs_next) self._add_to_buffer("info", info) self._add_to_buffer("policy", policy) - # maintain available index for frame-stack sampling - if self._avail: - # update current frame - avail = sum(self.done[i] for i in range( - self._index - self.stack_num + 1, self._index)) == 0 - if self._size < self.stack_num - 1: - avail = False - if avail and self._index not in self._avail_index: - self._avail_index.append(self._index) - elif not avail and self._index in self._avail_index: - self._avail_index.remove(self._index) - # remove the later available frame because of broken storage - t = (self._index + self.stack_num - 1) % self._maxsize - if t in self._avail_index: - self._avail_index.remove(t) - - if self._maxsize > 0: - self._size = min(self._size + 1, self._maxsize) - self._index = (self._index + 1) % self._maxsize - else: - self._size = self._index = self._index + 1 + if self.maxsize > 0: + self._size = min(self._size + 1, self.maxsize) + self._index = (self._index + 1) % self.maxsize + else: # TODO: remove this after deleting ListReplayBuffer + self._size = self._index = self._size + 1 def reset(self) -> None: """Clear all the data in replay buffer.""" - self._index = 0 - self._size = 0 - self._avail_index = [] + self._index = self._size = 0 def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with size equal to batch_size. @@ -309,60 +300,36 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: :return: Sample data and its corresponding index inside the buffer. """ if batch_size > 0: - _all = self._avail_index if self._avail else self._size - indice = np.random.choice(_all, batch_size) - else: - if self._avail: - indice = np.array(self._avail_index) - else: - indice = np.concatenate([ - np.arange(self._index, self._size), - np.arange(0, self._index), - ]) + indice = np.random.choice(self._size, batch_size) + else: # construct current available indices + indice = np.concatenate([ + np.arange(self._index, self._size), + np.arange(0, self._index), + ]) assert len(indice) > 0, "No available indice can be sampled." return self[indice], indice def get( self, - indice: Union[slice, int, np.integer, np.ndarray], + indice: Union[int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: """Return the stacked result. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the - indice. The stack_num (here equals to 4) is given from buffer - initialization procedure. + indice. """ if stack_num is None: stack_num = self.stack_num - if stack_num == 1: # the most often case - if key != "obs_next" or self._save_s_: - val = self._meta.__dict__[key] - try: - return val[indice] - except IndexError as e: - if not (isinstance(val, Batch) and val.is_empty()): - raise e # val != Batch() - return Batch() - indice = self._indices[:self._size][indice] - done = self._meta.__dict__["done"] - if key == "obs_next" and not self._save_s_: - indice += 1 - done[indice].astype(np.int) - indice[indice == self._size] = 0 - key = "obs" - val = self._meta.__dict__[key] + val = self._meta[key] try: - if stack_num == 1: + if stack_num == 1: # the most often case return val[indice] stack: List[Any] = [] for _ in range(stack_num): stack = [val[indice]] + stack - pre_indice = np.asarray(indice - 1) - pre_indice[pre_indice == -1] = self._size - 1 - indice = np.asarray( - pre_indice + done[pre_indice].astype(np.int)) - indice[indice == self._size] = 0 + indice = self.prev(indice, within_episode=True) if isinstance(val, Batch): return Batch.stack(stack, axis=indice.ndim) else: @@ -380,46 +347,40 @@ def __getitem__( If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, ...). """ + index = self._indices[:self._size][index] # change slice to np array + if self._save_obs_next: + obs_next = self.get(index, "obs_next", self.stack_num) + else: + next_index = self.next(index, within_episode=True) + obs_next = self.get(next_index, "obs", self.stack_num) return Batch( - obs=self.get(index, "obs"), + obs=self.get(index, "obs", self.stack_num), act=self.act[index], rew=self.rew[index], done=self.done[index], - obs_next=self.get(index, "obs_next"), - info=self.get(index, "info"), - policy=self.get(index, "policy"), + obs_next=obs_next, + info=self.get(index, "info", self.stack_num), + policy=self.get(index, "policy", self.stack_num), ) - def set_batch(self, batch: "Batch"): - """Manually choose the batch you want the ReplayBuffer to manage. This - method should be called instantly after the ReplayBuffer is initialised. - """ - assert self._meta.is_empty(), "This method cannot be called after add() method" - self._meta = batch - assert not self._is_meta_corrupted(), ( + def set_batch(self, batch: Batch): + """Manually choose the batch you want the ReplayBuffer to manage.""" + assert self._is_legal_batch(batch), ( "Input batch doesn't meet ReplayBuffer's data form requirement.") + self._meta = batch - def _is_meta_corrupted(self) -> bool: - """Assert if self._meta: Batch is still in legal form. - """ - # TODO do we need to check the chlid? is cache + def _is_legal_batch(self, batch: Batch) -> bool: + """Check the given batch is in legal form.""" if set(self._meta.keys()) != self._reserved_keys: + return False + if self._meta.is_empty(recurse=True): return True - for v in self._meta.values(): - if isinstance(v, Batch): - if not v.is_empty() and v.shape[0] != self._maxsize: - return True - elif isinstance(v, np.ndarray): - if v.shape[0] != self._maxsize: - return True - else: - return True - return False + return len(self._meta) == self.maxsize def save_hdf5(self, path: str) -> None: """Save replay buffer to HDF5 file.""" with h5py.File(path, "w") as f: - to_hdf5(self.__getstate__(), f) + to_hdf5(self.__dict__, f) @classmethod def load_hdf5( @@ -431,37 +392,6 @@ def load_hdf5( buf.__setstate__(from_hdf5(f, device=device)) return buf - def start(self, index): - """return start indices of given indices""" - assert index < len(self) and index >= 0, "Input index illegal." - sorted_starts = np.sort(self.starts()) - start_indices = np.searchsorted(sorted_starts, index, side="right") - 1 - start_indices[start_indices < 0] = sorted_starts[-1] - return start_indices - - def next(self, index): - """return next n step indices""" - assert index < len(self) and index >= 0, "Input index illegal." - return (index + ~(self.done[index] | index == self._index)) % len(self) - - def starts(self): - """return indices of all episodes""" - if len(self) > 0: - return (self.ends()+1) % len(self) - else: - return np.array([], dtype=np.int) - - def ends(self): - """return last indices of finished episodes. """ - if len(self) > 0: - last_write_in = int((self._index - 1) % len(self)) - if self.done[last_write_in]: - return np.where(self.done[:len(self)])[0] - else: - return np.append(np.where(self.done[:len(self)])[0], last_write_in) - else: - return np.array([], dtype=np.int) - class ListReplayBuffer(ReplayBuffer): """List-based replay buffer. @@ -588,15 +518,20 @@ def update_weight( def __getitem__( self, index: Union[slice, int, np.integer, np.ndarray] ) -> Batch: + index = self._indices[:self._size][index] # change slice to np array + if self._save_obs_next: + obs_next = self.get(index, "obs_next", self.stack_num) + else: + next_index = self.next(index, within_episode=True) + obs_next = self.get(next_index, "obs", self.stack_num) return Batch( - obs=self.get(index, "obs"), + obs=self.get(index, "obs", self.stack_num), act=self.act[index], rew=self.rew[index], done=self.done[index], - obs_next=self.get(index, "obs_next"), - info=self.get(index, "info"), - policy=self.get(index, "policy"), - weight=self.weight[index], + obs_next=obs_next, + info=self.get(index, "info", self.stack_num), + policy=self.get(index, "policy", self.stack_num), ) @@ -608,19 +543,19 @@ def __init__( **kwargs: Any, ) -> None: # TODO can size==0? - assert size > 0 - assert buf_n > 0 - if buf_n == 1: - import warnings - warnings.warn( - "VecReplayBuffer with buf_n = 1 will cause low efficiency. " - "Please consider using ReplayBuffer which is not in vector form.", - Warning) - _maxsize = buf_n*size - self.buf_n = buf_n - self.bufs = np.array([ReplayBuffer(size, **kwargs) - for _ in range(buf_n)]) - super().__init__(size=_maxsize, **kwargs) + assert size > 0 + assert buf_n > 0 + if buf_n == 1: + import warnings + warnings.warn( + "VecReplayBuffer with buf_n = 1 will cause low efficiency. " + "Please consider using ReplayBuffer which is not in vector form.", + Warning) + _maxsize = buf_n * size + self.buf_n = buf_n + self.bufs = np.array([ReplayBuffer(size, **kwargs) + for _ in range(buf_n)]) + super().__init__(size=_maxsize, **kwargs) def __len__(self) -> int: return np.sum([len(b) for b in self.bufs]) @@ -659,11 +594,11 @@ def add( if isinstance(policy, Batch) and policy.is_empty(): policy = {} obs_next = np.atleast_1d( - [None]*len(index)) if obs_next is None else np.atleast_1d(obs_next) + [None] * len(index)) if obs_next is None else np.atleast_1d(obs_next) info = np.atleast_1d( - [{}]*len(index)) if info == {} else np.atleast_1d(info) + [{}] * len(index)) if info == {} else np.atleast_1d(info) policy = np.atleast_1d( - [{}]*len(index)) if policy == {} else np.atleast_1d(policy) + [{}] * len(index)) if policy == {} else np.atleast_1d(policy) # can accelerate if self._meta.is_empty(): @@ -760,7 +695,7 @@ def next(self, index): lower = upper upper += b._maxsize mask = next_indices >= lower and next_indices < upper - next_indices[mask] = b.next(next_indices[mask]-lower)+lower + next_indices[mask] = b.next(next_indices[mask] - lower) + lower return next_indices def start(self, index): @@ -772,7 +707,7 @@ def start(self, index): lower = upper upper += b._maxsize mask = start_indices >= lower and start_indices < upper - start_indices[mask] = b.start(start_indices[mask]-lower)+lower + start_indices[mask] = b.start(start_indices[mask] - lower) + lower return start_indices @@ -804,7 +739,7 @@ def __init__( "Please consider using ReplayBuffer which is not in cached form.", Warning) - _maxsize = size+cached_buffer_n*max_length + _maxsize = size + cached_buffer_n * max_length self.cached_bufs_n = cached_buffer_n # TODO see if we can generalize to all kinds of buffer self.main_buffer = ReplayBuffer(size, **kwargs) @@ -881,11 +816,11 @@ def add( if isinstance(policy, Batch) and policy.is_empty(): policy = {} obs_next = np.atleast_1d( - [None]*len(index)) if obs_next is None else np.atleast_1d(obs_next) + [None] * len(index)) if obs_next is None else np.atleast_1d(obs_next) info = np.atleast_1d( - [{}]*len(index)) if info == {} else np.atleast_1d(info) + [{}] * len(index)) if info == {} else np.atleast_1d(info) policy = np.atleast_1d( - [{}]*len(index)) if policy == {} else np.atleast_1d(policy) + [{}] * len(index)) if policy == {} else np.atleast_1d(policy) # TODO what if data is already in episodes, what if i want to add mutiple data ? # can accelerate From 0ac97afc2d2bc6f8578a9c06daf84c7614a2ef6e Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 22 Jan 2021 18:50:59 +0800 Subject: [PATCH 007/104] refactor ReplayBuffer --- tianshou/data/buffer.py | 127 +++++++++++++++++----------------------- 1 file changed, 53 insertions(+), 74 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index d30f63098..04d5633f9 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -135,9 +135,6 @@ class ReplayBuffer: :param bool save_only_last_obs: only save the last obs/obs_next when it has a shape of (timestep, ...) because of temporal stacking, defaults to False. - :param bool sample_avail: the parameter indicating sampling only available - index when using frame-stack sampling method, defaults to False. - This feature is not supported in Prioritized Replay Buffer currently. """ _reserved_keys = {"obs", "act", "rew", "done", "obs_next", "info", "policy"} @@ -148,7 +145,6 @@ def __init__( stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, - sample_avail: bool = False, ) -> None: super().__init__() self.maxsize = size @@ -157,7 +153,6 @@ def __init__( self._indices = np.arange(size) self._save_obs_next = not ignore_obs_next self._save_only_last_obs = save_only_last_obs - self._avail = sample_avail and stack_num > 1 self._index = 0 # current index self._size = 0 # current buffer size self._meta: Batch = Batch() @@ -198,12 +193,6 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: except KeyError: self._meta.__dict__[name] = _create_value(inst, self.maxsize) value = self._meta.__dict__[name] - if isinstance(inst, (torch.Tensor, np.ndarray)): - if inst.shape != value.shape[1:]: - raise ValueError( - "Cannot add data to a buffer with different shape with key" - f" {name}, expect {value.shape[1:]}, given {inst.shape}." - ) try: value[self._index] = inst except KeyError: # inst is a dict/Batch @@ -220,12 +209,13 @@ def prev( within_episode: bool = False, ) -> np.ndarray: """Return one step previous index.""" - index = self._indices[:self._size][index] + assert np.all(index >= 0) and np.all(index < len(self)), \ + "Illegal index input." prev_index = (index - 1) % self._size if within_episode: - done = self.done[prev_index] | \ + end_flag = self.done[prev_index] | \ (prev_index == self.unfinished_index()) - prev_index = (prev_index + done) % self._size + prev_index = (prev_index + end_flag) % self._size return prev_index def next( @@ -234,10 +224,11 @@ def next( within_episode: bool = False, ) -> np.ndarray: """Return one step next index.""" - index = self._indices[:self._size][index] + assert np.all(index >= 0) and np.all(index < len(self)), \ + "Illegal index input." if within_episode: - done = self.done[index] | (index == self.unfinished_index()) - return (index + (1 - done)) % self._size + end_flag = self.done[index] | (index == self.unfinished_index()) + return (index + (1 - end_flag)) % self._size else: return (index + 1) % self._size @@ -246,10 +237,10 @@ def update(self, buffer: "ReplayBuffer") -> None: if len(buffer) == 0: return stack_num_orig, buffer.stack_num = buffer.stack_num, 1 - batch, _ = buffer.sample(0) + indices = buffer.sample_index(0) buffer.stack_num = stack_num_orig - for b in batch: - self.add(**b) + for i in indices: + self.add(**buffer[i]) def add( self, @@ -292,22 +283,32 @@ def reset(self) -> None: """Clear all the data in replay buffer.""" self._index = self._size = 0 - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with size equal to batch_size. + def sample_index(self, batch_size: int) -> np.ndarray: + """Same as sample(), but only return indices to avoid possible overhead. Return all the data in the buffer if batch_size is 0. :return: Sample data and its corresponding index inside the buffer. """ if batch_size > 0: - indice = np.random.choice(self._size, batch_size) + indices = np.random.choice(self._size, batch_size) else: # construct current available indices - indice = np.concatenate([ + indices = np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index), ]) - assert len(indice) > 0, "No available indice can be sampled." - return self[indice], indice + assert len(indices) > 0, "No available indice can be sampled." + return indices + + def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: + """Get a random sample from buffer with size equal to batch_size. + + Return all the data in the buffer if batch_size is 0. + + :return: Sample data and its corresponding index inside the buffer. + """ + indices = self.sample_index(batch_size) + return self[indices], indices def get( self, @@ -349,18 +350,18 @@ def __getitem__( """ index = self._indices[:self._size][index] # change slice to np array if self._save_obs_next: - obs_next = self.get(index, "obs_next", self.stack_num) + obs_next = self.get(index, "obs_next") else: next_index = self.next(index, within_episode=True) - obs_next = self.get(next_index, "obs", self.stack_num) + obs_next = self.get(next_index, "obs") return Batch( - obs=self.get(index, "obs", self.stack_num), + obs=self.get(index, "obs"), act=self.act[index], rew=self.rew[index], done=self.done[index], obs_next=obs_next, - info=self.get(index, "info", self.stack_num), - policy=self.get(index, "policy", self.stack_num), + info=self.get(index, "info"), + policy=self.get(index, "policy"), ) def set_batch(self, batch: Batch): @@ -371,11 +372,9 @@ def set_batch(self, batch: Batch): def _is_legal_batch(self, batch: Batch) -> bool: """Check the given batch is in legal form.""" - if set(self._meta.keys()) != self._reserved_keys: + if set(batch.keys()) != self._reserved_keys: return False - if self._meta.is_empty(recurse=True): - return True - return len(self._meta) == self.maxsize + return len(batch) == self.maxsize def save_hdf5(self, path: str) -> None: """Save replay buffer to HDF5 file.""" @@ -515,24 +514,6 @@ def update_weight( self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) - def __getitem__( - self, index: Union[slice, int, np.integer, np.ndarray] - ) -> Batch: - index = self._indices[:self._size][index] # change slice to np array - if self._save_obs_next: - obs_next = self.get(index, "obs_next", self.stack_num) - else: - next_index = self.next(index, within_episode=True) - obs_next = self.get(next_index, "obs", self.stack_num) - return Batch( - obs=self.get(index, "obs", self.stack_num), - act=self.act[index], - rew=self.rew[index], - done=self.done[index], - obs_next=obs_next, - info=self.get(index, "info", self.stack_num), - policy=self.get(index, "policy", self.stack_num), - ) class VecReplayBuffer(ReplayBuffer): @@ -626,13 +607,12 @@ def _initialise( super().reset() # TODO delete useless varible? del self._index del self._size - del self._avail_index self._set_batch() def _set_batch(self): start = 0 for buf in self.bufs: - end = start + buf._maxsize + end = start + buf.maxsize buf.set_batch(self._meta[start: end]) start = end @@ -659,7 +639,7 @@ def sample(self, batch_size: int, return_only_indice: bool = False) -> Tuple[Bat end = start + len(buf) _all[start:end] = _all[start:end] + add start = end - add = add + buf._maxsize - len(buf) + add = add + buf.maxsize - len(buf) # TODO consider making _all a seperate method indice = np.random.choice(_all, batch_size) assert len(indice) > 0, "No available indice can be sampled." @@ -687,13 +667,13 @@ def starts(self): return np.concatenate([b.starts() for b in self.cached_buffer.bufs]) def next(self, index): - assert index >= 0 and index < self._maxsize, "Input index illegal." + assert index >= 0 and index < self.maxsize, "Input index illegal." next_indices = np.full(index.shape, -1) upper = 0 lower = 0 for b in self.bufs: lower = upper - upper += b._maxsize + upper += b.maxsize mask = next_indices >= lower and next_indices < upper next_indices[mask] = b.next(next_indices[mask] - lower) + lower return next_indices @@ -705,7 +685,7 @@ def start(self, index): lower = 0 for b in self.bufs: lower = upper - upper += b._maxsize + upper += b.maxsize mask = start_indices >= lower and start_indices < upper start_indices[mask] = b.start(start_indices[mask] - lower) + lower return start_indices @@ -841,7 +821,7 @@ def _main_buf_update(self): lens[i] = len(buf) rews[i] = np.sum(buf.rew[:lens[i]]) start_indexs[i] = self.main_buffer._index - if self.main_buffer._maxsize > 0: + if self.main_buffer.maxsize > 0: # _maxsize of main_buffer might be 0 in test collector. self.main_buffer.update(buf) buf.reset() @@ -850,7 +830,6 @@ def _main_buf_update(self): def reset(self) -> None: self.cached_buffer.reset() self.main_buffer.reset() - self._avail_index = [] # TODO finish def sample(self, batch_size: int, @@ -862,12 +841,12 @@ def sample(self, batch_size: int, _all = np.arange(len(self), dtype=np.int) start = len(self.main_buffer) - add = self.main_buffer._maxsize - len(self.main_buffer) + add = self.main_buffer.maxsize - len(self.main_buffer) for buf in self.cached_buffer.bufs: end = start + len(buf) _all[start:end] = _all[start:end] + add start = end - add = add + buf._maxsize - len(buf) + add = add + buf.maxsize - len(buf) indice = np.random.choice(_all, batch_size) assert len(indice) > 0, "No available indice can be sampled." return self[indice], indice @@ -900,8 +879,8 @@ def _initialise( super().reset() # TODO delete useless varible? del self._index del self._size - self.main_buffer.set_batch(self._meta[:self.main_buffer._maxsize]) - self.cached_buffer.set_batch(self._meta[self.main_buffer._maxsize:]) + self.main_buffer.set_batch(self._meta[:self.main_buffer.maxsize]) + self.cached_buffer.set_batch(self._meta[self.main_buffer.maxsize:]) # TODO add standard API for vec buffer and use vec buffer to replace self.cached_buffer.bufs @@ -914,24 +893,24 @@ def starts(self): [self.main_buffer.starts(), self.cached_buffer.starts()]) def next(self, index): - assert index >= 0 and index < self._maxsize, "Input index illegal." + assert index >= 0 and index < self.maxsize, "Input index illegal." next_indices = np.full(index.shape, -1) - mask = index < self.main_buffer._maxsize + mask = index < self.main_buffer.maxsize next_indices[mask] = self.main_buffer.next(index[mask]) - next_indices[~mask] = self.cached_buffer.next(index[~mask]) + self.main_buffer._maxsize + next_indices[~mask] = self.cached_buffer.next(index[~mask]) + self.main_buffer.maxsize return next_indices def start(self, index): """return start indices of given indices""" - assert index >= 0 and index < self._maxsize, "Input index illegal." + assert index >= 0 and index < self.maxsize, "Input index illegal." start_indices = np.full(index.shape, -1) - mask = index < self.main_buffer._maxsize + mask = index < self.main_buffer.maxsize start_indices[mask] = self.main_buffer.start(index[mask]) - start_indices[~mask] = self.cached_buffer.start(index[~mask]) + self.main_buffer._maxsize + start_indices[~mask] = self.cached_buffer.start(index[~mask]) + self.main_buffer.maxsize return start_indices # def _global2local(self, global_index): - # assert (global_index>=0 and global_index=0 and global_index=lower and global_index Date: Fri, 22 Jan 2021 23:44:07 +0800 Subject: [PATCH 008/104] refactor vec/cached buffer --- tianshou/data/buffer.py | 299 ++++++++++++++++++++-------------------- 1 file changed, 148 insertions(+), 151 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 04d5633f9..eab95fd19 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -200,8 +200,12 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: value.__dict__[key] = _create_value(inst[key], self.maxsize) value[self._index] = inst - def unfinished_index(self) -> int: - return (self._index - 1) % self._size + def unfinished_index(self) -> np.ndarray: + try: + last = (self._index - 1) % self._size + except ZeroDivisionError: + return np.array([]) + return np.array([last]) if not self.done[last] else np.array([]) def prev( self, @@ -210,11 +214,11 @@ def prev( ) -> np.ndarray: """Return one step previous index.""" assert np.all(index >= 0) and np.all(index < len(self)), \ - "Illegal index input." + "Illegal index input." prev_index = (index - 1) % self._size if within_episode: end_flag = self.done[prev_index] | \ - (prev_index == self.unfinished_index()) + np.isin(prev_index, self.unfinished_index()) prev_index = (prev_index + end_flag) % self._size return prev_index @@ -225,9 +229,10 @@ def next( ) -> np.ndarray: """Return one step next index.""" assert np.all(index >= 0) and np.all(index < len(self)), \ - "Illegal index input." + "Illegal index input." if within_episode: - end_flag = self.done[index] | (index == self.unfinished_index()) + end_flag = self.done[index] |\ + np.isin(index, self.unfinished_index()) return (index + (1 - end_flag)) % self._size else: return (index + 1) % self._size @@ -283,7 +288,7 @@ def reset(self) -> None: """Clear all the data in replay buffer.""" self._index = self._size = 0 - def sample_index(self, batch_size: int) -> np.ndarray: + def sample_index(self, batch_size: int, **kwargs) -> np.ndarray: """Same as sample(), but only return indices to avoid possible overhead. Return all the data in the buffer if batch_size is 0. @@ -300,14 +305,14 @@ def sample_index(self, batch_size: int) -> np.ndarray: assert len(indices) > 0, "No available indice can be sampled." return indices - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: + def sample(self, batch_size: int, **kwargs) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with size equal to batch_size. Return all the data in the buffer if batch_size is 0. :return: Sample data and its corresponding index inside the buffer. """ - indices = self.sample_index(batch_size) + indices = self.sample_index(batch_size, **kwargs) return self[indices], indices def get( @@ -348,6 +353,7 @@ def __getitem__( If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, ...). """ + # TODO vec_buffer cannot iheritage from this if containing self._size index = self._indices[:self._size][index] # change slice to np array if self._save_obs_next: obs_next = self.get(index, "obs_next") @@ -515,7 +521,6 @@ def update_weight( self._min_prio = min(self._min_prio, weight.min()) - class VecReplayBuffer(ReplayBuffer): def __init__( self, @@ -541,6 +546,56 @@ def __init__( def __len__(self) -> int: return np.sum([len(b) for b in self.bufs]) + def unfinished_index(self) -> np.ndarray: + return np.concatenate( + [b.unfinished_index() for b in self.bufs]) + + def prev( + self, + index: Union[int, np.integer, np.ndarray], + within_episode: bool = False, + ) -> np.ndarray: + """Return one step previous index.""" + assert np.all(index >= 0) and np.all(index < self.maxsize), \ + "Input index illegal." + if not within_episode: + # it's hard to define behavior of next(within_episode = False) + # here, perhaps we don't need within_episode = False option anyway? + raise NotImplementedError + last_indices = np.full(index.shape, -1) + upper = 0 + lower = 0 + for b in self.bufs: + lower = upper + upper += b.maxsize + mask = last_indices >= lower and last_indices < upper + last_indices[mask] = b.prev(last_indices[mask] - lower, + within_episode=True) + lower + return last_indices + + def next( + self, + index: Union[int, np.integer, np.ndarray], + within_episode: bool = False, + ) -> np.ndarray: + """Return one step next index.""" + assert np.all(index >= 0) and np.all(index < self.maxsize), \ + "Input index illegal." + if not within_episode: + # it's hard to define behavior of next(within_episode = False) + # here, perhaps we don't need within_episode = False option anyway? + raise NotImplementedError + next_indices = np.full(index.shape, -1) + upper = 0 + lower = 0 + for b in self.bufs: + lower = upper + upper += b.maxsize + mask = next_indices >= lower and next_indices < upper + next_indices[mask] = b.next(next_indices[mask] - lower, + within_episode=True) + lower + return next_indices + def update(self, **kwargs): raise NotImplementedError @@ -581,7 +636,6 @@ def add( policy = np.atleast_1d( [{}] * len(index)) if policy == {} else np.atleast_1d(policy) - # can accelerate if self._meta.is_empty(): self._initialise(obs[0], act[0], rew[0], done[0], obs_next[0], info[0], policy[0]) @@ -604,12 +658,13 @@ def _initialise( assert(self._meta.is_empty()) # to initialise self._meta super().add(obs, act, rew, done, obs_next, info, policy) - super().reset() # TODO delete useless varible? + super().reset() del self._index del self._size - self._set_batch() + # TODO check method that use these 2 value + self._set_batch_for_children() - def _set_batch(self): + def _set_batch_for_children(self): start = 0 for buf in self.bufs: end = start + buf.maxsize @@ -617,78 +672,29 @@ def _set_batch(self): start = end def set_batch(self, batch: "Batch"): - """Manually choose the batch you want the ReplayBuffer to manage. This - method should be called instantly after the ReplayBuffer is initialised. - """ - assert self.bufs.is_empty(), "This method cannot be called after add() method" - self._meta = batch - assert not self._is_meta_corrupted(), ( - "Input batch doesn't meet ReplayBuffer's data form requirement.") - self._set_batch() + """Manually choose the batch you want the ReplayBuffer to manage.""" + super().set_batch(batch) + self._set_batch_for_children() def reset(self) -> None: for buf in self.bufs: buf.reset() - # TODO finish - def sample(self, batch_size: int, return_only_indice: bool = False) -> Tuple[Batch, np.ndarray]: - _all = np.arange(len(self), dtype=np.int) + def sample_index(self, batch_size: int, **kwargs) -> np.ndarray: + avail_indexes = np.arange(len(self), dtype=np.int) start = 0 add = 0 for buf in self.bufs: end = start + len(buf) - _all[start:end] = _all[start:end] + add + avail_indexes[start:end] = avail_indexes[start:end] + add start = end add = add + buf.maxsize - len(buf) - # TODO consider making _all a seperate method - indice = np.random.choice(_all, batch_size) - assert len(indice) > 0, "No available indice can be sampled." - if return_only_indice: - return indice + # TODO consider making avail_indexes a seperate method + if batch_size == 0: + # data not in chronological order + return avail_indexes else: - return self[indice], indice - - def get( - self, - indice: Union[slice, int, np.integer, np.ndarray], - key: str, - stack_num: Optional[int] = None, - ) -> Union[Batch, np.ndarray]: - if stack_num is None: - stack_num = self.stack_num - assert(stack_num == 1) - # TODO support stack - return super().get(indice, key, stack_num) - - def ends(self): - return np.concatenate([b.ends() for b in self.bufs]) - - def starts(self): - return np.concatenate([b.starts() for b in self.cached_buffer.bufs]) - - def next(self, index): - assert index >= 0 and index < self.maxsize, "Input index illegal." - next_indices = np.full(index.shape, -1) - upper = 0 - lower = 0 - for b in self.bufs: - lower = upper - upper += b.maxsize - mask = next_indices >= lower and next_indices < upper - next_indices[mask] = b.next(next_indices[mask] - lower) + lower - return next_indices - - def start(self, index): - """return start indices of given indices""" - start_indices = np.full(index.shape, -1) - upper = 0 - lower = 0 - for b in self.bufs: - lower = upper - upper += b.maxsize - mask = start_indices >= lower and start_indices < upper - start_indices[mask] = b.start(start_indices[mask] - lower) + lower - return start_indices + return np.random.choice(avail_indexes, batch_size) class CachedReplayBuffer(ReplayBuffer): @@ -697,6 +703,7 @@ class CachedReplayBuffer(ReplayBuffer): parallel collecting in collector. In CachedReplayBuffer is not organized chronologically, but standard API like start()/starts()/ends/next() are provided to help CachedReplayBuffer to be used just like ReplayBuffer. + #TODO finsih doc """ def __init__( @@ -710,7 +717,6 @@ def __init__( TODO support stack in the future """ assert cached_buffer_n > 0 - # TODO what if people don't care about how buffer is organized assert max_length > 0 if cached_buffer_n == 1: import warnings @@ -735,6 +741,54 @@ def __len__(self) -> int: """Return len(self).""" return len(self.main_buffer) + len(self.cached_buffer) + def unfinished_index(self) -> np.ndarray: + return np.concatenate([self.main_buffer.unfinished_index(), + self.cached_buffer.unfinished_index()]) + + def prev( + self, + index: Union[int, np.integer, np.ndarray], + within_episode: bool = False, + ) -> np.ndarray: + """Return one step previous index.""" + assert np.all(index >= 0) and np.all(index < self.maxsize), \ + "Input index illegal." + if not within_episode: + # it's hard to define behavior of prev(within_episode = False) + # here, perhaps we don't need within_episode = False option anyway? + raise NotImplementedError + prev_indices = np.full(index.shape, -1) + mask = index < self.main_buffer.maxsize + prev_indices[mask] = self.main_buffer.prev(index[mask], + within_episode=True) + prev_indices[~mask] = self.cached_buffer.prev( + index[~mask] - self.main_buffer.maxsize, + within_episode=True) + \ + self.main_buffer.maxsize + return prev_indices + + def next( + self, + index: Union[int, np.integer, np.ndarray], + within_episode: bool = False, + ) -> np.ndarray: + """Return one step next index.""" + assert np.all(index >= 0) and np.all(index < self.maxsize), \ + "Input index illegal." + if not within_episode: + # it's hard to define behavior of next(within_episode = False) + # here, perhaps we don't need within_episode = False option anyway? + raise NotImplementedError + next_indices = np.full(index.shape, -1) + mask = index < self.main_buffer.maxsize + next_indices[mask] = self.main_buffer.next(index[mask], + within_episode=True) + next_indices[~mask] = self.cached_buffer.next( + index[~mask] - self.main_buffer.maxsize, + within_episode=True) + \ + self.main_buffer.maxsize + return next_indices + def update(self, buffer: "ReplayBuffer") -> int: """CachedReplayBuffer will only update data from buffer which is in episode form. Return an integer which indicates the number of steps @@ -776,9 +830,6 @@ def add( index: Optional[Union[int, np.integer, np.ndarray, List[int]]] = None, **kwargs: Any ) -> None: - """ - - """ if index is None: index = range(self.cached_bufs_n) index = np.atleast_1d(index).astype(np.int) @@ -830,38 +881,27 @@ def _main_buf_update(self): def reset(self) -> None: self.cached_buffer.reset() self.main_buffer.reset() - # TODO finish - def sample(self, batch_size: int, - is_from_main_buf=False) -> Tuple[Batch, np.ndarray]: + def sample_index(self, batch_size: int, + is_from_main_buf=False, **kwargs) -> np.ndarray: if is_from_main_buf: - return self.main_buffer.sample(batch_size) - - # TODO use all() method to replace + return self.main_buffer.sample_index(batch_size, **kwargs) - _all = np.arange(len(self), dtype=np.int) + avail_indexes = np.arange(len(self), dtype=np.int) start = len(self.main_buffer) add = self.main_buffer.maxsize - len(self.main_buffer) for buf in self.cached_buffer.bufs: end = start + len(buf) - _all[start:end] = _all[start:end] + add + avail_indexes[start:end] = avail_indexes[start:end] + add start = end add = add + buf.maxsize - len(buf) - indice = np.random.choice(_all, batch_size) - assert len(indice) > 0, "No available indice can be sampled." - return self[indice], indice - - def get( - self, - indice: Union[slice, int, np.integer, np.ndarray], - key: str, - stack_num: Optional[int] = None, - ) -> Union[Batch, np.ndarray]: - if stack_num is None: - stack_num = self.stack_num - assert(stack_num == 1) - # TODO support stack - return super().get(indice, key, stack_num) + assert len(avail_indexes) > 0, "No available indice can be sampled." + # TODO consider making avail_indexes a seperate method + if batch_size == 0: + # data not in chronological order + return avail_indexes + else: + return np.random.choice(avail_indexes, batch_size) def _initialise( self, @@ -879,56 +919,13 @@ def _initialise( super().reset() # TODO delete useless varible? del self._index del self._size + self._set_batch_for_children() + + def _set_batch_for_children(self): self.main_buffer.set_batch(self._meta[:self.main_buffer.maxsize]) self.cached_buffer.set_batch(self._meta[self.main_buffer.maxsize:]) - # TODO add standard API for vec buffer and use vec buffer to replace self.cached_buffer.bufs - - def ends(self): - return np.concatenate( - [self.main_buffer.ends(), self.cached_buffer.ends()]) - - def starts(self): - return np.concatenate( - [self.main_buffer.starts(), self.cached_buffer.starts()]) - - def next(self, index): - assert index >= 0 and index < self.maxsize, "Input index illegal." - next_indices = np.full(index.shape, -1) - mask = index < self.main_buffer.maxsize - next_indices[mask] = self.main_buffer.next(index[mask]) - next_indices[~mask] = self.cached_buffer.next(index[~mask]) + self.main_buffer.maxsize - return next_indices - - def start(self, index): - """return start indices of given indices""" - assert index >= 0 and index < self.maxsize, "Input index illegal." - start_indices = np.full(index.shape, -1) - mask = index < self.main_buffer.maxsize - start_indices[mask] = self.main_buffer.start(index[mask]) - start_indices[~mask] = self.cached_buffer.start(index[~mask]) + self.main_buffer.maxsize - return start_indices - - # def _global2local(self, global_index): - # assert (global_index>=0 and global_index=lower and global_index Date: Fri, 22 Jan 2021 23:50:06 +0800 Subject: [PATCH 009/104] pep8 fix --- tianshou/data/__init__.py | 2 +- tianshou/data/buffer.py | 34 ++++++++++++++++++---------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index fb480a983..037dd03b1 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -3,7 +3,7 @@ from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer import ReplayBuffer, \ ListReplayBuffer, PrioritizedReplayBuffer, \ - VecReplayBuffer, CachedReplayBuffer + VecReplayBuffer, CachedReplayBuffer from tianshou.data.collector import Collector __all__ = [ diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index eab95fd19..90488e4fe 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,5 +1,4 @@ import h5py -import torch import numpy as np from numbers import Number from typing import Any, Dict, List, Tuple, Union, Optional @@ -534,9 +533,9 @@ def __init__( if buf_n == 1: import warnings warnings.warn( - "VecReplayBuffer with buf_n = 1 will cause low efficiency. " - "Please consider using ReplayBuffer which is not in vector form.", - Warning) + "VecReplayBuffer with buf_n = 1 will cause low efficiency. " + "Please consider using ReplayBuffer which is not in vector" + "form.", Warning) _maxsize = buf_n * size self.buf_n = buf_n self.bufs = np.array([ReplayBuffer(size, **kwargs) @@ -629,8 +628,8 @@ def add( info = {} if isinstance(policy, Batch) and policy.is_empty(): policy = {} - obs_next = np.atleast_1d( - [None] * len(index)) if obs_next is None else np.atleast_1d(obs_next) + obs_next = np.atleast_1d([None] * len(index)) \ + if obs_next is None else np.atleast_1d(obs_next) info = np.atleast_1d( [{}] * len(index)) if info == {} else np.atleast_1d(info) policy = np.atleast_1d( @@ -701,8 +700,8 @@ class CachedReplayBuffer(ReplayBuffer): """CachedReplayBuffer can be considered as a combination of one main buffer and a list of cached_buffers. It's designed to used by collector to allow parallel collecting in collector. In CachedReplayBuffer is not organized - chronologically, but standard API like start()/starts()/ends/next() are provided - to help CachedReplayBuffer to be used just like ReplayBuffer. + chronologically, but standard API like start()/starts()/ends/next() are + provided to help CachedReplayBuffer to be used just like ReplayBuffer. #TODO finsih doc """ @@ -721,8 +720,9 @@ def __init__( if cached_buffer_n == 1: import warnings warnings.warn( - "CachedReplayBuffer with cached_buffer_n = 1 will cause low efficiency. " - "Please consider using ReplayBuffer which is not in cached form.", + "CachedReplayBuffer with cached_buffer_n = 1 will" + "cause low efficiency. Please consider using ReplayBuffer" + "which is not in cached form.", Warning) _maxsize = size + cached_buffer_n * max_length @@ -734,7 +734,8 @@ def __init__( self.cached_buffer = VecReplayBuffer( max_length, cached_buffer_n, **kwargs) super().__init__(size=_maxsize, **kwargs) - # TODO support, or just delete stack_num option from Replay buffer for now + # TODO support, or just delete stack_num option from + # Replay buffer for now assert self.stack_num == 1 def __len__(self) -> int: @@ -796,7 +797,8 @@ def update(self, buffer: "ReplayBuffer") -> int: # For now update method copy element one by one, which is too slow. if isinstance(buffer, CachedReplayBuffer): buffer = buffer.main_buffer - # now treat buffer like a normal ReplayBuffer and remove those incomplete steps + # now treat buffer like a normal ReplayBuffer and + # remove those incomplete steps if len(buffer) == 0: return 0 diposed_count = 0 @@ -846,15 +848,15 @@ def add( info = {} if isinstance(policy, Batch) and policy.is_empty(): policy = {} - obs_next = np.atleast_1d( - [None] * len(index)) if obs_next is None else np.atleast_1d(obs_next) + obs_next = np.atleast_1d([None] * len(index)) \ + if obs_next is None else np.atleast_1d(obs_next) info = np.atleast_1d( [{}] * len(index)) if info == {} else np.atleast_1d(info) policy = np.atleast_1d( [{}] * len(index)) if policy == {} else np.atleast_1d(policy) - # TODO what if data is already in episodes, what if i want to add mutiple data ? - # can accelerate + # TODO what if data is already in episodes, + # what if i want to add mutiple data ? if self._meta.is_empty(): self._initialise(obs[0], act[0], rew[0], done[0], obs_next[0], info[0], policy[0]) From 3afb2cb363d6555ef6ae264c04dcba729217359f Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 23 Jan 2021 22:02:50 +0800 Subject: [PATCH 010/104] update VectorReplayBuffer and add test --- test/base/test_buffer.py | 66 +++++- tianshou/data/__init__.py | 4 +- tianshou/data/buffer.py | 457 +++++++++++++++----------------------- 3 files changed, 245 insertions(+), 282 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 78ca69419..3991d6a4d 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -7,8 +7,9 @@ import numpy as np from timeit import timeit -from tianshou.data import Batch, SegmentTree, \ - ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Batch, SegmentTree, ReplayBuffer +from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import VectorReplayBuffer from tianshou.data.utils.converter import to_hdf5 if __name__ == '__main__': @@ -94,13 +95,11 @@ def test_ignore_obs_next(size=10): def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) - buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) obs = env.reset(1) for i in range(16): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) - buf2.add(obs, 1, rew, done, None, info) buf3.add([None, None, obs], 1, rew, done, [None, obs], info) obs = obs_next if done: @@ -112,10 +111,6 @@ def test_stack(size=5, bufsize=9, stack_num=4): [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs')) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next')) - _, indice = buf2.sample(0) - assert indice.tolist() == [2, 6] - _, indice = buf2.sample(1) - assert indice in [2, 6] with pytest.raises(IndexError): buf[bufsize * 2] @@ -347,7 +342,62 @@ def test_hdf5(): to_hdf5(data, grp) +def test_vectorbuffer(): + buf = VectorReplayBuffer(5, 4) + buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], + done=[0, 0, 1], buffer_index=[0, 1, 2]) + batch, indice = buf.sample(10) + batch, indice = buf.sample(0) + assert np.allclose(indice, [0, 5, 10]) + indice_prev = buf.prev(indice) + assert np.allclose(indice_prev, indice), indice_prev + indice_next = buf.next(indice) + assert np.allclose(indice_next, indice), indice_next + buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_index=[3]) + batch, indice = buf.sample(10) + batch, indice = buf.sample(0) + assert np.allclose(indice, [0, 5, 10, 15]) + indice_prev = buf.prev(indice) + assert np.allclose(indice_prev, indice), indice_prev + indice_next = buf.next(indice) + assert np.allclose(indice_next, indice), indice_next + data = np.array([0, 0, 0, 0]) + buf.add(obs=data, act=data, rew=data, done=data, buffer_index=[0, 1, 2, 3]) + buf.add(obs=data, act=data, rew=data, done=1 - data, + buffer_index=[0, 1, 2, 3]) + assert len(buf) == 12 + buf.add(obs=data, act=data, rew=data, done=data, buffer_index=[0, 1, 2, 3]) + buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1], + buffer_index=[0, 1, 2, 3]) + assert len(buf) == 20 + batch, indice = buf.sample(10) + indice = buf.sample_index(0) + assert np.allclose(indice, np.arange(len(buf))) + assert np.allclose(buf.done, [ + 0, 0, 1, 0, 0, + 0, 0, 1, 0, 1, + 1, 0, 1, 0, 0, + 1, 0, 1, 0, 1, + ]) + indice_prev = buf.prev(indice) + assert np.allclose(indice_prev, [ + 0, 0, 1, 3, 3, + 5, 5, 6, 8, 8, + 10, 11, 11, 13, 13, + 15, 16, 16, 18, 18, + ]) + indice_next = buf.next(indice) + assert np.allclose(indice_next, [ + 1, 2, 2, 4, 4, + 6, 7, 7, 9, 9, + 10, 12, 12, 14, 14, + 15, 17, 17, 19, 19, + ]) + # TODO: prev/next/stack/hdf5 + + if __name__ == '__main__': + test_vectorbuffer() test_hdf5() test_replaybuffer() test_ignore_obs_next() diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 037dd03b1..be0959192 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -3,7 +3,7 @@ from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer import ReplayBuffer, \ ListReplayBuffer, PrioritizedReplayBuffer, \ - VecReplayBuffer, CachedReplayBuffer + VectorReplayBuffer, CachedReplayBuffer from tianshou.data.collector import Collector __all__ = [ @@ -15,7 +15,7 @@ "ReplayBuffer", "ListReplayBuffer", "PrioritizedReplayBuffer", + "VectorReplayBuffer", "CachedReplayBuffer", - "VecReplayBuffer", "Collector", ] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 90488e4fe..381418962 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,7 +1,8 @@ import h5py +import warnings import numpy as np from numbers import Number -from typing import Any, Dict, List, Tuple, Union, Optional +from typing import Any, Dict, List, Tuple, Union, Callable, Optional from tianshou.data.batch import _create_value from tianshou.data import Batch, SegmentTree, to_numpy @@ -135,8 +136,8 @@ class ReplayBuffer: a shape of (timestep, ...) because of temporal stacking, defaults to False. """ - _reserved_keys = {"obs", "act", "rew", "done", - "obs_next", "info", "policy"} + _reserved_keys = ("obs", "act", "rew", "done", + "obs_next", "info", "policy") def __init__( self, @@ -144,6 +145,8 @@ def __init__( stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, + alloc_fn: Optional[Callable[["ReplayBuffer", List[str], Any], None]] + = None, ) -> None: super().__init__() self.maxsize = size @@ -155,6 +158,7 @@ def __init__( self._index = 0 # current index self._size = 0 # current buffer size self._meta: Batch = Batch() + self._alloc = alloc_fn or ReplayBuffer.default_alloc_fn self.reset() def __len__(self) -> int: @@ -186,65 +190,71 @@ def __setattr__(self, key: str, value: Any) -> None: "key '{}' is reserved and cannot be assigned".format(key)) super().__setattr__(key, value) + @staticmethod + def default_alloc_fn( + buffer: "ReplayBuffer", key: List[str], value: Any + ) -> None: + """Allocate memory on buffer._meta for new (key, value) pair.""" + data = buffer._meta + for k in key[:-1]: + data = data[k] + data[key[-1]] = _create_value(value, buffer.maxsize) + def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] except KeyError: - self._meta.__dict__[name] = _create_value(inst, self.maxsize) - value = self._meta.__dict__[name] + self._alloc(self, [name], inst) + value = self._meta[name] try: value[self._index] = inst except KeyError: # inst is a dict/Batch - for key in set(inst.keys()).difference(value.__dict__.keys()): - value.__dict__[key] = _create_value(inst[key], self.maxsize) + for key in set(inst.keys()).difference(value.keys()): + self._alloc(self, [name, key], inst) value[self._index] = inst + def set_batch(self, batch: Batch): + """Manually choose the batch you want the ReplayBuffer to manage.""" + assert len(batch) == self.maxsize, \ + "Input batch doesn't meet ReplayBuffer's data form requirement." + self._meta = batch + def unfinished_index(self) -> np.ndarray: + """Return the index of unfinished episode.""" try: last = (self._index - 1) % self._size except ZeroDivisionError: return np.array([]) return np.array([last]) if not self.done[last] else np.array([]) - def prev( - self, - index: Union[int, np.integer, np.ndarray], - within_episode: bool = False, - ) -> np.ndarray: - """Return one step previous index.""" - assert np.all(index >= 0) and np.all(index < len(self)), \ - "Illegal index input." + def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + """Return the index of previous transition. + + The index won't be modified if it is the beginning of one episode. + """ prev_index = (index - 1) % self._size - if within_episode: - end_flag = self.done[prev_index] | \ - np.isin(prev_index, self.unfinished_index()) - prev_index = (prev_index + end_flag) % self._size + end_flag = self.done[prev_index] | \ + np.isin(prev_index, self.unfinished_index()) + prev_index = (prev_index + end_flag) % self._size return prev_index - def next( - self, - index: Union[int, np.integer, np.ndarray], - within_episode: bool = False, - ) -> np.ndarray: - """Return one step next index.""" - assert np.all(index >= 0) and np.all(index < len(self)), \ - "Illegal index input." - if within_episode: - end_flag = self.done[index] |\ - np.isin(index, self.unfinished_index()) - return (index + (1 - end_flag)) % self._size - else: - return (index + 1) % self._size + def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + """Return the index of next transition. + + The index won't be modified if it is the end of one episode. + """ + end_flag = self.done[index] | np.isin(index, self.unfinished_index()) + return (index + (1 - end_flag)) % self._size def update(self, buffer: "ReplayBuffer") -> None: """Move the data from the given buffer to current buffer.""" if len(buffer) == 0: return stack_num_orig, buffer.stack_num = buffer.stack_num, 1 - indices = buffer.sample_index(0) - buffer.stack_num = stack_num_orig + indices = buffer.sample_index(0) # get all available indices for i in indices: self.add(**buffer[i]) + buffer.stack_num = stack_num_orig def add( self, @@ -287,31 +297,27 @@ def reset(self) -> None: """Clear all the data in replay buffer.""" self._index = self._size = 0 - def sample_index(self, batch_size: int, **kwargs) -> np.ndarray: - """Same as sample(), but only return indices to avoid possible overhead. + def sample_index(self, batch_size: int) -> np.ndarray: + """Get a random sample of index with size = batch_size. - Return all the data in the buffer if batch_size is 0. - - :return: Sample data and its corresponding index inside the buffer. + Return all available indices in the buffer if batch_size is 0. """ if batch_size > 0: - indices = np.random.choice(self._size, batch_size) + return np.random.choice(self._size, batch_size) else: # construct current available indices - indices = np.concatenate([ + return np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index), ]) - assert len(indices) > 0, "No available indice can be sampled." - return indices - def sample(self, batch_size: int, **kwargs) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with size equal to batch_size. + def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: + """Get a random sample from buffer with size = batch_size. Return all the data in the buffer if batch_size is 0. :return: Sample data and its corresponding index inside the buffer. """ - indices = self.sample_index(batch_size, **kwargs) + indices = self.sample_index(batch_size) return self[indices], indices def get( @@ -334,7 +340,7 @@ def get( stack: List[Any] = [] for _ in range(stack_num): stack = [val[indice]] + stack - indice = self.prev(indice, within_episode=True) + indice = self.prev(indice) if isinstance(val, Batch): return Batch.stack(stack, axis=indice.ndim) else: @@ -352,13 +358,12 @@ def __getitem__( If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, ...). """ - # TODO vec_buffer cannot iheritage from this if containing self._size - index = self._indices[:self._size][index] # change slice to np array + if isinstance(index, slice): # change slice to np array + index = self._indices[:len(self)][index] if self._save_obs_next: obs_next = self.get(index, "obs_next") else: - next_index = self.next(index, within_episode=True) - obs_next = self.get(next_index, "obs") + obs_next = self.get(self.next(index), "obs") return Batch( obs=self.get(index, "obs"), act=self.act[index], @@ -369,18 +374,6 @@ def __getitem__( policy=self.get(index, "policy"), ) - def set_batch(self, batch: Batch): - """Manually choose the batch you want the ReplayBuffer to manage.""" - assert self._is_legal_batch(batch), ( - "Input batch doesn't meet ReplayBuffer's data form requirement.") - self._meta = batch - - def _is_legal_batch(self, batch: Batch) -> bool: - """Check the given batch is in legal form.""" - if set(batch.keys()) != self._reserved_keys: - return False - return len(batch) == self.maxsize - def save_hdf5(self, path: str) -> None: """Save replay buffer to HDF5 file.""" with h5py.File(path, "w") as f: @@ -413,6 +406,7 @@ class ListReplayBuffer(ReplayBuffer): """ def __init__(self, **kwargs: Any) -> None: + warnings.warn("ListReplayBuffer will be replaced soon.") super().__init__(size=0, ignore_obs_next=False, **kwargs) def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: @@ -421,14 +415,14 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def _add_to_buffer( self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool] ) -> None: - if self._meta.__dict__.get(name) is None: + if self._meta.get(name) is None: self._meta.__dict__[name] = [] - self._meta.__dict__[name].append(inst) + self._meta[name].append(inst) def reset(self) -> None: self._index = self._size = 0 - for k in list(self._meta.__dict__.keys()): - if isinstance(self._meta.__dict__[k], list): + for k in self._meta.keys(): + if isinstance(self._meta[k], list): self._meta.__dict__[k] = [] @@ -477,6 +471,17 @@ def add( self.weight[self._index] = weight ** self._alpha super().add(obs, act, rew, done, obs_next, info, policy, **kwargs) + def sample_index(self, batch_size: int) -> np.ndarray: + assert self._size > 0, "Cannot sample a buffer with 0 size." + if batch_size == 0: + return np.concatenate([ + np.arange(self._index, self._size), + np.arange(0, self._index), + ]) + else: + scalar = np.random.rand(batch_size) * self.weight.reduce() + return self.weight.get_prefix_sum_idx(scalar) + def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with priority probability. @@ -488,15 +493,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ - assert self._size > 0, "Cannot sample a buffer with 0 size!" - if batch_size == 0: - indice = np.concatenate([ - np.arange(self._index, self._size), - np.arange(0, self._index), - ]) - else: - scalar = np.random.rand(batch_size) * self.weight.reduce() - indice = self.weight.get_prefix_sum_idx(scalar) + indice = self.sample_index(batch_size) batch = self[indice] # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) @@ -519,181 +516,131 @@ def update_weight( self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) + def __getitem__( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> Batch: + batch = super().__getitem__(index) + batch.weight = self.weight[index] + return batch + + +class VectorReplayBuffer(ReplayBuffer): + """VectorReplayBuffer contains n ReplayBuffer with the given size, where \ + n equals to buffer_num. + + :param float size: the size of each ReplayBuffer. + :param float buffer_num: number of ReplayBuffer needs to be handled. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed + explanation. + """ -class VecReplayBuffer(ReplayBuffer): def __init__( - self, - size: int, - buf_n: int, - **kwargs: Any, + self, size: int, buffer_num: int, **kwargs: Any ) -> None: - # TODO can size==0? - assert size > 0 - assert buf_n > 0 - if buf_n == 1: - import warnings - warnings.warn( - "VecReplayBuffer with buf_n = 1 will cause low efficiency. " - "Please consider using ReplayBuffer which is not in vector" - "form.", Warning) - _maxsize = buf_n * size - self.buf_n = buf_n - self.bufs = np.array([ReplayBuffer(size, **kwargs) - for _ in range(buf_n)]) - super().__init__(size=_maxsize, **kwargs) + + def buffer_alloc_fn( + buffer: ReplayBuffer, key: List[str], value: Any + ) -> None: + data = self._meta + for k in key[:-1]: + data = data[k] + data[key[-1]] = _create_value(value, self.maxsize) + self._set_batch_for_children() + + assert size > 0 and buffer_num > 0 + self.buffer_num = buffer_num + kwargs["alloc_fn"] = kwargs.get("alloc_fn", buffer_alloc_fn) + self.buffers = [ReplayBuffer(size, **kwargs) + for _ in range(buffer_num)] + super().__init__(size=buffer_num * size, **kwargs) def __len__(self) -> int: - return np.sum([len(b) for b in self.bufs]) + return sum([len(buf) for buf in self.buffers]) def unfinished_index(self) -> np.ndarray: - return np.concatenate( - [b.unfinished_index() for b in self.bufs]) + return np.concatenate([buf.unfinished_index() for buf in self.buffers]) + + def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + index = np.asarray(index) + prev_indices = np.zeros_like(index) + for i, buf in enumerate(self.buffers): + lower, upper = buf.maxsize * i, buf.maxsize * (i + 1) + mask = (lower <= index) & (index < upper) + if np.any(mask): + prev_indices[mask] = buf.prev(index[mask] - lower) + lower + return prev_indices - def prev( - self, - index: Union[int, np.integer, np.ndarray], - within_episode: bool = False, - ) -> np.ndarray: - """Return one step previous index.""" - assert np.all(index >= 0) and np.all(index < self.maxsize), \ - "Input index illegal." - if not within_episode: - # it's hard to define behavior of next(within_episode = False) - # here, perhaps we don't need within_episode = False option anyway? - raise NotImplementedError - last_indices = np.full(index.shape, -1) - upper = 0 - lower = 0 - for b in self.bufs: - lower = upper - upper += b.maxsize - mask = last_indices >= lower and last_indices < upper - last_indices[mask] = b.prev(last_indices[mask] - lower, - within_episode=True) + lower - return last_indices - - def next( - self, - index: Union[int, np.integer, np.ndarray], - within_episode: bool = False, - ) -> np.ndarray: - """Return one step next index.""" - assert np.all(index >= 0) and np.all(index < self.maxsize), \ - "Input index illegal." - if not within_episode: - # it's hard to define behavior of next(within_episode = False) - # here, perhaps we don't need within_episode = False option anyway? - raise NotImplementedError - next_indices = np.full(index.shape, -1) - upper = 0 - lower = 0 - for b in self.bufs: - lower = upper - upper += b.maxsize - mask = next_indices >= lower and next_indices < upper - next_indices[mask] = b.next(next_indices[mask] - lower, - within_episode=True) + lower + def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + index = np.asarray(index) + next_indices = np.zeros_like(index) + for i, buf in enumerate(self.buffers): + lower, upper = buf.maxsize * i, buf.maxsize * (i + 1) + mask = (lower <= index) & (index < upper) + if np.any(mask): + next_indices[mask] = buf.next(index[mask] - lower) + lower return next_indices - def update(self, **kwargs): + def update(self, buffer: ReplayBuffer) -> None: raise NotImplementedError def add( self, obs: Any, act: Any, - rew: Union[Number, np.number, np.ndarray], - done: Union[Number, np.number, np.bool_], - obs_next: Any = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {}, - index: Optional[Union[int, np.integer, np.ndarray, List[int]]] = None, - type_check: bool = True, + rew: Union[np.ndarray], + done: Union[np.ndarray], + obs_next: Any = Batch(), + info: Optional[Batch] = Batch(), + policy: Optional[Batch] = Batch(), + buffer_index: Optional[Union[np.ndarray, List[int]]] = None, **kwargs: Any ) -> None: - if type_check: - if index is None: - index = range(self.cached_bufs_n) - index = np.atleast_1d(index).astype(np.int) - assert(index.ndim == 1) - - obs = np.atleast_1d(obs) - act = np.atleast_1d(act) - rew = np.atleast_1d(rew) - done = np.atleast_1d(done) - # TODO ugly code - if isinstance(obs_next, Batch) and obs_next.is_empty(): - obs_next = None - if isinstance(info, Batch) and info.is_empty(): - info = {} - if isinstance(policy, Batch) and policy.is_empty(): - policy = {} - obs_next = np.atleast_1d([None] * len(index)) \ - if obs_next is None else np.atleast_1d(obs_next) - info = np.atleast_1d( - [{}] * len(index)) if info == {} else np.atleast_1d(info) - policy = np.atleast_1d( - [{}] * len(index)) if policy == {} else np.atleast_1d(policy) + """Add a batch of data into VectorReplayBuffer. - if self._meta.is_empty(): - self._initialise(obs[0], act[0], rew[0], done[0], obs_next[0], - info[0], policy[0]) - # now we add data to selected bufs one by one - bufs_slice = self.bufs[index] - for i, b in enumerate(bufs_slice): - b.add(obs[i], act[i], rew[i], done[i], - obs_next[i], info[i], policy[i]) - - def _initialise( - self, - obs: Any, - act: Any, - rew: Union[Number, np.number, np.ndarray], - done: Union[Number, np.number, np.bool_], - obs_next: Any = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {} - ) -> None: - assert(self._meta.is_empty()) - # to initialise self._meta - super().add(obs, act, rew, done, obs_next, info, policy) - super().reset() - del self._index - del self._size - # TODO check method that use these 2 value - self._set_batch_for_children() + Each of the data's length (first dimension) must equal to the length of + buffer_index. + """ + assert buffer_index is not None, \ + "buffer_index is required in VectorReplayBuffer.add()" + batch = Batch(obs=obs, act=act, rew=rew, done=done, + obs_next=obs_next, info=info, policy=policy) + assert len(buffer_index) == len(batch) + for batch_idx, buffer_idx in enumerate(buffer_index): + self.buffers[buffer_idx].add(**batch[batch_idx]) def _set_batch_for_children(self): - start = 0 - for buf in self.bufs: - end = start + buf.maxsize - buf.set_batch(self._meta[start: end]) - start = end + for i, buf in enumerate(self.buffers): + start, end = buf.maxsize * i, buf.maxsize * (i + 1) + buf.set_batch(self._meta[start:end]) - def set_batch(self, batch: "Batch"): - """Manually choose the batch you want the ReplayBuffer to manage.""" + def set_batch(self, batch: Batch): super().set_batch(batch) self._set_batch_for_children() def reset(self) -> None: - for buf in self.bufs: + for buf in self.buffers: buf.reset() - def sample_index(self, batch_size: int, **kwargs) -> np.ndarray: - avail_indexes = np.arange(len(self), dtype=np.int) - start = 0 - add = 0 - for buf in self.bufs: - end = start + len(buf) - avail_indexes[start:end] = avail_indexes[start:end] + add - start = end - add = add + buf.maxsize - len(buf) - # TODO consider making avail_indexes a seperate method - if batch_size == 0: - # data not in chronological order - return avail_indexes + def sample_index(self, batch_size: int) -> np.ndarray: + if batch_size == 0: # get all available indices + sample_num = np.zeros(self.buffer_num, np.int) else: - return np.random.choice(avail_indexes, batch_size) + buffer_lens = np.array([len(buf) for buf in self.buffers]) + buffer_idx = np.random.choice(self.buffer_num, batch_size, + p=buffer_lens / buffer_lens.sum()) + sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) + # avoid batch_size > 0 and sample_num == 0 -> get child's all data + sample_num[sample_num == 0] = -1 + + indices = [] + for i, (buf, bsz) in enumerate(zip(self.buffers, sample_num)): + if bsz >= 0: + offset = buf.maxsize * i + indices.append(buf.sample_index(bsz) + offset) + return np.concatenate(indices) class CachedReplayBuffer(ReplayBuffer): @@ -702,41 +649,29 @@ class CachedReplayBuffer(ReplayBuffer): parallel collecting in collector. In CachedReplayBuffer is not organized chronologically, but standard API like start()/starts()/ends/next() are provided to help CachedReplayBuffer to be used just like ReplayBuffer. - #TODO finsih doc + #TODO finish doc """ def __init__( self, size: int, cached_buffer_n: int, - max_length: int, + max_episode_length: int, **kwargs: Any, ) -> None: - """ - TODO support stack in the future - """ - assert cached_buffer_n > 0 - assert max_length > 0 + assert cached_buffer_n > 0 and max_episode_length > 0 if cached_buffer_n == 1: - import warnings warnings.warn( "CachedReplayBuffer with cached_buffer_n = 1 will" "cause low efficiency. Please consider using ReplayBuffer" "which is not in cached form.", Warning) - - _maxsize = size + cached_buffer_n * max_length + maxsize = size + cached_buffer_n * max_episode_length self.cached_bufs_n = cached_buffer_n - # TODO see if we can generalize to all kinds of buffer self.main_buffer = ReplayBuffer(size, **kwargs) - # TODO cached_buffer can be consider to be replced by vector - # buffer in the future - self.cached_buffer = VecReplayBuffer( - max_length, cached_buffer_n, **kwargs) - super().__init__(size=_maxsize, **kwargs) - # TODO support, or just delete stack_num option from - # Replay buffer for now - assert self.stack_num == 1 + self.cached_buffer = VectorReplayBuffer( + max_episode_length, cached_buffer_n, **kwargs) + super().__init__(size=maxsize, **kwargs) def __len__(self) -> int: """Return len(self).""" @@ -746,47 +681,25 @@ def unfinished_index(self) -> np.ndarray: return np.concatenate([self.main_buffer.unfinished_index(), self.cached_buffer.unfinished_index()]) - def prev( - self, - index: Union[int, np.integer, np.ndarray], - within_episode: bool = False, - ) -> np.ndarray: + def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: """Return one step previous index.""" - assert np.all(index >= 0) and np.all(index < self.maxsize), \ - "Input index illegal." - if not within_episode: - # it's hard to define behavior of prev(within_episode = False) - # here, perhaps we don't need within_episode = False option anyway? - raise NotImplementedError + assert np.all(index >= 0) and np.all(index < self.maxsize) prev_indices = np.full(index.shape, -1) mask = index < self.main_buffer.maxsize - prev_indices[mask] = self.main_buffer.prev(index[mask], - within_episode=True) + prev_indices[mask] = self.main_buffer.prev(index[mask]) prev_indices[~mask] = self.cached_buffer.prev( - index[~mask] - self.main_buffer.maxsize, - within_episode=True) + \ + index[~mask] - self.main_buffer.maxsize) + \ self.main_buffer.maxsize return prev_indices - def next( - self, - index: Union[int, np.integer, np.ndarray], - within_episode: bool = False, - ) -> np.ndarray: + def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: """Return one step next index.""" - assert np.all(index >= 0) and np.all(index < self.maxsize), \ - "Input index illegal." - if not within_episode: - # it's hard to define behavior of next(within_episode = False) - # here, perhaps we don't need within_episode = False option anyway? - raise NotImplementedError + assert np.all(index >= 0) and np.all(index < self.maxsize) next_indices = np.full(index.shape, -1) mask = index < self.main_buffer.maxsize - next_indices[mask] = self.main_buffer.next(index[mask], - within_episode=True) + next_indices[mask] = self.main_buffer.next(index[mask]) next_indices[~mask] = self.cached_buffer.next( - index[~mask] - self.main_buffer.maxsize, - within_episode=True) + \ + index[~mask] - self.main_buffer.maxsize) + \ self.main_buffer.maxsize return next_indices From 443969d0b09bb8bb2934a08c2452a2a91d336686 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sun, 24 Jan 2021 16:18:16 +0800 Subject: [PATCH 011/104] update cached --- test/base/test_buffer.py | 24 ++- tianshou/data/buffer.py | 331 ++++++++++++++++----------------------- 2 files changed, 151 insertions(+), 204 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 3991d6a4d..e6bd86a31 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -9,7 +9,7 @@ from tianshou.data import Batch, SegmentTree, ReplayBuffer from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer -from tianshou.data import VectorReplayBuffer +from tianshou.data import VectorReplayBuffer, CachedReplayBuffer from tianshou.data.utils.converter import to_hdf5 if __name__ == '__main__': @@ -282,7 +282,8 @@ def test_hdf5(): buffers = { "array": ReplayBuffer(size, stack_num=2), "list": ListReplayBuffer(), - "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4) + "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), + "vector": VectorReplayBuffer(size, buffer_num=4), } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -298,6 +299,8 @@ def test_hdf5(): buffers["array"].add(**kwargs) buffers["list"].add(**kwargs) buffers["prioritized"].add(weight=np.random.rand(), **kwargs) + buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + env_ids=[0, 1, 2]) # save paths = {} @@ -345,7 +348,7 @@ def test_hdf5(): def test_vectorbuffer(): buf = VectorReplayBuffer(5, 4) buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], - done=[0, 0, 1], buffer_index=[0, 1, 2]) + done=[0, 0, 1], env_ids=[0, 1, 2]) batch, indice = buf.sample(10) batch, indice = buf.sample(0) assert np.allclose(indice, [0, 5, 10]) @@ -353,7 +356,7 @@ def test_vectorbuffer(): assert np.allclose(indice_prev, indice), indice_prev indice_next = buf.next(indice) assert np.allclose(indice_next, indice), indice_next - buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_index=[3]) + buf.add(obs=[4], act=[4], rew=[4], done=[1], env_ids=[3]) batch, indice = buf.sample(10) batch, indice = buf.sample(0) assert np.allclose(indice, [0, 5, 10, 15]) @@ -362,13 +365,13 @@ def test_vectorbuffer(): indice_next = buf.next(indice) assert np.allclose(indice_next, indice), indice_next data = np.array([0, 0, 0, 0]) - buf.add(obs=data, act=data, rew=data, done=data, buffer_index=[0, 1, 2, 3]) + buf.add(obs=data, act=data, rew=data, done=data, env_ids=[0, 1, 2, 3]) buf.add(obs=data, act=data, rew=data, done=1 - data, - buffer_index=[0, 1, 2, 3]) + env_ids=[0, 1, 2, 3]) assert len(buf) == 12 - buf.add(obs=data, act=data, rew=data, done=data, buffer_index=[0, 1, 2, 3]) + buf.add(obs=data, act=data, rew=data, done=data, env_ids=[0, 1, 2, 3]) buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1], - buffer_index=[0, 1, 2, 3]) + env_ids=[0, 1, 2, 3]) assert len(buf) == 20 batch, indice = buf.sample(10) indice = buf.sample_index(0) @@ -394,6 +397,11 @@ def test_vectorbuffer(): 15, 17, 17, 19, 19, ]) # TODO: prev/next/stack/hdf5 + # CachedReplayBuffer + buf = CachedReplayBuffer(10, 4, 5) + assert buf.sample_index(0).tolist() == [] + buf.add(obs=[1], act=[1], rew=[1], done=[1], env_ids=[1]) + print(buf) if __name__ == '__main__': diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 381418962..140910c5e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -136,6 +136,7 @@ class ReplayBuffer: a shape of (timestep, ...) because of temporal stacking, defaults to False. """ + _reserved_keys = ("obs", "act", "rew", "done", "obs_next", "info", "policy") @@ -147,8 +148,12 @@ def __init__( save_only_last_obs: bool = False, alloc_fn: Optional[Callable[["ReplayBuffer", List[str], Any], None]] = None, + sample_avail: bool = False, ) -> None: super().__init__() + if sample_avail: + warnings.warn("sample_avail is deprecated. Please check out " + "tianshou version <= 0.3.1 if you want to use it.") self.maxsize = size assert stack_num > 0, "stack_num should greater than 0" self.stack_num = stack_num @@ -184,6 +189,9 @@ def __setstate__(self, state: Dict[str, Any]) -> None: """ self.__dict__.update(state) + def __getstate__(self) -> Dict[str, Any]: + return self.__dict__ + def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" assert key not in self._reserved_keys, ( @@ -213,7 +221,7 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: self._alloc(self, [name, key], inst) value[self._index] = inst - def set_batch(self, batch: Batch): + def set_batch(self, batch: Batch) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" assert len(batch) == self.maxsize, \ "Input batch doesn't meet ReplayBuffer's data form requirement." @@ -221,39 +229,34 @@ def set_batch(self, batch: Batch): def unfinished_index(self) -> np.ndarray: """Return the index of unfinished episode.""" - try: - last = (self._index - 1) % self._size - except ZeroDivisionError: - return np.array([]) - return np.array([last]) if not self.done[last] else np.array([]) + last = (self._index - 1) % self._size if self._size else 0 + return np.array([last] if not self.done[last] else [], np.int) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: """Return the index of previous transition. - The index won't be modified if it is the beginning of one episode. + The index won't be modified if it is the beginning of an episode. """ - prev_index = (index - 1) % self._size - end_flag = self.done[prev_index] | \ - np.isin(prev_index, self.unfinished_index()) - prev_index = (prev_index + end_flag) % self._size - return prev_index + index = (index - 1) % self._size + end_flag = self.done[index] | np.isin(index, self.unfinished_index()) + return (index + end_flag) % self._size def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: """Return the index of next transition. - The index won't be modified if it is the end of one episode. + The index won't be modified if it is the end of an episode. """ end_flag = self.done[index] | np.isin(index, self.unfinished_index()) return (index + (1 - end_flag)) % self._size def update(self, buffer: "ReplayBuffer") -> None: """Move the data from the given buffer to current buffer.""" - if len(buffer) == 0: + if len(buffer) == 0 or self.maxsize == 0: return stack_num_orig, buffer.stack_num = buffer.stack_num, 1 indices = buffer.sample_index(0) # get all available indices for i in indices: - self.add(**buffer[i]) + self.add(**buffer[i]) # type: ignore buffer.stack_num = stack_num_orig def add( @@ -300,15 +303,19 @@ def reset(self) -> None: def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. - Return all available indices in the buffer if batch_size is 0. + Return all available indices in the buffer if batch_size is 0; return + an empty numpy array if batch_size < 0 or no available index can be + sampled. """ if batch_size > 0: return np.random.choice(self._size, batch_size) - else: # construct current available indices + elif batch_size == 0: # construct current available indices return np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index), ]) + else: + return np.array([], np.int) def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with size = batch_size. @@ -342,9 +349,9 @@ def get( stack = [val[indice]] + stack indice = self.prev(indice) if isinstance(val, Batch): - return Batch.stack(stack, axis=indice.ndim) + return Batch.stack(stack, axis=1) else: - return np.stack(stack, axis=indice.ndim) + return np.stack(stack, axis=1) except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() @@ -377,7 +384,7 @@ def __getitem__( def save_hdf5(self, path: str) -> None: """Save replay buffer to HDF5 file.""" with h5py.File(path, "w") as f: - to_hdf5(self.__dict__, f) + to_hdf5(self.__getstate__(), f) @classmethod def load_hdf5( @@ -473,14 +480,11 @@ def add( def sample_index(self, batch_size: int) -> np.ndarray: assert self._size > 0, "Cannot sample a buffer with 0 size." - if batch_size == 0: - return np.concatenate([ - np.arange(self._index, self._size), - np.arange(0, self._index), - ]) - else: + if batch_size > 0: scalar = np.random.rand(batch_size) * self.weight.reduce() return self.weight.get_prefix_sum_idx(scalar) + else: + return super().sample_index(batch_size) def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with priority probability. @@ -528,8 +532,8 @@ class VectorReplayBuffer(ReplayBuffer): """VectorReplayBuffer contains n ReplayBuffer with the given size, where \ n equals to buffer_num. - :param float size: the size of each ReplayBuffer. - :param float buffer_num: number of ReplayBuffer needs to be handled. + :param int size: the size of each ReplayBuffer. + :param int buffer_num: number of ReplayBuffer needs to be handled. .. seealso:: @@ -553,8 +557,8 @@ def buffer_alloc_fn( assert size > 0 and buffer_num > 0 self.buffer_num = buffer_num kwargs["alloc_fn"] = kwargs.get("alloc_fn", buffer_alloc_fn) - self.buffers = [ReplayBuffer(size, **kwargs) - for _ in range(buffer_num)] + self.buffers = np.array([ReplayBuffer(size, **kwargs) + for _ in range(buffer_num)]) super().__init__(size=buffer_num * size, **kwargs) def __len__(self) -> int: @@ -584,9 +588,10 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: return next_indices def update(self, buffer: ReplayBuffer) -> None: + """The VectorReplayBuffer cannot be updated by any buffer.""" raise NotImplementedError - def add( + def add( # type: ignore self, obs: Any, act: Any, @@ -595,28 +600,30 @@ def add( obs_next: Any = Batch(), info: Optional[Batch] = Batch(), policy: Optional[Batch] = Batch(), - buffer_index: Optional[Union[np.ndarray, List[int]]] = None, + env_ids: Optional[Union[np.ndarray, List[int]]] = None, **kwargs: Any ) -> None: """Add a batch of data into VectorReplayBuffer. Each of the data's length (first dimension) must equal to the length of - buffer_index. + env_ids. """ - assert buffer_index is not None, \ - "buffer_index is required in VectorReplayBuffer.add()" + assert env_ids is not None, \ + "env_ids is required in VectorReplayBuffer.add()" + # assume each element in env_ids is unique + assert np.bincount(env_ids).max() == 1 batch = Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next, info=info, policy=policy) - assert len(buffer_index) == len(batch) - for batch_idx, buffer_idx in enumerate(buffer_index): - self.buffers[buffer_idx].add(**batch[batch_idx]) + assert len(env_ids) == len(batch) + for batch_idx, env_id in enumerate(env_ids): + self.buffers[env_id].add(**batch[batch_idx]) - def _set_batch_for_children(self): + def _set_batch_for_children(self) -> None: for i, buf in enumerate(self.buffers): start, end = buf.maxsize * i, buf.maxsize * (i + 1) buf.set_batch(self._meta[start:end]) - def set_batch(self, batch: Batch): + def set_batch(self, batch: Batch) -> None: super().set_batch(batch) self._set_batch_for_children() @@ -625,6 +632,8 @@ def reset(self) -> None: buf.reset() def sample_index(self, batch_size: int) -> np.ndarray: + if batch_size < 0: + return np.array([], np.int) if batch_size == 0: # get all available indices sample_num = np.zeros(self.buffer_num, np.int) else: @@ -637,44 +646,61 @@ def sample_index(self, batch_size: int) -> np.ndarray: indices = [] for i, (buf, bsz) in enumerate(zip(self.buffers, sample_num)): - if bsz >= 0: - offset = buf.maxsize * i - indices.append(buf.sample_index(bsz) + offset) + offset = buf.maxsize * i + indices.append(buf.sample_index(bsz) + offset) return np.concatenate(indices) class CachedReplayBuffer(ReplayBuffer): - """CachedReplayBuffer can be considered as a combination of one main buffer - and a list of cached_buffers. It's designed to used by collector to allow - parallel collecting in collector. In CachedReplayBuffer is not organized - chronologically, but standard API like start()/starts()/ends/next() are - provided to help CachedReplayBuffer to be used just like ReplayBuffer. - #TODO finish doc + """CachedReplayBuffer contains a ReplayBuffer with the given size as the \ + main buffer, and a VectorReplayBuffer with cached_buffer_num * \ + ReplayBuffer(size=max_episode_length) as the cached buffer. + + The data is first stored in cached buffers. When the episode is + terminated, the data will move to the main buffer and the corresponding + cached buffer will be reset. + + :param int size: the size of main buffer. + :param int cached_buffer_num: number of ReplayBuffer needs to be created + for cached buffer. + :param int max_episode_length: the maximum length of one episode, used in + each cached buffer's maxsize. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` or + :class:`~tianshou.data.VectorReplayBuffer` for more detailed + explanation. """ def __init__( self, size: int, - cached_buffer_n: int, + cached_buffer_num: int, max_episode_length: int, **kwargs: Any, ) -> None: - assert cached_buffer_n > 0 and max_episode_length > 0 - if cached_buffer_n == 1: - warnings.warn( - "CachedReplayBuffer with cached_buffer_n = 1 will" - "cause low efficiency. Please consider using ReplayBuffer" - "which is not in cached form.", - Warning) - maxsize = size + cached_buffer_n * max_episode_length - self.cached_bufs_n = cached_buffer_n + + def buffer_alloc_fn( + buffer: ReplayBuffer, key: List[str], value: Any + ) -> None: + data = self._meta + for k in key[:-1]: + data = data[k] + data[key[-1]] = _create_value(value, self.maxsize) + self._set_batch_for_children() + + assert cached_buffer_num > 0 and max_episode_length > 0 + self.offset = size + self.cached_buffer_num = cached_buffer_num + kwargs["alloc_fn"] = kwargs.get("alloc_fn", buffer_alloc_fn) self.main_buffer = ReplayBuffer(size, **kwargs) self.cached_buffer = VectorReplayBuffer( - max_episode_length, cached_buffer_n, **kwargs) + max_episode_length, cached_buffer_num, **kwargs) + maxsize = size + cached_buffer_num * max_episode_length super().__init__(size=maxsize, **kwargs) def __len__(self) -> int: - """Return len(self).""" return len(self.main_buffer) + len(self.cached_buffer) def unfinished_index(self) -> np.ndarray: @@ -682,165 +708,78 @@ def unfinished_index(self) -> np.ndarray: self.cached_buffer.unfinished_index()]) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - """Return one step previous index.""" - assert np.all(index >= 0) and np.all(index < self.maxsize) - prev_indices = np.full(index.shape, -1) - mask = index < self.main_buffer.maxsize + index = np.asarray(index) + prev_indices = np.zeros_like(index) + mask = index < self.offset prev_indices[mask] = self.main_buffer.prev(index[mask]) prev_indices[~mask] = self.cached_buffer.prev( - index[~mask] - self.main_buffer.maxsize) + \ - self.main_buffer.maxsize + index[~mask] - self.offset) + self.offset return prev_indices def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - """Return one step next index.""" - assert np.all(index >= 0) and np.all(index < self.maxsize) - next_indices = np.full(index.shape, -1) - mask = index < self.main_buffer.maxsize + index = np.asarray(index) + next_indices = np.zeros_like(index) + mask = index < self.offset next_indices[mask] = self.main_buffer.next(index[mask]) next_indices[~mask] = self.cached_buffer.next( - index[~mask] - self.main_buffer.maxsize) + \ - self.main_buffer.maxsize + index[~mask] - self.offset) + self.offset return next_indices - def update(self, buffer: "ReplayBuffer") -> int: - """CachedReplayBuffer will only update data from buffer which is in - episode form. Return an integer which indicates the number of steps - being ignored.""" - # For now update method copy element one by one, which is too slow. - if isinstance(buffer, CachedReplayBuffer): - buffer = buffer.main_buffer - # now treat buffer like a normal ReplayBuffer and - # remove those incomplete steps - if len(buffer) == 0: - return 0 - diposed_count = 0 - # TODO use standard API now - end = (buffer._index - 1) % len(buffer) - begin = buffer._index % len(buffer) - while True: - if buffer.done[end] > 0: - break - else: - diposed_count = diposed_count + 1 - if end == begin: - assert diposed_count == len(self) - return diposed_count - end = (end - 1) % len(buffer) - while True: - self.main_buffer.add(**buffer[begin]) - if begin == end: - return diposed_count - begin = (begin + 1) % len(buffer) + def update(self, buffer: ReplayBuffer) -> None: + self.main_buffer.update(buffer) - def add( + def add( # type: ignore self, obs: Any, act: Any, rew: Union[Number, np.number, np.ndarray], done: Union[Number, np.number, np.bool_], - obs_next: Any = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {}, - index: Optional[Union[int, np.integer, np.ndarray, List[int]]] = None, - **kwargs: Any + obs_next: Any = Batch(), + info: Optional[Batch] = Batch(), + policy: Optional[Batch] = Batch(), + env_ids: Optional[Union[np.ndarray, List[int]]] = None, + **kwargs: Any, ) -> None: - if index is None: - index = range(self.cached_bufs_n) - index = np.atleast_1d(index).astype(np.int) - assert(index.ndim == 1) - - obs = np.atleast_1d(obs) - act = np.atleast_1d(act) - rew = np.atleast_1d(rew) - done = np.atleast_1d(done) - # TODO ugly code - if isinstance(obs_next, Batch) and obs_next.is_empty(): - obs_next = None - if isinstance(info, Batch) and info.is_empty(): - info = {} - if isinstance(policy, Batch) and policy.is_empty(): - policy = {} - obs_next = np.atleast_1d([None] * len(index)) \ - if obs_next is None else np.atleast_1d(obs_next) - info = np.atleast_1d( - [{}] * len(index)) if info == {} else np.atleast_1d(info) - policy = np.atleast_1d( - [{}] * len(index)) if policy == {} else np.atleast_1d(policy) - - # TODO what if data is already in episodes, - # what if i want to add mutiple data ? - if self._meta.is_empty(): - self._initialise(obs[0], act[0], rew[0], done[0], obs_next[0], - info[0], policy[0]) + """Add a batch of data into CachedReplayBuffer. + Each of the data's length (first dimension) must equal to the length of + env_ids. + """ + # if env_ids is None, an exception will raise from cached_buffer self.cached_buffer.add(obs, act, rew, done, obs_next, - info, policy, index, False, **kwargs) - return self._main_buf_update() - - def _main_buf_update(self): - lens = np.zeros((self.cached_bufs_n, ), dtype=np.int) - rews = np.zeros((self.cached_bufs_n, )) - start_indexs = np.zeros((self.cached_bufs_n, ), dtype=np.int) - for i, buf in enumerate(self.cached_buffer.bufs): - if buf.done[buf._index - 1] > 0: - lens[i] = len(buf) - rews[i] = np.sum(buf.rew[:lens[i]]) - start_indexs[i] = self.main_buffer._index - if self.main_buffer.maxsize > 0: - # _maxsize of main_buffer might be 0 in test collector. - self.main_buffer.update(buf) - buf.reset() - return lens, rews, start_indexs + info, policy, env_ids, **kwargs) + # find the terminated episode, move data from cached buf to main buf + for buffer_idx in np.asarray(env_ids)[np.asarray(done) > 0]: + self.main_buffer.update(self.cached_buffer.buffers[buffer_idx]) + self.cached_buffer.buffers[buffer_idx].reset() def reset(self) -> None: self.cached_buffer.reset() self.main_buffer.reset() - def sample_index(self, batch_size: int, - is_from_main_buf=False, **kwargs) -> np.ndarray: - if is_from_main_buf: - return self.main_buffer.sample_index(batch_size, **kwargs) - - avail_indexes = np.arange(len(self), dtype=np.int) - start = len(self.main_buffer) - add = self.main_buffer.maxsize - len(self.main_buffer) - for buf in self.cached_buffer.bufs: - end = start + len(buf) - avail_indexes[start:end] = avail_indexes[start:end] + add - start = end - add = add + buf.maxsize - len(buf) - assert len(avail_indexes) > 0, "No available indice can be sampled." - # TODO consider making avail_indexes a seperate method - if batch_size == 0: - # data not in chronological order - return avail_indexes - else: - return np.random.choice(avail_indexes, batch_size) + def _set_batch_for_children(self) -> None: + self.main_buffer.set_batch(self._meta[:self.offset]) + self.cached_buffer.set_batch(self._meta[self.offset:]) - def _initialise( - self, - obs: Any, - act: Any, - rew: Union[Number, np.number, np.ndarray], - done: Union[Number, np.number, np.bool_], - obs_next: Any = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {} - ) -> None: - assert(self._meta.is_empty()) - # to initialise self._meta - super().add(obs, act, rew, done, obs_next, info, policy) - super().reset() # TODO delete useless varible? - del self._index - del self._size + def set_batch(self, batch: Batch) -> None: + super().set_batch(batch) self._set_batch_for_children() - def _set_batch_for_children(self): - self.main_buffer.set_batch(self._meta[:self.main_buffer.maxsize]) - self.cached_buffer.set_batch(self._meta[self.main_buffer.maxsize:]) + def sample_index(self, batch_size: int) -> np.ndarray: + if batch_size < 0: + return np.array([], np.int) + if batch_size == 0: # get all available indices + sample_num = np.array([0, 0], np.int) + else: + buffer_lens = np.array([ + len(self.main_buffer), len(self.cached_buffer)]) + buffer_idx = np.random.choice(2, batch_size, + p=buffer_lens / buffer_lens.sum()) + sample_num = np.bincount(buffer_idx, minlength=2) + # avoid batch_size > 0 and sample_num == 0 -> get child's all data + sample_num[sample_num == 0] = -1 - def set_batch(self, batch: "Batch"): - """Manually choose the batch you want the ReplayBuffer to manage.""" - super().set_batch(batch) - self._set_batch_for_children() + return np.concatenate([ + self.main_buffer.sample_index(sample_num[0]), + self.cached_buffer.sample_index(sample_num[1]) + self.offset, + ]) From ee51e645da5be7327375d40e15eff617b5f3df44 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Mon, 25 Jan 2021 12:45:00 +0800 Subject: [PATCH 012/104] order change, small fix --- tianshou/data/buffer.py | 163 ++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 81 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 140910c5e..39361459d 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -198,32 +198,29 @@ def __setattr__(self, key: str, value: Any) -> None: "key '{}' is reserved and cannot be assigned".format(key)) super().__setattr__(key, value) - @staticmethod - def default_alloc_fn( - buffer: "ReplayBuffer", key: List[str], value: Any - ) -> None: - """Allocate memory on buffer._meta for new (key, value) pair.""" - data = buffer._meta - for k in key[:-1]: - data = data[k] - data[key[-1]] = _create_value(value, buffer.maxsize) + def save_hdf5(self, path: str) -> None: + """Save replay buffer to HDF5 file.""" + with h5py.File(path, "w") as f: + to_hdf5(self.__getstate__(), f) - def _add_to_buffer(self, name: str, inst: Any) -> None: - try: - value = self._meta.__dict__[name] - except KeyError: - self._alloc(self, [name], inst) - value = self._meta[name] - try: - value[self._index] = inst - except KeyError: # inst is a dict/Batch - for key in set(inst.keys()).difference(value.keys()): - self._alloc(self, [name, key], inst) - value[self._index] = inst + @classmethod + def load_hdf5( + cls, path: str, device: Optional[str] = None + ) -> "ReplayBuffer": + """Load replay buffer from HDF5 file.""" + with h5py.File(path, "r") as f: + buf = cls.__new__(cls) + buf.__setstate__(from_hdf5(f, device=device)) + return buf + + def reset(self) -> None: + """Clear all the data in replay buffer.""" + self._index = self._size = 0 def set_batch(self, batch: Batch) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" - assert len(batch) == self.maxsize, \ + assert len(batch) == self.maxsize and \ + set(batch.keys()).issubset(self._reserved_key), \ "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch @@ -259,6 +256,29 @@ def update(self, buffer: "ReplayBuffer") -> None: self.add(**buffer[i]) # type: ignore buffer.stack_num = stack_num_orig + @staticmethod + def default_alloc_fn( + buffer: "ReplayBuffer", key: List[str], value: Any + ) -> None: + """Allocate memory on buffer._meta for new (key, value) pair.""" + data = buffer._meta + for k in key[:-1]: + data = data[k] + data[key[-1]] = _create_value(value, buffer.maxsize) + + def _add_to_buffer(self, name: str, inst: Any) -> None: + try: + value = self._meta.__dict__[name] + except KeyError: + self._alloc(self, [name], inst) + value = self._meta[name] + try: + value[self._index] = inst + except KeyError: # inst is a dict/Batch + for key in set(inst.keys()).difference(value.keys()): + self._alloc(self, [name, key], inst) + value[self._index] = inst + def add( self, obs: Any, @@ -296,10 +316,6 @@ def add( else: # TODO: remove this after deleting ListReplayBuffer self._size = self._index = self._size + 1 - def reset(self) -> None: - """Clear all the data in replay buffer.""" - self._index = self._size = 0 - def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. @@ -381,21 +397,6 @@ def __getitem__( policy=self.get(index, "policy"), ) - def save_hdf5(self, path: str) -> None: - """Save replay buffer to HDF5 file.""" - with h5py.File(path, "w") as f: - to_hdf5(self.__getstate__(), f) - - @classmethod - def load_hdf5( - cls, path: str, device: Optional[str] = None - ) -> "ReplayBuffer": - """Load replay buffer from HDF5 file.""" - with h5py.File(path, "r") as f: - buf = cls.__new__(cls) - buf.__setstate__(from_hdf5(f, device=device)) - return buf - class ListReplayBuffer(ReplayBuffer): """List-based replay buffer. @@ -564,6 +565,19 @@ def buffer_alloc_fn( def __len__(self) -> int: return sum([len(buf) for buf in self.buffers]) + def reset(self) -> None: + for buf in self.buffers: + buf.reset() + + def _set_batch_for_children(self) -> None: + for i, buf in enumerate(self.buffers): + start, end = buf.maxsize * i, buf.maxsize * (i + 1) + buf.set_batch(self._meta[start:end]) + + def set_batch(self, batch: Batch) -> None: + super().set_batch(batch) + self._set_batch_for_children() + def unfinished_index(self) -> np.ndarray: return np.concatenate([buf.unfinished_index() for buf in self.buffers]) @@ -600,37 +614,23 @@ def add( # type: ignore obs_next: Any = Batch(), info: Optional[Batch] = Batch(), policy: Optional[Batch] = Batch(), - env_ids: Optional[Union[np.ndarray, List[int]]] = None, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, **kwargs: Any ) -> None: """Add a batch of data into VectorReplayBuffer. - Each of the data's length (first dimension) must equal to the length of - env_ids. + buffer_ids. TODO buffer_ids default to all buffers in sequential order. """ - assert env_ids is not None, \ - "env_ids is required in VectorReplayBuffer.add()" - # assume each element in env_ids is unique - assert np.bincount(env_ids).max() == 1 + if buffer_ids is None: + buffer_ids = np.arange(self.buffer_num) + # assume each element in buffer_ids is unique + assert np.bincount(buffer_ids).max() == 1 batch = Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next, info=info, policy=policy) - assert len(env_ids) == len(batch) - for batch_idx, env_id in enumerate(env_ids): + assert len(buffer_ids) == len(batch) + for batch_idx, env_id in enumerate(buffer_ids): self.buffers[env_id].add(**batch[batch_idx]) - def _set_batch_for_children(self) -> None: - for i, buf in enumerate(self.buffers): - start, end = buf.maxsize * i, buf.maxsize * (i + 1) - buf.set_batch(self._meta[start:end]) - - def set_batch(self, batch: Batch) -> None: - super().set_batch(batch) - self._set_batch_for_children() - - def reset(self) -> None: - for buf in self.buffers: - buf.reset() - def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: return np.array([], np.int) @@ -703,6 +703,18 @@ def buffer_alloc_fn( def __len__(self) -> int: return len(self.main_buffer) + len(self.cached_buffer) + def reset(self) -> None: + self.cached_buffer.reset() + self.main_buffer.reset() + + def _set_batch_for_children(self) -> None: + self.main_buffer.set_batch(self._meta[:self.offset]) + self.cached_buffer.set_batch(self._meta[self.offset:]) + + def set_batch(self, batch: Batch) -> None: + super().set_batch(batch) + self._set_batch_for_children() + def unfinished_index(self) -> np.ndarray: return np.concatenate([self.main_buffer.unfinished_index(), self.cached_buffer.unfinished_index()]) @@ -726,7 +738,7 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: return next_indices def update(self, buffer: ReplayBuffer) -> None: - self.main_buffer.update(buffer) + raise NotImplementedError def add( # type: ignore self, @@ -737,33 +749,22 @@ def add( # type: ignore obs_next: Any = Batch(), info: Optional[Batch] = Batch(), policy: Optional[Batch] = Batch(), - env_ids: Optional[Union[np.ndarray, List[int]]] = None, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, **kwargs: Any, ) -> None: """Add a batch of data into CachedReplayBuffer. Each of the data's length (first dimension) must equal to the length of - env_ids. + buffer_ids. """ - # if env_ids is None, an exception will raise from cached_buffer + # if buffer_ids is None, an exception will raise from cached_buffer self.cached_buffer.add(obs, act, rew, done, obs_next, - info, policy, env_ids, **kwargs) + info, policy, buffer_ids, **kwargs) # find the terminated episode, move data from cached buf to main buf - for buffer_idx in np.asarray(env_ids)[np.asarray(done) > 0]: + for buffer_idx in np.asarray(buffer_ids)[np.asarray(done) > 0]: self.main_buffer.update(self.cached_buffer.buffers[buffer_idx]) self.cached_buffer.buffers[buffer_idx].reset() - - def reset(self) -> None: - self.cached_buffer.reset() - self.main_buffer.reset() - - def _set_batch_for_children(self) -> None: - self.main_buffer.set_batch(self._meta[:self.offset]) - self.cached_buffer.set_batch(self._meta[self.offset:]) - - def set_batch(self, batch: Batch) -> None: - super().set_batch(batch) - self._set_batch_for_children() + # TODO retrun to previous version def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: From a5bc4ad81f106b57c527c1ad4aefbaf8ff8453b1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 25 Jan 2021 18:19:25 +0800 Subject: [PATCH 013/104] try unittest --- test/base/test_buffer.py | 57 ++++++-- tianshou/data/__init__.py | 7 +- tianshou/data/buffer.py | 283 ++++++++++++++++---------------------- 3 files changed, 170 insertions(+), 177 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index e6bd86a31..18c0c27ea 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -9,7 +9,7 @@ from tianshou.data import Batch, SegmentTree, ReplayBuffer from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer -from tianshou.data import VectorReplayBuffer, CachedReplayBuffer +from tianshou.data import ReplayBuffers, CachedReplayBuffer from tianshou.data.utils.converter import to_hdf5 if __name__ == '__main__': @@ -283,7 +283,8 @@ def test_hdf5(): "array": ReplayBuffer(size, stack_num=2), "list": ListReplayBuffer(), "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), - "vector": VectorReplayBuffer(size, buffer_num=4), + "vector": ReplayBuffers([ReplayBuffer(size) for i in range(4)]), + "cached": CachedReplayBuffer(size, 4, size) } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -293,14 +294,16 @@ def test_hdf5(): 'obs': Batch(index=np.array([i])), 'act': i, 'rew': rew, - 'done': 0, + 'done': i % 3 == 2, 'info': {"number": {"n": i}, 'extra': None}, } buffers["array"].add(**kwargs) buffers["list"].add(**kwargs) buffers["prioritized"].add(weight=np.random.rand(), **kwargs) buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), - env_ids=[0, 1, 2]) + buffer_ids=[0, 1, 2]) + buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + cached_buffer_ids=[0, 1, 2]) # save paths = {} @@ -329,6 +332,33 @@ def test_hdf5(): buffers[k][:].info.number.n == _buffers[k][:].info.number.n) assert np.all( buffers[k][:].info.extra == _buffers[k][:].info.extra) + # check shallow copy in ReplayBuffers + for k in ["vector", "cached"]: + buffers[k].info.number.n[0] = -100 + assert buffers[k].buffers[0].info.number.n[0] == -100 + # check if still behave normally + for k in ["vector", "cached"]: + kwargs = { + 'obs': Batch(index=np.array([5])), + 'act': 5, + 'rew': rew, + 'done': False, + 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, + } + buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]])) + act = np.zeros(buffers[k].maxsize) + if k == "vector": + act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) + act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) + act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) + act[size * 3] = 5 + elif k == "cached": + act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) + act[np.arange(3) + size] = np.array([3, 5, 2]) + act[np.arange(3) + size * 2] = np.array([3, 5, 2]) + act[np.arange(3) + size * 3] = np.array([3, 5, 2]) + act[size * 4] = 5 + assert np.allclose(buffers[k].act, act) for path in paths.values(): os.remove(path) @@ -346,9 +376,9 @@ def test_hdf5(): def test_vectorbuffer(): - buf = VectorReplayBuffer(5, 4) + buf = ReplayBuffers([ReplayBuffer(size=5) for i in range(4)]) buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], - done=[0, 0, 1], env_ids=[0, 1, 2]) + done=[0, 0, 1], buffer_ids=[0, 1, 2]) batch, indice = buf.sample(10) batch, indice = buf.sample(0) assert np.allclose(indice, [0, 5, 10]) @@ -356,7 +386,9 @@ def test_vectorbuffer(): assert np.allclose(indice_prev, indice), indice_prev indice_next = buf.next(indice) assert np.allclose(indice_next, indice), indice_next - buf.add(obs=[4], act=[4], rew=[4], done=[1], env_ids=[3]) + assert np.allclose(buf.unfinished_index(), [0, 5]) + buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_ids=[3]) + assert np.allclose(buf.unfinished_index(), [0, 5]) batch, indice = buf.sample(10) batch, indice = buf.sample(0) assert np.allclose(indice, [0, 5, 10, 15]) @@ -365,13 +397,13 @@ def test_vectorbuffer(): indice_next = buf.next(indice) assert np.allclose(indice_next, indice), indice_next data = np.array([0, 0, 0, 0]) - buf.add(obs=data, act=data, rew=data, done=data, env_ids=[0, 1, 2, 3]) + buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3]) buf.add(obs=data, act=data, rew=data, done=1 - data, - env_ids=[0, 1, 2, 3]) + buffer_ids=[0, 1, 2, 3]) assert len(buf) == 12 - buf.add(obs=data, act=data, rew=data, done=data, env_ids=[0, 1, 2, 3]) + buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3]) buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1], - env_ids=[0, 1, 2, 3]) + buffer_ids=[0, 1, 2, 3]) assert len(buf) == 20 batch, indice = buf.sample(10) indice = buf.sample_index(0) @@ -396,11 +428,12 @@ def test_vectorbuffer(): 10, 12, 12, 14, 14, 15, 17, 17, 19, 19, ]) + assert np.allclose(buf.unfinished_index(), [4, 14]) # TODO: prev/next/stack/hdf5 # CachedReplayBuffer buf = CachedReplayBuffer(10, 4, 5) assert buf.sample_index(0).tolist() == [] - buf.add(obs=[1], act=[1], rew=[1], done=[1], env_ids=[1]) + buf.add(obs=[1], act=[1], rew=[1], done=[1], cached_buffer_ids=[1]) print(buf) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index be0959192..1fbf1e481 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,9 +1,8 @@ from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree -from tianshou.data.buffer import ReplayBuffer, \ - ListReplayBuffer, PrioritizedReplayBuffer, \ - VectorReplayBuffer, CachedReplayBuffer +from tianshou.data.buffer import ReplayBuffer, ListReplayBuffer, \ + PrioritizedReplayBuffer, ReplayBuffers, CachedReplayBuffer from tianshou.data.collector import Collector __all__ = [ @@ -15,7 +14,7 @@ "ReplayBuffer", "ListReplayBuffer", "PrioritizedReplayBuffer", - "VectorReplayBuffer", + "ReplayBuffers", "CachedReplayBuffer", "Collector", ] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 39361459d..628c85f10 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -2,7 +2,7 @@ import warnings import numpy as np from numbers import Number -from typing import Any, Dict, List, Tuple, Union, Callable, Optional +from typing import Any, Dict, List, Tuple, Union, Optional from tianshou.data.batch import _create_value from tianshou.data import Batch, SegmentTree, to_numpy @@ -146,8 +146,6 @@ def __init__( stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, - alloc_fn: Optional[Callable[["ReplayBuffer", List[str], Any], None]] - = None, sample_avail: bool = False, ) -> None: super().__init__() @@ -163,7 +161,8 @@ def __init__( self._index = 0 # current index self._size = 0 # current buffer size self._meta: Batch = Batch() - self._alloc = alloc_fn or ReplayBuffer.default_alloc_fn + self._episode_reward = 0.0 + self._episode_length = 0 self.reset() def __len__(self) -> int: @@ -215,19 +214,21 @@ def load_hdf5( def reset(self) -> None: """Clear all the data in replay buffer.""" - self._index = self._size = 0 + self._index = self._size = self._episode_length = 0 + self._episode_reward = 0.0 def set_batch(self, batch: Batch) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" assert len(batch) == self.maxsize and \ - set(batch.keys()).issubset(self._reserved_key), \ + set(batch.keys()).issubset(self._reserved_keys), \ "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch def unfinished_index(self) -> np.ndarray: """Return the index of unfinished episode.""" last = (self._index - 1) % self._size if self._size else 0 - return np.array([last] if not self.done[last] else [], np.int) + return np.array( + [last] if not self.done[last] and self._size else [], np.int) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: """Return the index of previous transition. @@ -256,28 +257,25 @@ def update(self, buffer: "ReplayBuffer") -> None: self.add(**buffer[i]) # type: ignore buffer.stack_num = stack_num_orig - @staticmethod - def default_alloc_fn( - buffer: "ReplayBuffer", key: List[str], value: Any - ) -> None: + def alloc_fn(self, key: List[str], value: Any) -> None: """Allocate memory on buffer._meta for new (key, value) pair.""" - data = buffer._meta + data = self._meta for k in key[:-1]: data = data[k] - data[key[-1]] = _create_value(value, buffer.maxsize) + data[key[-1]] = _create_value(value, self.maxsize) def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] except KeyError: - self._alloc(self, [name], inst) + self.alloc_fn([name], inst) value = self._meta[name] try: value[self._index] = inst except KeyError: # inst is a dict/Batch for key in set(inst.keys()).difference(value.keys()): - self._alloc(self, [name, key], inst) - value[self._index] = inst + self.alloc_fn([name, key], inst[key]) + self._meta[name][self._index] = inst def add( self, @@ -289,8 +287,12 @@ def add( info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, **kwargs: Any, - ) -> None: - """Add a batch of data into replay buffer.""" + ) -> Tuple[int, float]: + """Add a batch of data into replay buffer. + + Return (episode_length, episode_reward) if one episode is terminated, + otherwise return (0, 0.0). + """ assert isinstance( info, (dict, Batch) ), "You should return a dict in the last argument of env.step()." @@ -299,7 +301,8 @@ def add( self._add_to_buffer("obs", obs) self._add_to_buffer("act", act) # make sure the data type of reward is float instead of int - self._add_to_buffer("rew", rew * 1.0) # type: ignore + rew = rew * 1.0 # type: ignore + self._add_to_buffer("rew", rew) self._add_to_buffer("done", bool(done)) # done should be a bool scalar if self._save_obs_next: if obs_next is None: @@ -310,12 +313,22 @@ def add( self._add_to_buffer("info", info) self._add_to_buffer("policy", policy) + self._episode_reward += rew + self._episode_length += 1 + if self.maxsize > 0: self._size = min(self._size + 1, self.maxsize) self._index = (self._index + 1) % self.maxsize else: # TODO: remove this after deleting ListReplayBuffer self._size = self._index = self._size + 1 + if done: + result = (self._episode_length, self._episode_reward) + self._episode_length, self._episode_reward = 0, 0.0 + return result + else: + return (0, 0.0) + def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. @@ -361,13 +374,14 @@ def get( if stack_num == 1: # the most often case return val[indice] stack: List[Any] = [] + indice = np.asarray(indice) for _ in range(stack_num): stack = [val[indice]] + stack indice = self.prev(indice) if isinstance(val, Batch): - return Batch.stack(stack, axis=1) + return Batch.stack(stack, axis=indice.ndim) else: - return np.stack(stack, axis=1) + return np.stack(stack, axis=indice.ndim) except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() @@ -414,7 +428,7 @@ class ListReplayBuffer(ReplayBuffer): """ def __init__(self, **kwargs: Any) -> None: - warnings.warn("ListReplayBuffer will be replaced soon.") + warnings.warn("ListReplayBuffer will be replaced in version 0.4.0.") super().__init__(size=0, ignore_obs_next=False, **kwargs) def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: @@ -468,7 +482,7 @@ def add( policy: Optional[Union[dict, Batch]] = {}, weight: Optional[Union[Number, np.number]] = None, **kwargs: Any, - ) -> None: + ) -> Tuple[int, float]: """Add a batch of data into replay buffer.""" if weight is None: weight = self._max_prio @@ -477,7 +491,8 @@ def add( self._max_prio = max(self._max_prio, weight) self._min_prio = min(self._min_prio, weight) self.weight[self._index] = weight ** self._alpha - super().add(obs, act, rew, done, obs_next, info, policy, **kwargs) + return super().add(obs, act, rew, done, obs_next, + info, policy, **kwargs) def sample_index(self, batch_size: int) -> np.ndarray: assert self._size > 0, "Cannot sample a buffer with 0 size." @@ -529,12 +544,13 @@ def __getitem__( return batch -class VectorReplayBuffer(ReplayBuffer): - """VectorReplayBuffer contains n ReplayBuffer with the given size, where \ - n equals to buffer_num. +class ReplayBuffers(ReplayBuffer): + """ReplayBuffers contains a list of ReplayBuffer. + + These replay buffers have contiguous memory layout, and the storage space + each buffer has is a shallow copy of the topmost memory. - :param int size: the size of each ReplayBuffer. - :param int buffer_num: number of ReplayBuffer needs to be handled. + :param int buffer_list: a list of ReplayBuffers needed to be handled. .. seealso:: @@ -542,25 +558,15 @@ class VectorReplayBuffer(ReplayBuffer): explanation. """ - def __init__( - self, size: int, buffer_num: int, **kwargs: Any - ) -> None: - - def buffer_alloc_fn( - buffer: ReplayBuffer, key: List[str], value: Any - ) -> None: - data = self._meta - for k in key[:-1]: - data = data[k] - data[key[-1]] = _create_value(value, self.maxsize) - self._set_batch_for_children() - - assert size > 0 and buffer_num > 0 - self.buffer_num = buffer_num - kwargs["alloc_fn"] = kwargs.get("alloc_fn", buffer_alloc_fn) - self.buffers = np.array([ReplayBuffer(size, **kwargs) - for _ in range(buffer_num)]) - super().__init__(size=buffer_num * size, **kwargs) + def __init__(self, buffer_list: List[ReplayBuffer]) -> None: + self.buffer_num = len(buffer_list) + self.buffers = np.array(buffer_list) + offset = 0 + for buf in self.buffers: + buf.alloc_fn = self.alloc_fn + buf.offset = offset + offset += buf.maxsize + super().__init__(size=offset) def __len__(self) -> int: return sum([len(buf) for buf in self.buffers]) @@ -571,55 +577,64 @@ def reset(self) -> None: def _set_batch_for_children(self) -> None: for i, buf in enumerate(self.buffers): - start, end = buf.maxsize * i, buf.maxsize * (i + 1) - buf.set_batch(self._meta[start:end]) + buf.set_batch(self._meta[buf.offset:buf.offset + buf.maxsize]) def set_batch(self, batch: Batch) -> None: super().set_batch(batch) self._set_batch_for_children() def unfinished_index(self) -> np.ndarray: - return np.concatenate([buf.unfinished_index() for buf in self.buffers]) + return np.concatenate([ + buf.unfinished_index() + buf.offset for buf in self.buffers]) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: index = np.asarray(index) prev_indices = np.zeros_like(index) - for i, buf in enumerate(self.buffers): - lower, upper = buf.maxsize * i, buf.maxsize * (i + 1) - mask = (lower <= index) & (index < upper) + for buf in self.buffers: + mask = (buf.offset <= index) & (index < buf.offset + buf.maxsize) if np.any(mask): - prev_indices[mask] = buf.prev(index[mask] - lower) + lower + prev_indices[mask] = buf.prev( + index[mask] - buf.offset) + buf.offset return prev_indices def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: index = np.asarray(index) next_indices = np.zeros_like(index) - for i, buf in enumerate(self.buffers): - lower, upper = buf.maxsize * i, buf.maxsize * (i + 1) - mask = (lower <= index) & (index < upper) + for buf in self.buffers: + mask = (buf.offset <= index) & (index < buf.offset + buf.maxsize) if np.any(mask): - next_indices[mask] = buf.next(index[mask] - lower) + lower + next_indices[mask] = buf.next( + index[mask] - buf.offset) + buf.offset return next_indices def update(self, buffer: ReplayBuffer) -> None: """The VectorReplayBuffer cannot be updated by any buffer.""" raise NotImplementedError + def alloc_fn(self, key: List[str], value: Any) -> None: + super().alloc_fn(key, value) + self._set_batch_for_children() + def add( # type: ignore self, obs: Any, act: Any, - rew: Union[np.ndarray], - done: Union[np.ndarray], + rew: np.ndarray, + done: np.ndarray, obs_next: Any = Batch(), info: Optional[Batch] = Batch(), policy: Optional[Batch] = Batch(), buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, **kwargs: Any - ) -> None: - """Add a batch of data into VectorReplayBuffer. + ) -> Tuple[np.ndarray, np.ndarray]: + """Add a batch of data into ReplayBuffers. + Each of the data's length (first dimension) must equal to the length of - buffer_ids. TODO buffer_ids default to all buffers in sequential order. + buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. + + Return the array of episode_length and episode_reward with shape + (len(buffer_ids), ...), where (episode_length[i], episode_reward[i]) + refers to the buffer_ids[i]'s corresponding episode result. """ if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) @@ -628,8 +643,13 @@ def add( # type: ignore batch = Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next, info=info, policy=policy) assert len(buffer_ids) == len(batch) + episode_lengths = [] + episode_rewards = [] for batch_idx, env_id in enumerate(buffer_ids): - self.buffers[env_id].add(**batch[batch_idx]) + length, reward = self.buffers[env_id].add(**batch[batch_idx]) + episode_lengths.append(length) + episode_rewards.append(reward) + return np.array(episode_lengths), np.array(episode_rewards) def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: @@ -644,17 +664,18 @@ def sample_index(self, batch_size: int) -> np.ndarray: # avoid batch_size > 0 and sample_num == 0 -> get child's all data sample_num[sample_num == 0] = -1 - indices = [] - for i, (buf, bsz) in enumerate(zip(self.buffers, sample_num)): - offset = buf.maxsize * i - indices.append(buf.sample_index(bsz) + offset) - return np.concatenate(indices) + return np.concatenate([ + buf.sample_index(bsz) + buf.offset + for buf, bsz in zip(self.buffers, sample_num) + ]) + +class CachedReplayBuffer(ReplayBuffers): + """CachedReplayBuffer contains a ReplayBuffers with given size and n \ + cached buffers, cached_buffer_num * ReplayBuffer(size=max_episode_length). -class CachedReplayBuffer(ReplayBuffer): - """CachedReplayBuffer contains a ReplayBuffer with the given size as the \ - main buffer, and a VectorReplayBuffer with cached_buffer_num * \ - ReplayBuffer(size=max_episode_length) as the cached buffer. + The memory layout is: ``| main_buffer | cached_buffer[0] | cached_buffer[1] + | ... | cached_buffer[cached_buffer_num - 1]``. The data is first stored in cached buffers. When the episode is terminated, the data will move to the main buffer and the corresponding @@ -669,7 +690,7 @@ class CachedReplayBuffer(ReplayBuffer): .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` or - :class:`~tianshou.data.VectorReplayBuffer` for more detailed + :class:`~tianshou.data.ReplayBuffers` for more detailed explanation. """ @@ -680,107 +701,47 @@ def __init__( max_episode_length: int, **kwargs: Any, ) -> None: - - def buffer_alloc_fn( - buffer: ReplayBuffer, key: List[str], value: Any - ) -> None: - data = self._meta - for k in key[:-1]: - data = data[k] - data[key[-1]] = _create_value(value, self.maxsize) - self._set_batch_for_children() - assert cached_buffer_num > 0 and max_episode_length > 0 - self.offset = size self.cached_buffer_num = cached_buffer_num - kwargs["alloc_fn"] = kwargs.get("alloc_fn", buffer_alloc_fn) - self.main_buffer = ReplayBuffer(size, **kwargs) - self.cached_buffer = VectorReplayBuffer( - max_episode_length, cached_buffer_num, **kwargs) - maxsize = size + cached_buffer_num * max_episode_length - super().__init__(size=maxsize, **kwargs) - - def __len__(self) -> int: - return len(self.main_buffer) + len(self.cached_buffer) - - def reset(self) -> None: - self.cached_buffer.reset() - self.main_buffer.reset() - - def _set_batch_for_children(self) -> None: - self.main_buffer.set_batch(self._meta[:self.offset]) - self.cached_buffer.set_batch(self._meta[self.offset:]) - - def set_batch(self, batch: Batch) -> None: - super().set_batch(batch) - self._set_batch_for_children() - - def unfinished_index(self) -> np.ndarray: - return np.concatenate([self.main_buffer.unfinished_index(), - self.cached_buffer.unfinished_index()]) - - def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = np.asarray(index) - prev_indices = np.zeros_like(index) - mask = index < self.offset - prev_indices[mask] = self.main_buffer.prev(index[mask]) - prev_indices[~mask] = self.cached_buffer.prev( - index[~mask] - self.offset) + self.offset - return prev_indices - - def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = np.asarray(index) - next_indices = np.zeros_like(index) - mask = index < self.offset - next_indices[mask] = self.main_buffer.next(index[mask]) - next_indices[~mask] = self.cached_buffer.next( - index[~mask] - self.offset) + self.offset - return next_indices - - def update(self, buffer: ReplayBuffer) -> None: - raise NotImplementedError + main_buffer = ReplayBuffer(size, **kwargs) + buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) + for _ in range(cached_buffer_num)] + super().__init__(buffer_list=buffers) def add( # type: ignore self, obs: Any, act: Any, - rew: Union[Number, np.number, np.ndarray], - done: Union[Number, np.number, np.bool_], + rew: np.ndarray, + done: np.ndarray, obs_next: Any = Batch(), info: Optional[Batch] = Batch(), policy: Optional[Batch] = Batch(), - buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, **kwargs: Any, - ) -> None: + ) -> Tuple[np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. Each of the data's length (first dimension) must equal to the length of - buffer_ids. + cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ..., + cached_buffer_num - 1]. + + Return the array of episode_length and episode_reward with shape + (len(cached_buffer_ids), ...), where (episode_length[i], + episode_reward[i]) refers to the cached_buffer_ids[i]'s corresponding + episode result. """ - # if buffer_ids is None, an exception will raise from cached_buffer - self.cached_buffer.add(obs, act, rew, done, obs_next, - info, policy, buffer_ids, **kwargs) + if cached_buffer_ids is None: + cached_buffer_ids = np.arange(self.buffer_num - 1) + # in self.buffers, the first buffer is main_buffer + buffer_ids = np.asarray(cached_buffer_ids) + 1 + + result = super().add(obs, act, rew, done, obs_next, + info, policy, buffer_ids=buffer_ids, **kwargs) + # find the terminated episode, move data from cached buf to main buf for buffer_idx in np.asarray(buffer_ids)[np.asarray(done) > 0]: - self.main_buffer.update(self.cached_buffer.buffers[buffer_idx]) - self.cached_buffer.buffers[buffer_idx].reset() - # TODO retrun to previous version - - def sample_index(self, batch_size: int) -> np.ndarray: - if batch_size < 0: - return np.array([], np.int) - if batch_size == 0: # get all available indices - sample_num = np.array([0, 0], np.int) - else: - buffer_lens = np.array([ - len(self.main_buffer), len(self.cached_buffer)]) - buffer_idx = np.random.choice(2, batch_size, - p=buffer_lens / buffer_lens.sum()) - sample_num = np.bincount(buffer_idx, minlength=2) - # avoid batch_size > 0 and sample_num == 0 -> get child's all data - sample_num[sample_num == 0] = -1 + self.buffers[0].update(self.buffers[buffer_idx]) + self.buffers[buffer_idx].reset() - return np.concatenate([ - self.main_buffer.sample_index(sample_num[0]), - self.cached_buffer.sample_index(sample_num[1]) + self.offset, - ]) + return result From 17e3612411bb8d5f6cf6e63ba0385cbef7297111 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 25 Jan 2021 21:22:39 +0800 Subject: [PATCH 014/104] add more test and fix bugs --- test/base/test_buffer.py | 121 ++++++++++++++++++++++++++++++++++----- tianshou/data/buffer.py | 33 +++++------ 2 files changed, 121 insertions(+), 33 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 18c0c27ea..c8edd818b 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -92,27 +92,66 @@ def test_ignore_obs_next(size=10): assert data.obs_next -def test_stack(size=5, bufsize=9, stack_num=4): +def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) + buf4 = CachedReplayBuffer(bufsize, cached_num, size, + stack_num=stack_num, ignore_obs_next=True) obs = env.reset(1) - for i in range(16): + for i in range(18): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf3.add([None, None, obs], 1, rew, done, [None, obs], info) + obs_list = np.array([obs + size * i for i in range(cached_num)]) + act_list = [1] * cached_num + rew_list = [rew] * cached_num + done_list = [done] * cached_num + obs_next_list = -obs_list + info_list = [info] * cached_num + buf4.add(obs_list, act_list, rew_list, done_list, + obs_next_list, info_list) obs = obs_next if done: obs = env.reset(1) indice = np.arange(len(buf)) assert np.allclose(buf.get(indice, 'obs')[..., 0], [ - [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], + [2, 2, 2, 2], [2, 2, 2, 3], [2, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) + [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2]]) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs')) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next')) with pytest.raises(IndexError): buf[bufsize * 2] + assert np.allclose(buf4.obs.reshape(-1), [ + 12, 13, 14, 4, 6, 7, 8, 9, 11, + 1, 2, 3, 4, 0, + 6, 7, 8, 9, 0, + 11, 12, 13, 14, 0, + ]), buf4.obs + assert np.allclose(buf4.done, [ + 0, 0, 1, 1, 0, 0, 0, 1, 0, + 0, 0, 0, 1, 0, + 0, 0, 0, 1, 0, + 0, 0, 0, 1, 0, + ]), buf4.done + assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) + indice = sorted(buf4.sample_index(0)) + assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) + assert np.allclose(buf4[indice].obs[..., 0], [ + [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], + [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], + [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], + [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], + [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], + ]) + assert np.allclose(buf4[indice].obs_next[..., 0], [ + [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], + [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], + [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], + [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], + [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], + ]) def test_priortized_replaybuffer(size=32, bufsize=15): @@ -377,9 +416,11 @@ def test_hdf5(): def test_vectorbuffer(): buf = ReplayBuffers([ReplayBuffer(size=5) for i in range(4)]) - buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], - done=[0, 0, 1], buffer_ids=[0, 1, 2]) - batch, indice = buf.sample(10) + ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], + done=[0, 0, 1], buffer_ids=[0, 1, 2]) + assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3]) + indice = buf.sample_index(11000) + assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 batch, indice = buf.sample(0) assert np.allclose(indice, [0, 5, 10]) indice_prev = buf.prev(indice) @@ -405,6 +446,8 @@ def test_vectorbuffer(): buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1], buffer_ids=[0, 1, 2, 3]) assert len(buf) == 20 + indice = buf.sample_index(120000) + assert np.bincount(indice).min() >= 5000 batch, indice = buf.sample(10) indice = buf.sample_index(0) assert np.allclose(indice, np.arange(len(buf))) @@ -414,27 +457,75 @@ def test_vectorbuffer(): 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, ]) - indice_prev = buf.prev(indice) - assert np.allclose(indice_prev, [ + assert np.allclose(buf.prev(indice), [ 0, 0, 1, 3, 3, 5, 5, 6, 8, 8, 10, 11, 11, 13, 13, 15, 16, 16, 18, 18, ]) - indice_next = buf.next(indice) - assert np.allclose(indice_next, [ + assert np.allclose(buf.next(indice), [ 1, 2, 2, 4, 4, 6, 7, 7, 9, 9, 10, 12, 12, 14, 14, 15, 17, 17, 19, 19, ]) assert np.allclose(buf.unfinished_index(), [4, 14]) - # TODO: prev/next/stack/hdf5 + ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[1], + buffer_ids=[2]) + assert np.allclose(ep_len, [3]) and np.allclose(ep_rew, [1]) + assert np.allclose(buf.unfinished_index(), [4]) + indice = list(sorted(buf.sample_index(0))) + assert np.allclose(indice, np.arange(len(buf))) + assert np.allclose(buf.prev(indice), [ + 0, 0, 1, 3, 3, + 5, 5, 6, 8, 8, + 14, 11, 11, 13, 13, + 15, 16, 16, 18, 18, + ]) + assert np.allclose(buf.next(indice), [ + 1, 2, 2, 4, 4, + 6, 7, 7, 9, 9, + 10, 12, 12, 14, 10, + 15, 17, 17, 19, 19, + ]) + # CachedReplayBuffer buf = CachedReplayBuffer(10, 4, 5) assert buf.sample_index(0).tolist() == [] - buf.add(obs=[1], act=[1], rew=[1], done=[1], cached_buffer_ids=[1]) - print(buf) + ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0], + cached_buffer_ids=[1]) + obs = np.zeros(buf.maxsize) + obs[15] = 1 + indice = buf.sample_index(0) + assert np.allclose(indice, [15]) + assert np.allclose(buf.prev(indice), [15]) + assert np.allclose(buf.next(indice), [15]) + assert np.allclose(buf.obs, obs) + assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0]) + ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1], + cached_buffer_ids=[3]) + obs[[0, 25]] = 2 + indice = buf.sample_index(0) + assert np.allclose(indice, [0, 15]) + assert np.allclose(buf.prev(indice), [0, 15]) + assert np.allclose(buf.next(indice), [0, 15]) + assert np.allclose(buf.obs, obs) + assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0]) + assert np.allclose(buf.unfinished_index(), [15]) + assert np.allclose(buf.sample_index(0), [0, 15]) + ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4], + done=[0, 1], cached_buffer_ids=[3, 1]) + assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0]) + obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] + assert np.allclose(buf.obs, obs) + assert np.allclose(buf.unfinished_index(), [25]) + indice = buf.sample_index(0) + assert np.allclose(indice, [0, 1, 2, 25]) + assert np.allclose(buf.done[indice], [1, 0, 1, 0]) + assert np.allclose(buf.prev(indice), [0, 1, 1, 25]) + assert np.allclose(buf.next(indice), [0, 2, 2, 25]) + indice = buf.sample_index(10000) + assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 if __name__ == '__main__': @@ -442,9 +533,9 @@ def test_vectorbuffer(): test_hdf5() test_replaybuffer() test_ignore_obs_next() - test_update() test_stack() test_pickle() test_segtree() test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) + test_update() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 628c85f10..f41991b3a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -150,8 +150,8 @@ def __init__( ) -> None: super().__init__() if sample_avail: - warnings.warn("sample_avail is deprecated. Please check out " - "tianshou version <= 0.3.1 if you want to use it.") + warnings.warn("sample_avail is deprecated in 0.4.0. Please check " + "out version <= 0.3.1 if you want to use it.") self.maxsize = size assert stack_num > 0, "stack_num should greater than 0" self.stack_num = stack_num @@ -214,8 +214,8 @@ def load_hdf5( def reset(self) -> None: """Clear all the data in replay buffer.""" - self._index = self._size = self._episode_length = 0 - self._episode_reward = 0.0 + self._index = self._size = 0 + self._episode_length, self._episode_reward = 0, 0.0 def set_batch(self, batch: Batch) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" @@ -434,9 +434,7 @@ def __init__(self, **kwargs: Any) -> None: def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: raise NotImplementedError("ListReplayBuffer cannot be sampled!") - def _add_to_buffer( - self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool] - ) -> None: + def _add_to_buffer(self, name: str, inst: Any) -> None: if self._meta.get(name) is None: self._meta.__dict__[name] = [] self._meta[name].append(inst) @@ -558,15 +556,15 @@ class ReplayBuffers(ReplayBuffer): explanation. """ - def __init__(self, buffer_list: List[ReplayBuffer]) -> None: + def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None: self.buffer_num = len(buffer_list) - self.buffers = np.array(buffer_list) + self.buffers = buffer_list offset = 0 for buf in self.buffers: - buf.alloc_fn = self.alloc_fn + buf.alloc_fn = self.alloc_fn # type: ignore buf.offset = offset offset += buf.maxsize - super().__init__(size=offset) + super().__init__(size=offset, **kwargs) def __len__(self) -> int: return sum([len(buf) for buf in self.buffers]) @@ -608,7 +606,7 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: return next_indices def update(self, buffer: ReplayBuffer) -> None: - """The VectorReplayBuffer cannot be updated by any buffer.""" + """The ReplayBuffers cannot be updated by any buffer.""" raise NotImplementedError def alloc_fn(self, key: List[str], value: Any) -> None: @@ -643,8 +641,8 @@ def add( # type: ignore batch = Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next, info=info, policy=policy) assert len(buffer_ids) == len(batch) - episode_lengths = [] - episode_rewards = [] + episode_lengths = [] # (len(buffer_ids),) + episode_rewards = [] # (len(buffer_ids), ...) for batch_idx, env_id in enumerate(buffer_ids): length, reward = self.buffers[env_id].add(**batch[batch_idx]) episode_lengths.append(length) @@ -702,11 +700,10 @@ def __init__( **kwargs: Any, ) -> None: assert cached_buffer_num > 0 and max_episode_length > 0 - self.cached_buffer_num = cached_buffer_num main_buffer = ReplayBuffer(size, **kwargs) buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) for _ in range(cached_buffer_num)] - super().__init__(buffer_list=buffers) + super().__init__(buffer_list=buffers, **kwargs) def add( # type: ignore self, @@ -728,8 +725,8 @@ def add( # type: ignore Return the array of episode_length and episode_reward with shape (len(cached_buffer_ids), ...), where (episode_length[i], - episode_reward[i]) refers to the cached_buffer_ids[i]'s corresponding - episode result. + episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's + corresponding episode result. """ if cached_buffer_ids is None: cached_buffer_ids = np.arange(self.buffer_num - 1) From 7eba23d00ea738b7ddd83655bedb4543ee59705c Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 25 Jan 2021 22:03:38 +0800 Subject: [PATCH 015/104] fix a bug and add some corner-case tests --- test/base/test_buffer.py | 10 ++++++++++ tianshou/data/buffer.py | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index c8edd818b..a20ed7c61 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -419,6 +419,8 @@ def test_vectorbuffer(): ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], done=[0, 0, 1], buffer_ids=[0, 1, 2]) assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3]) + with pytest.raises(NotImplementedError): + buf.update(buf) indice = buf.sample_index(11000) assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 batch, indice = buf.sample(0) @@ -488,6 +490,14 @@ def test_vectorbuffer(): 10, 12, 12, 14, 10, 15, 17, 17, 19, 19, ]) + # corner case: list, int and -1 + assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] + assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] + batch = buf._meta + batch.info.n = np.ones(buf.maxsize) + buf.set_batch(batch) + assert np.allclose(buf.buffers[-1].info.n, [1] * 5) + assert buf.sample_index(-1).tolist() == [] # CachedReplayBuffer buf = CachedReplayBuffer(10, 4, 5) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index f41991b3a..82279bdf6 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -586,7 +586,7 @@ def unfinished_index(self) -> np.ndarray: buf.unfinished_index() + buf.offset for buf in self.buffers]) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = np.asarray(index) + index = np.asarray(index) % self.maxsize prev_indices = np.zeros_like(index) for buf in self.buffers: mask = (buf.offset <= index) & (index < buf.offset + buf.maxsize) @@ -596,7 +596,7 @@ def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: return prev_indices def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = np.asarray(index) + index = np.asarray(index) % self.maxsize next_indices = np.zeros_like(index) for buf in self.buffers: mask = (buf.offset <= index) & (index < buf.offset + buf.maxsize) From b9f4f2a1e2cf79e7f95372c5b8ac19920568e2ce Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 26 Jan 2021 09:49:06 +0800 Subject: [PATCH 016/104] re-implement sample_avail function and add test for CachedReplayBuffer(size=0, ) --- test/base/test_buffer.py | 40 +++++++++++++++++++++++++++++++- tianshou/data/buffer.py | 49 ++++++++++++++++++++++++++++------------ 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index a20ed7c61..77b6524ad 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -95,13 +95,17 @@ def test_ignore_obs_next(size=10): def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) + buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) buf4 = CachedReplayBuffer(bufsize, cached_num, size, stack_num=stack_num, ignore_obs_next=True) + buf5 = CachedReplayBuffer(bufsize, cached_num, size, stack_num=stack_num, + ignore_obs_next=True, sample_avail=True) obs = env.reset(1) for i in range(18): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) + buf2.add(obs, 1, rew, done, None, info) buf3.add([None, None, obs], 1, rew, done, [None, obs], info) obs_list = np.array([obs + size * i for i in range(cached_num)]) act_list = [1] * cached_num @@ -111,6 +115,8 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): info_list = [info] * cached_num buf4.add(obs_list, act_list, rew_list, done_list, obs_next_list, info_list) + buf5.add(obs_list, act_list, rew_list, done_list, + obs_next_list, info_list) obs = obs_next if done: obs = env.reset(1) @@ -121,6 +127,10 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2]]) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs')) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next')) + _, indice = buf2.sample(0) + assert indice.tolist() == [6] + _, indice = buf2.sample(1) + assert indice.tolist() == [6] with pytest.raises(IndexError): buf[bufsize * 2] assert np.allclose(buf4.obs.reshape(-1), [ @@ -152,6 +162,16 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], ]) + assert np.all(buf4.done == buf5.done) + indice = buf5.sample_index(0) + assert np.allclose(sorted(indice), [2, 7]) + assert np.all(np.isin(buf5.sample_index(100), indice)) + # manually change the stack num + buf5.stack_num = 2 + for buf in buf5.buffers: + buf.stack_num = 2 + indice = buf5.sample_index(0) + assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) def test_priortized_replaybuffer(size=32, bufsize=15): @@ -498,7 +518,7 @@ def test_vectorbuffer(): buf.set_batch(batch) assert np.allclose(buf.buffers[-1].info.n, [1] * 5) assert buf.sample_index(-1).tolist() == [] - + assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object # CachedReplayBuffer buf = CachedReplayBuffer(10, 4, 5) assert buf.sample_index(0).tolist() == [] @@ -536,6 +556,24 @@ def test_vectorbuffer(): assert np.allclose(buf.next(indice), [0, 2, 2, 25]) indice = buf.sample_index(10000) assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 + # cached buffer with main_buffer size == 0 + buf = CachedReplayBuffer(0, 4, 5, sample_avail=True) # no effect + data = np.array([0, 0, 0, 0]) + buf.add(obs=data, act=data, rew=data, done=[0, 0, 1, 1], obs_next=data) + buf.add(obs=data, act=data, rew=data, done=[0, 0, 0, 0], obs_next=data) + buf.add(obs=data, act=data, rew=data, done=[1, 1, 1, 1], obs_next=data) + buf.add(obs=data, act=data, rew=data, done=[0, 0, 0, 0], obs_next=data) + buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1], obs_next=data) + assert np.allclose(buf.done, [ + 0, 0, 1, 0, 0, + 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, + ]) + indice = buf.sample_index(0) + assert np.allclose(indice, [0, 1, 10, 11]) + assert np.allclose(buf.prev(indice), [0, 0, 10, 10]) + assert np.allclose(buf.next(indice), [1, 1, 11, 11]) if __name__ == '__main__': diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 82279bdf6..6448c22aa 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -135,6 +135,9 @@ class ReplayBuffer: :param bool save_only_last_obs: only save the last obs/obs_next when it has a shape of (timestep, ...) because of temporal stacking, defaults to False. + :param bool sample_avail: the parameter indicating sampling only available + index when using frame-stack sampling method, defaults to False. + This feature is not supported in Prioritized Replay Buffer currently. """ _reserved_keys = ("obs", "act", "rew", "done", @@ -149,15 +152,13 @@ def __init__( sample_avail: bool = False, ) -> None: super().__init__() - if sample_avail: - warnings.warn("sample_avail is deprecated in 0.4.0. Please check " - "out version <= 0.3.1 if you want to use it.") self.maxsize = size assert stack_num > 0, "stack_num should greater than 0" self.stack_num = stack_num self._indices = np.arange(size) self._save_obs_next = not ignore_obs_next self._save_only_last_obs = save_only_last_obs + self._sample_avail = sample_avail self._index = 0 # current index self._size = 0 # current buffer size self._meta: Batch = Batch() @@ -251,11 +252,11 @@ def update(self, buffer: "ReplayBuffer") -> None: """Move the data from the given buffer to current buffer.""" if len(buffer) == 0 or self.maxsize == 0: return - stack_num_orig, buffer.stack_num = buffer.stack_num, 1 + stack_num, buffer.stack_num = buffer.stack_num, 1 indices = buffer.sample_index(0) # get all available indices for i in indices: self.add(**buffer[i]) # type: ignore - buffer.stack_num = stack_num_orig + buffer.stack_num = stack_num def alloc_fn(self, key: List[str], value: Any) -> None: """Allocate memory on buffer._meta for new (key, value) pair.""" @@ -336,15 +337,27 @@ def sample_index(self, batch_size: int) -> np.ndarray: an empty numpy array if batch_size < 0 or no available index can be sampled. """ - if batch_size > 0: - return np.random.choice(self._size, batch_size) - elif batch_size == 0: # construct current available indices - return np.concatenate([ - np.arange(self._index, self._size), - np.arange(0, self._index), - ]) + if self.stack_num == 1 or not self._sample_avail: # most often case + if batch_size > 0: + return np.random.choice(self._size, batch_size) + elif batch_size == 0: # construct current available indices + return np.concatenate([ + np.arange(self._index, self._size), + np.arange(self._index)]) + else: + return np.array([], np.int) else: - return np.array([], np.int) + if batch_size < 0: + return np.array([], np.int) + all_indices = prev_indices = np.concatenate([ + np.arange(self._index, self._size), np.arange(self._index)]) + for _ in range(self.stack_num - 2): + prev_indices = self.prev(prev_indices) + all_indices = all_indices[prev_indices != self.prev(prev_indices)] + if batch_size > 0: + return np.random.choice(all_indices, batch_size) + else: + return all_indices def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with size = batch_size. @@ -397,12 +410,13 @@ def __getitem__( """ if isinstance(index, slice): # change slice to np array index = self._indices[:len(self)][index] + obs = self.get(index, "obs") if self._save_obs_next: obs_next = self.get(index, "obs_next") else: obs_next = self.get(self.next(index), "obs") return Batch( - obs=self.get(index, "obs"), + obs=obs, act=self.act[index], rew=self.rew[index], done=self.done[index], @@ -652,6 +666,13 @@ def add( # type: ignore def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: return np.array([], np.int) + if self._sample_avail and self.stack_num > 1: + all_indices = np.concatenate([ + buf.sample_index(0) + buf.offset for buf in self.buffers]) + if batch_size == 0: + return all_indices + else: + return np.random.choice(all_indices, batch_size) if batch_size == 0: # get all available indices sample_num = np.zeros(self.buffer_num, np.int) else: From 8fe85f88133f0897cd88c3e6e0f7f203c25209d8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 26 Jan 2021 10:24:01 +0800 Subject: [PATCH 017/104] improve documents --- docs/tutorials/concepts.rst | 154 +++++++++++++++++++++++++++++++++++- tianshou/data/buffer.py | 116 +-------------------------- 2 files changed, 151 insertions(+), 119 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 8ea2d272e..1294cf1b3 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -53,11 +53,157 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair Buffer ------ -.. automodule:: tianshou.data.ReplayBuffer - :members: - :noindex: +:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. -Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. +The current implementation of Tianshou typically use 7 reserved keys in +:class:`~tianshou.data.Batch`: + +* ``obs`` the observation of step :math:`t` ; +* ``act`` the action of step :math:`t` ; +* ``rew`` the reward of step :math:`t` ; +* ``done`` the done flag of step :math:`t` ; +* ``obs_next`` the observation of step :math:`t+1` ; +* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``); +* ``policy`` the data computed by policy in step :math:`t`; + +The following code snippet illustrates its usage: + +:: + + >>> import pickle, numpy as np + >>> from tianshou.data import ReplayBuffer + >>> buf = ReplayBuffer(size=20) + >>> for i in range(3): + ... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}) + + >>> buf.obs + # since we set size = 20, len(buf.obs) == 20. + array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + >>> # but there are only three valid items, so len(buf) == 3. + >>> len(buf) + 3 + >>> # save to file "buf.pkl" + >>> pickle.dump(buf, open('buf.pkl', 'wb')) + >>> # save to HDF5 file + >>> buf.save_hdf5('buf.hdf5') + + >>> buf2 = ReplayBuffer(size=10) + >>> for i in range(15): + ... done = i % 4 == 0 + ... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}) + >>> len(buf2) + 10 + >>> buf2.obs + # since its size = 10, it only stores the last 10 steps' result. + array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9]) + + >>> # move buf2's result into buf (meanwhile keep it chronologically) + >>> buf.update(buf2) + >>> buf.obs + array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0, + 0, 0, 0, 0]) + + >>> # get all available index by using batch_size = 0 + >>> indice = buf.sample_index(0) + >>> indice + array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + >>> # get one step previous/next transition + >>> buf.prev(indice) + array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11]) + >>> buf.next(indice) + array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12]) + + >>> # get a random sample from buffer + >>> # the batch_data is equal to buf[indice]. + >>> batch_data, indice = buf.sample(batch_size=4) + >>> batch_data.obs == buf[indice].obs + array([ True, True, True, True]) + >>> len(buf) + 13 + + >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" + >>> len(buf) + 3 + >>> # load complete buffer from HDF5 file + >>> buf = ReplayBuffer.load_hdf5('buf.hdf5') + >>> len(buf) + 3 + +:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see issue#38): + +.. raw:: html + +
+ Advance usage of ReplayBuffer + +.. code-block:: python + + >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) + >>> for i in range(16): + ... done = i % 5 == 0 + ... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i, + ... done=done, obs_next={'id': i + 1}) + ... print(i, ep_len, ep_rew) + 0 1 0.0 + 1 0 0.0 + 2 0 0.0 + 3 0 0.0 + 4 0 0.0 + 5 5 15.0 + 6 0 0.0 + 7 0 0.0 + 8 0 0.0 + 9 0 0.0 + 10 5 40.0 + 11 0 0.0 + 12 0 0.0 + 13 0 0.0 + 14 0 0.0 + 15 5 65.0 + >>> print(buf) # you can see obs_next is not saved in buf + ReplayBuffer( + obs: Batch( + id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]), + ), + act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]), + rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), + done: array([False, True, False, False, False, False, True, False, + False]), + info: Batch(), + policy: Batch(), + ) + >>> index = np.arange(len(buf)) + >>> print(buf.get(index, 'obs').id) + [[ 7 7 8 9] + [ 7 8 9 10] + [11 11 11 11] + [11 11 11 12] + [11 11 12 13] + [11 12 13 14] + [12 13 14 15] + [ 7 7 7 7] + [ 7 7 7 8]] + >>> # here is another way to get the stacked data + >>> # (stack only for obs and obs_next) + >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() + 0 + >>> # we can get obs_next through __getitem__, even if it doesn't exist + >>> print(buf[:].obs_next.id) + [[ 7 8 9 10] + [ 7 8 9 10] + [11 11 11 12] + [11 11 12 13] + [11 12 13 14] + [12 13 14 15] + [12 13 14 15] + [ 7 7 7 8] + [ 7 7 8 9]] + +.. raw:: html + +

+ +Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. Policy diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 6448c22aa..82e903642 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -12,121 +12,7 @@ class ReplayBuffer: """:class:`~tianshou.data.ReplayBuffer` stores data generated from \ interaction between the policy and environment. ReplayBuffer can be \ - considered as a specialized form(management) of Batch. - - The current implementation of Tianshou typically use 7 reserved keys in - :class:`~tianshou.data.Batch`: - - * ``obs`` the observation of step :math:`t` ; - * ``act`` the action of step :math:`t` ; - * ``rew`` the reward of step :math:`t` ; - * ``done`` the done flag of step :math:`t` ; - * ``obs_next`` the observation of step :math:`t+1` ; - * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \ - function returns 4 arguments, and the last one is ``info``); - * ``policy`` the data computed by policy in step :math:`t`; - - The following code snippet illustrates its usage: - :: - - >>> import pickle, numpy as np - >>> from tianshou.data import ReplayBuffer - >>> buf = ReplayBuffer(size=20) - >>> for i in range(3): - ... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) - >>> buf.obs - # since we set size = 20, len(buf.obs) == 20. - array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0.]) - >>> # but there are only three valid items, so len(buf) == 3. - >>> len(buf) - 3 - >>> # save to file "buf.pkl" - >>> pickle.dump(buf, open('buf.pkl', 'wb')) - >>> # save to HDF5 file - >>> buf.save_hdf5('buf.hdf5') - >>> buf2 = ReplayBuffer(size=10) - >>> for i in range(15): - ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) - >>> len(buf2) - 10 - >>> buf2.obs - # since its size = 10, it only stores the last 10 steps' result. - array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.]) - - >>> # move buf2's result into buf (meanwhile keep it chronologically) - >>> buf.update(buf2) - array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., - 0., 0., 0., 0., 0., 0., 0.]) - - >>> # get a random sample from buffer - >>> # the batch_data is equal to buf[indice]. - >>> batch_data, indice = buf.sample(batch_size=4) - >>> batch_data.obs == buf[indice].obs - array([ True, True, True, True]) - >>> len(buf) - 13 - >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" - >>> len(buf) - 3 - >>> # load complete buffer from HDF5 file - >>> buf = ReplayBuffer.load_hdf5('buf.hdf5') - >>> len(buf) - 3 - >>> # load contents of HDF5 file into existing buffer - >>> # (only possible if size of buffer and data in file match) - >>> buf.load_contents_hdf5('buf.hdf5') - >>> len(buf) - 3 - - :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling - (typically for RNN usage, see issue#19), ignoring storing the next - observation (save memory in atari tasks), and multi-modal observation (see - issue#38): - :: - - >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) - >>> for i in range(16): - ... done = i % 5 == 0 - ... buf.add(obs={'id': i}, act=i, rew=i, done=done, - ... obs_next={'id': i + 1}) - >>> print(buf) # you can see obs_next is not saved in buf - ReplayBuffer( - act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), - info: Batch(), - obs: Batch( - id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - ), - policy: Batch(), - rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - ) - >>> index = np.arange(len(buf)) - >>> print(buf.get(index, 'obs').id) - [[ 7. 7. 8. 9.] - [ 7. 8. 9. 10.] - [11. 11. 11. 11.] - [11. 11. 11. 12.] - [11. 11. 12. 13.] - [11. 12. 13. 14.] - [12. 13. 14. 15.] - [ 7. 7. 7. 7.] - [ 7. 7. 7. 8.]] - >>> # here is another way to get the stacked data - >>> # (stack only for obs and obs_next) - >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() - 0.0 - >>> # we can get obs_next through __getitem__, even if it doesn't exist - >>> print(buf[:].obs_next.id) - [[ 7. 8. 9. 10.] - [ 7. 8. 9. 10.] - [11. 11. 11. 12.] - [11. 11. 12. 13.] - [11. 12. 13. 14.] - [12. 13. 14. 15.] - [12. 13. 14. 15.] - [ 7. 7. 7. 8.] - [ 7. 7. 8. 9.]] + considered as a specialized form (or management) of Batch. :param int size: the size of replay buffer. :param int stack_num: the frame-stack sampling argument, should be greater From f59a53024e3f71a151acd37feb4fc4ab84c0fe53 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 26 Jan 2021 20:19:39 +0800 Subject: [PATCH 018/104] ReplayBuffers._offset --- tianshou/data/buffer.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 82e903642..b0f0c40b9 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,4 +1,5 @@ import h5py +import torch import warnings import numpy as np from numbers import Number @@ -393,8 +394,7 @@ def add( info, policy, **kwargs) def sample_index(self, batch_size: int) -> np.ndarray: - assert self._size > 0, "Cannot sample a buffer with 0 size." - if batch_size > 0: + if batch_size > 0 and self._size > 0: scalar = np.random.rand(batch_size) * self.weight.reduce() return self.weight.get_prefix_sum_idx(scalar) else: @@ -422,7 +422,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def update_weight( self, indice: Union[np.ndarray], - new_weight: np.ndarray + new_weight: Union[np.ndarray, torch.Tensor], ) -> None: """Update priority weight by indice in this buffer. @@ -459,10 +459,13 @@ class ReplayBuffers(ReplayBuffer): def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None: self.buffer_num = len(buffer_list) self.buffers = buffer_list + self._offset = [] offset = 0 for buf in self.buffers: + # overwrite sub-buffers' alloc_fn so that the top buffer can + # allocate new memory for all buffers buf.alloc_fn = self.alloc_fn # type: ignore - buf.offset = offset + self._offset.append(offset) offset += buf.maxsize super().__init__(size=offset, **kwargs) @@ -474,8 +477,8 @@ def reset(self) -> None: buf.reset() def _set_batch_for_children(self) -> None: - for i, buf in enumerate(self.buffers): - buf.set_batch(self._meta[buf.offset:buf.offset + buf.maxsize]) + for offset, buf in zip(self._offset, self.buffers): + buf.set_batch(self._meta[offset:offset + buf.maxsize]) def set_batch(self, batch: Batch) -> None: super().set_batch(batch) @@ -483,26 +486,25 @@ def set_batch(self, batch: Batch) -> None: def unfinished_index(self) -> np.ndarray: return np.concatenate([ - buf.unfinished_index() + buf.offset for buf in self.buffers]) + buf.unfinished_index() + offset + for offset, buf in zip(self._offset, self.buffers)]) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: index = np.asarray(index) % self.maxsize prev_indices = np.zeros_like(index) - for buf in self.buffers: - mask = (buf.offset <= index) & (index < buf.offset + buf.maxsize) + for offset, buf in zip(self._offset, self.buffers): + mask = (offset <= index) & (index < offset + buf.maxsize) if np.any(mask): - prev_indices[mask] = buf.prev( - index[mask] - buf.offset) + buf.offset + prev_indices[mask] = buf.prev(index[mask] - offset) + offset return prev_indices def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: index = np.asarray(index) % self.maxsize next_indices = np.zeros_like(index) - for buf in self.buffers: - mask = (buf.offset <= index) & (index < buf.offset + buf.maxsize) + for offset, buf in zip(self._offset, self.buffers): + mask = (offset <= index) & (index < offset + buf.maxsize) if np.any(mask): - next_indices[mask] = buf.next( - index[mask] - buf.offset) + buf.offset + next_indices[mask] = buf.next(index[mask] - offset) + offset return next_indices def update(self, buffer: ReplayBuffer) -> None: @@ -554,7 +556,8 @@ def sample_index(self, batch_size: int) -> np.ndarray: return np.array([], np.int) if self._sample_avail and self.stack_num > 1: all_indices = np.concatenate([ - buf.sample_index(0) + buf.offset for buf in self.buffers]) + buf.sample_index(0) + offset + for offset, buf in zip(self._offset, self.buffers)]) if batch_size == 0: return all_indices else: @@ -570,8 +573,8 @@ def sample_index(self, batch_size: int) -> np.ndarray: sample_num[sample_num == 0] = -1 return np.concatenate([ - buf.sample_index(bsz) + buf.offset - for buf, bsz in zip(self.buffers, sample_num) + buf.sample_index(bsz) + offset + for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) ]) From 425c2bd9a121efcac10bbabfaf4b62a5493934cf Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 27 Jan 2021 13:21:35 +0800 Subject: [PATCH 019/104] fix atari-style update; support CachedBuffer with main_buffer==PrioBuffer --- test/base/test_buffer.py | 71 +++++++++++++++++-------- tianshou/data/buffer.py | 108 ++++++++++++++++++++++++++------------- 2 files changed, 121 insertions(+), 58 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 77b6524ad..4a980d5de 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -7,10 +7,10 @@ import numpy as np from timeit import timeit +from tianshou.data.utils.converter import to_hdf5 from tianshou.data import Batch, SegmentTree, ReplayBuffer from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer from tianshou.data import ReplayBuffers, CachedReplayBuffer -from tianshou.data.utils.converter import to_hdf5 if __name__ == '__main__': from env import MyTestEnv @@ -97,10 +97,13 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) - buf4 = CachedReplayBuffer(bufsize, cached_num, size, - stack_num=stack_num, ignore_obs_next=True) - buf5 = CachedReplayBuffer(bufsize, cached_num, size, stack_num=stack_num, - ignore_obs_next=True, sample_avail=True) + buf4 = CachedReplayBuffer( + ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), + cached_num, size) + buf5 = CachedReplayBuffer( + PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num, + ignore_obs_next=True, sample_avail=True), + cached_num, size) obs = env.reset(1) for i in range(18): obs_next, rew, done, info = env.step(1) @@ -172,6 +175,30 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): buf.stack_num = 2 indice = buf5.sample_index(0) assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) + batch, _ = buf5.sample(0) + assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1) + buf5.update_weight(indice, batch.weight * 0) + weight = buf5[np.arange(buf5.maxsize)].weight + modified_weight = weight[[0, 1, 2, 5, 6, 7]] + assert modified_weight.min() == modified_weight.max() + assert modified_weight.max() < 1 + unmodified_weight = weight[[3, 4, 8]] + assert unmodified_weight.min() == unmodified_weight.max() + assert unmodified_weight.max() < 1 + cached_weight = weight[9:] + assert cached_weight.min() == cached_weight.max() == 1 + # test Atari + buf6 = CachedReplayBuffer( + ReplayBuffer(bufsize, stack_num=stack_num, + save_only_last_obs=True, ignore_obs_next=True), + cached_num, size) + obs = np.random.rand(size, 4, 84, 84) + buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], + obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2]) + assert buf6.obs.shape == (buf6.maxsize, 84, 84) + assert np.allclose(buf6.obs[0], obs[0, -1]) + assert np.allclose(buf6.obs[14], obs[2, -1]) + assert np.allclose(buf6.obs[19], obs[0, -1]) def test_priortized_replaybuffer(size=32, bufsize=15): @@ -314,8 +341,7 @@ def test_pickle(): vbuf = ReplayBuffer(size, stack_num=2) lbuf = ListReplayBuffer() pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - rew = torch.tensor([1.]).to(device) + rew = np.array([1, 1]) for i in range(4): vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) for i in range(3): @@ -343,18 +369,18 @@ def test_hdf5(): "list": ListReplayBuffer(), "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), "vector": ReplayBuffers([ReplayBuffer(size) for i in range(4)]), - "cached": CachedReplayBuffer(size, 4, size) + "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' - rew = torch.tensor([1.]).to(device) + info_t = torch.tensor([1.]).to(device) for i in range(4): kwargs = { 'obs': Batch(index=np.array([i])), 'act': i, - 'rew': rew, + 'rew': np.array([1, 2]), 'done': i % 3 == 2, - 'info': {"number": {"n": i}, 'extra': None}, + 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, } buffers["array"].add(**kwargs) buffers["list"].add(**kwargs) @@ -400,7 +426,7 @@ def test_hdf5(): kwargs = { 'obs': Batch(index=np.array([5])), 'act': 5, - 'rew': rew, + 'rew': np.array([2, 1]), 'done': False, 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, } @@ -520,7 +546,7 @@ def test_vectorbuffer(): assert buf.sample_index(-1).tolist() == [] assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object # CachedReplayBuffer - buf = CachedReplayBuffer(10, 4, 5) + buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) assert buf.sample_index(0).tolist() == [] ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0], cached_buffer_ids=[1]) @@ -557,13 +583,14 @@ def test_vectorbuffer(): indice = buf.sample_index(10000) assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # cached buffer with main_buffer size == 0 - buf = CachedReplayBuffer(0, 4, 5, sample_avail=True) # no effect - data = np.array([0, 0, 0, 0]) - buf.add(obs=data, act=data, rew=data, done=[0, 0, 1, 1], obs_next=data) - buf.add(obs=data, act=data, rew=data, done=[0, 0, 0, 0], obs_next=data) - buf.add(obs=data, act=data, rew=data, done=[1, 1, 1, 1], obs_next=data) - buf.add(obs=data, act=data, rew=data, done=[0, 0, 0, 0], obs_next=data) - buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1], obs_next=data) + buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) + data = np.zeros(4) + rew = np.ones([4, 4]) + buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data) + buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) + buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data) + buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) + buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data) assert np.allclose(buf.done, [ 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, @@ -577,8 +604,6 @@ def test_vectorbuffer(): if __name__ == '__main__': - test_vectorbuffer() - test_hdf5() test_replaybuffer() test_ignore_obs_next() test_stack() @@ -587,3 +612,5 @@ def test_vectorbuffer(): test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) test_update() + test_vectorbuffer() + test_hdf5() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b0f0c40b9..389f219b4 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -15,7 +15,7 @@ class ReplayBuffer: interaction between the policy and environment. ReplayBuffer can be \ considered as a specialized form (or management) of Batch. - :param int size: the size of replay buffer. + :param int size: the maximum size of replay buffer. :param int stack_num: the frame-stack sampling argument, should be greater than or equal to 1, defaults to 1 (no stacking). :param bool ignore_obs_next: whether to store obs_next, defaults to False. @@ -24,7 +24,6 @@ class ReplayBuffer: False. :param bool sample_avail: the parameter indicating sampling only available index when using frame-stack sampling method, defaults to False. - This feature is not supported in Prioritized Replay Buffer currently. """ _reserved_keys = ("obs", "act", "rew", "done", @@ -38,6 +37,12 @@ def __init__( save_only_last_obs: bool = False, sample_avail: bool = False, ) -> None: + self.options: Dict[str, Any] = { + "stack_num": stack_num, + "ignore_obs_next": ignore_obs_next, + "save_only_last_obs": save_only_last_obs, + "sample_avail": sample_avail, + } super().__init__() self.maxsize = size assert stack_num > 0, "stack_num should greater than 0" @@ -140,10 +145,13 @@ def update(self, buffer: "ReplayBuffer") -> None: if len(buffer) == 0 or self.maxsize == 0: return stack_num, buffer.stack_num = buffer.stack_num, 1 + save_only_last_obs, self._save_only_last_obs = \ + self._save_only_last_obs, False indices = buffer.sample_index(0) # get all available indices for i in indices: self.add(**buffer[i]) # type: ignore buffer.stack_num = stack_num + self._save_only_last_obs = save_only_last_obs def alloc_fn(self, key: List[str], value: Any) -> None: """Allocate memory on buffer._meta for new (key, value) pair.""" @@ -158,6 +166,12 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: except KeyError: self.alloc_fn([name], inst) value = self._meta[name] + if isinstance(inst, (torch.Tensor, np.ndarray)): + if inst.shape != value.shape[1:]: + raise ValueError( + "Cannot add data to a buffer with different shape with key" + f" {name}, expect {value.shape[1:]}, given {inst.shape}." + ) try: value[self._index] = inst except KeyError: # inst is a dict/Batch @@ -175,7 +189,7 @@ def add( info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, **kwargs: Any, - ) -> Tuple[int, float]: + ) -> Tuple[np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. Return (episode_length, episode_reward) if one episode is terminated, @@ -211,11 +225,13 @@ def add( self._size = self._index = self._size + 1 if done: - result = (self._episode_length, self._episode_reward) + result = np.array(self._episode_length), \ + np.array(self._episode_reward) self._episode_length, self._episode_reward = 0, 0.0 return result else: - return (0, 0.0) + return np.zeros_like(self._episode_length), \ + np.zeros_like(self._episode_reward) def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. @@ -258,7 +274,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def get( self, - indice: Union[int, np.integer, np.ndarray], + index: Union[int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: @@ -272,9 +288,9 @@ def get( val = self._meta[key] try: if stack_num == 1: # the most often case - return val[indice] + return val[index] stack: List[Any] = [] - indice = np.asarray(indice) + indice = np.asarray(index) for _ in range(stack_num): stack = [val[indice]] + stack indice = self.prev(indice) @@ -381,8 +397,7 @@ def add( policy: Optional[Union[dict, Batch]] = {}, weight: Optional[Union[Number, np.number]] = None, **kwargs: Any, - ) -> Tuple[int, float]: - """Add a batch of data into replay buffer.""" + ) -> Tuple[np.ndarray, np.ndarray]: if weight is None: weight = self._max_prio else: @@ -400,37 +415,32 @@ def sample_index(self, batch_size: int) -> np.ndarray: else: return super().sample_index(batch_size) - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with priority probability. - - Return all the data in the buffer if batch_size is 0. - - :return: Sample data and its corresponding index inside the buffer. + def get_weight( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> np.ndarray: + """Get the importance sampling weight. The "weight" in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ - indice = self.sample_index(batch_size) - batch = self[indice] # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) - batch.weight = (batch.weight / self._min_prio) ** (-self._beta) - return batch, indice + return (self.weight[index] / self._min_prio) ** (-self._beta) def update_weight( self, - indice: Union[np.ndarray], + index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor], ) -> None: - """Update priority weight by indice in this buffer. + """Update priority weight by index in this buffer. - :param np.ndarray indice: indice you want to update weight. + :param np.ndarray index: index you want to update weight. :param np.ndarray new_weight: new priority weight you want to update. """ weight = np.abs(to_numpy(new_weight)) + self.__eps - self.weight[indice] = weight ** self._alpha + self.weight[index] = weight ** self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) @@ -438,7 +448,7 @@ def __getitem__( self, index: Union[slice, int, np.integer, np.ndarray] ) -> Batch: batch = super().__getitem__(index) - batch.weight = self.weight[index] + batch.weight = self.get_weight(index) return batch @@ -549,7 +559,7 @@ def add( # type: ignore length, reward = self.buffers[env_id].add(**batch[batch_idx]) episode_lengths.append(length) episode_rewards.append(reward) - return np.array(episode_lengths), np.array(episode_rewards) + return np.stack(episode_lengths), np.stack(episode_rewards) def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: @@ -579,8 +589,8 @@ def sample_index(self, batch_size: int) -> np.ndarray: class CachedReplayBuffer(ReplayBuffers): - """CachedReplayBuffer contains a ReplayBuffers with given size and n \ - cached buffers, cached_buffer_num * ReplayBuffer(size=max_episode_length). + """CachedReplayBuffer contains a given main buffer and n cached buffers, \ + cached_buffer_num * ReplayBuffer(size=max_episode_length). The memory layout is: ``| main_buffer | cached_buffer[0] | cached_buffer[1] | ... | cached_buffer[cached_buffer_num - 1]``. @@ -589,7 +599,8 @@ class CachedReplayBuffer(ReplayBuffers): terminated, the data will move to the main buffer and the corresponding cached buffer will be reset. - :param int size: the size of main buffer. + :param ReplayBuffer main_buffer: the main buffer whose ``.update()`` + function behaves normally. :param int cached_buffer_num: number of ReplayBuffer needs to be created for cached buffer. :param int max_episode_length: the maximum length of one episode, used in @@ -604,13 +615,14 @@ class CachedReplayBuffer(ReplayBuffers): def __init__( self, - size: int, + main_buffer: ReplayBuffer, cached_buffer_num: int, max_episode_length: int, - **kwargs: Any, ) -> None: assert cached_buffer_num > 0 and max_episode_length > 0 - main_buffer = ReplayBuffer(size, **kwargs) + self.main_buffer = main_buffer + self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer) + kwargs = self.main_buffer.options buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) for _ in range(cached_buffer_num)] super().__init__(buffer_list=buffers, **kwargs) @@ -643,12 +655,36 @@ def add( # type: ignore # in self.buffers, the first buffer is main_buffer buffer_ids = np.asarray(cached_buffer_ids) + 1 - result = super().add(obs, act, rew, done, obs_next, - info, policy, buffer_ids=buffer_ids, **kwargs) + result = super().add(obs, act, rew, done, obs_next, info, + policy, buffer_ids=buffer_ids, **kwargs) # find the terminated episode, move data from cached buf to main buf for buffer_idx in np.asarray(buffer_ids)[np.asarray(done) > 0]: - self.buffers[0].update(self.buffers[buffer_idx]) + self.main_buffer.update(self.buffers[buffer_idx]) self.buffers[buffer_idx].reset() - return result + + def __getitem__( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> Batch: + batch = super().__getitem__(index) + if self._is_prioritized: + indice = self._indices[index] + mask = indice < self.main_buffer.maxsize + batch.weight = np.ones(len(indice)) + batch.weight[mask] = self.main_buffer.get_weight(indice[mask]) + return batch + + def update_weight( + self, + index: np.ndarray, + new_weight: Union[np.ndarray, torch.Tensor], + ) -> None: + """Update priority weight by index in main buffer. + + :param np.ndarray index: index you want to update weight. + :param np.ndarray new_weight: new priority weight you want to update. + """ + if self._is_prioritized: + mask = index < self.main_buffer.maxsize + self.main_buffer.update_weight(index[mask], new_weight[mask]) From 2361755eba9317f3accd9b468f5f568ede5baadd Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 27 Jan 2021 16:39:19 +0800 Subject: [PATCH 020/104] assert _meta.is_empty() in ReplayBuffers init --- tianshou/data/buffer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 389f219b4..7e0450ffd 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -475,6 +475,7 @@ def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None: # overwrite sub-buffers' alloc_fn so that the top buffer can # allocate new memory for all buffers buf.alloc_fn = self.alloc_fn # type: ignore + assert buf._meta.is_empty() self._offset.append(offset) offset += buf.maxsize super().__init__(size=offset, **kwargs) From 75d581b62d038dad8060cbe10b34b5585726ec12 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 28 Jan 2021 10:11:45 +0800 Subject: [PATCH 021/104] small fix --- test/base/test_buffer.py | 4 +++- tianshou/data/buffer.py | 21 +++++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 4a980d5de..2c70682a5 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -53,6 +53,8 @@ def test_replaybuffer(size=10, bufsize=20): b = ListReplayBuffer() with pytest.raises(NotImplementedError): b.sample(0) + with pytest.raises(NotImplementedError): + b.update(b) def test_ignore_obs_next(size=10): @@ -407,9 +409,9 @@ def test_hdf5(): assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k].maxsize == buffers[k].maxsize - assert _buffers[k]._index == buffers[k]._index assert np.all(_buffers[k]._indices == buffers[k]._indices) for k in ["array", "prioritized"]: + assert _buffers[k]._index == buffers[k]._index assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 7e0450ffd..bc03d6e42 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -80,6 +80,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None: ("buffer.__getattr__" is customized). """ self.__dict__.update(state) + # compatible with previous version HDF5 + self._indices = np.arange(self.maxsize) def __getstate__(self) -> Dict[str, Any]: return self.__dict__ @@ -362,6 +364,10 @@ def reset(self) -> None: if isinstance(self._meta[k], list): self._meta.__dict__[k] = [] + def update(self, buffer: ReplayBuffer) -> None: + """The ListReplayBuffer cannot be updated by any buffer.""" + raise NotImplementedError + class PrioritizedReplayBuffer(ReplayBuffer): """Implementation of Prioritized Experience Replay. arXiv:1511.05952. @@ -479,6 +485,8 @@ def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None: self._offset.append(offset) offset += buf.maxsize super().__init__(size=offset, **kwargs) + # delete useless variables + del self._index, self._size, self._episode_reward, self._episode_length def __len__(self) -> int: return sum([len(buf) for buf in self.buffers]) @@ -556,8 +564,8 @@ def add( # type: ignore assert len(buffer_ids) == len(batch) episode_lengths = [] # (len(buffer_ids),) episode_rewards = [] # (len(buffer_ids), ...) - for batch_idx, env_id in enumerate(buffer_ids): - length, reward = self.buffers[env_id].add(**batch[batch_idx]) + for batch_idx, buffer_id in enumerate(buffer_ids): + length, reward = self.buffers[buffer_id].add(**batch[batch_idx]) episode_lengths.append(length) episode_rewards.append(reward) return np.stack(episode_lengths), np.stack(episode_rewards) @@ -621,11 +629,12 @@ def __init__( max_episode_length: int, ) -> None: assert cached_buffer_num > 0 and max_episode_length > 0 - self.main_buffer = main_buffer self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer) - kwargs = self.main_buffer.options - buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) - for _ in range(cached_buffer_num)] + kwargs = main_buffer.options + self.main_buffer = main_buffer + self.cached_buffer = [ReplayBuffer(max_episode_length, **kwargs) + for _ in range(cached_buffer_num)] + buffers = [main_buffer] + self.cached_buffer super().__init__(buffer_list=buffers, **kwargs) def add( # type: ignore From b5d93f365b361d4ef58619c8084f14e6af6addda Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 28 Jan 2021 10:26:06 +0800 Subject: [PATCH 022/104] small fix --- tianshou/data/buffer.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index bc03d6e42..b8c689532 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -227,8 +227,8 @@ def add( self._size = self._index = self._size + 1 if done: - result = np.array(self._episode_length), \ - np.array(self._episode_reward) + result = np.asarray(self._episode_length), \ + np.asarray(self._episode_reward) self._episode_length, self._episode_reward = 0, 0.0 return result else: @@ -315,6 +315,7 @@ def __getitem__( """ if isinstance(index, slice): # change slice to np array index = self._indices[:len(self)][index] + # raise KeyError first instead of AttributeError, to support np.array obs = self.get(index, "obs") if self._save_obs_next: obs_next = self.get(index, "obs_next") @@ -478,14 +479,14 @@ def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None: self._offset = [] offset = 0 for buf in self.buffers: - # overwrite sub-buffers' alloc_fn so that the top buffer can - # allocate new memory for all buffers + # overwrite sub-buffers' alloc_fn so that + # the top buffer can allocate new memory for all sub-buffers buf.alloc_fn = self.alloc_fn # type: ignore assert buf._meta.is_empty() self._offset.append(offset) offset += buf.maxsize super().__init__(size=offset, **kwargs) - # delete useless variables + # delete useless variables in ReplayBuffer del self._index, self._size, self._episode_reward, self._episode_length def __len__(self) -> int: @@ -631,11 +632,11 @@ def __init__( assert cached_buffer_num > 0 and max_episode_length > 0 self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer) kwargs = main_buffer.options - self.main_buffer = main_buffer - self.cached_buffer = [ReplayBuffer(max_episode_length, **kwargs) - for _ in range(cached_buffer_num)] - buffers = [main_buffer] + self.cached_buffer + buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) + for _ in range(cached_buffer_num)] super().__init__(buffer_list=buffers, **kwargs) + self.main_buffer = self.buffers[0] + self.cached_buffer = self.buffers[1:] def add( # type: ignore self, @@ -669,7 +670,7 @@ def add( # type: ignore policy, buffer_ids=buffer_ids, **kwargs) # find the terminated episode, move data from cached buf to main buf - for buffer_idx in np.asarray(buffer_ids)[np.asarray(done) > 0]: + for buffer_idx in buffer_ids[np.asarray(done) > 0]: self.main_buffer.update(self.buffers[buffer_idx]) self.buffers[buffer_idx].reset() return result From c8f27c983b9e5cd66bac525450727b9ac6969962 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 28 Jan 2021 10:31:18 +0800 Subject: [PATCH 023/104] improve coverage --- test/base/test_buffer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 2c70682a5..f09d6c59e 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -39,6 +39,8 @@ def test_replaybuffer(size=10, bufsize=20): assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() b = ReplayBuffer(size=10) + # neg bsz should return empty index + assert b.sample_index(-1).tolist() == [] b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}}) assert b.obs[0] == 1 assert b.obs_next[0] == 'str' From 9c879f213bd8cef0a5cef8718a493da8fd2ad7a8 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 28 Jan 2021 17:12:48 +0800 Subject: [PATCH 024/104] small buffer change --- tianshou/data/buffer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b8c689532..9053a6988 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -602,8 +602,8 @@ class CachedReplayBuffer(ReplayBuffers): """CachedReplayBuffer contains a given main buffer and n cached buffers, \ cached_buffer_num * ReplayBuffer(size=max_episode_length). - The memory layout is: ``| main_buffer | cached_buffer[0] | cached_buffer[1] - | ... | cached_buffer[cached_buffer_num - 1]``. + The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] + | ... | cached_buffers[cached_buffer_num - 1]``. The data is first stored in cached buffers. When the episode is terminated, the data will move to the main buffer and the corresponding @@ -636,7 +636,8 @@ def __init__( for _ in range(cached_buffer_num)] super().__init__(buffer_list=buffers, **kwargs) self.main_buffer = self.buffers[0] - self.cached_buffer = self.buffers[1:] + self.cached_buffers = self.buffers[1:] + self.cached_buffer_num = cached_buffer_num def add( # type: ignore self, @@ -662,17 +663,15 @@ def add( # type: ignore corresponding episode result. """ if cached_buffer_ids is None: - cached_buffer_ids = np.arange(self.buffer_num - 1) + cached_buffer_ids = np.arange(self.cached_buffer_num) # in self.buffers, the first buffer is main_buffer buffer_ids = np.asarray(cached_buffer_ids) + 1 - result = super().add(obs, act, rew, done, obs_next, info, policy, buffer_ids=buffer_ids, **kwargs) - # find the terminated episode, move data from cached buf to main buf - for buffer_idx in buffer_ids[np.asarray(done) > 0]: - self.main_buffer.update(self.buffers[buffer_idx]) - self.buffers[buffer_idx].reset() + for buffer_idx in cached_buffer_ids[np.asarray(done) > 0]: + self.main_buffer.update(self.cached_buffers[buffer_idx]) + self.cached_buffers[buffer_idx].reset() return result def __getitem__( From 679fe27750bbb3a05a09b6f7afc7bf446178a295 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 28 Jan 2021 17:25:57 +0800 Subject: [PATCH 025/104] draft of step_collector, not finished yet --- tianshou/data/collector.py | 311 +++++++++++++++---------------------- tianshou/trainer/utils.py | 4 +- 2 files changed, 130 insertions(+), 185 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 25f268dc0..018992e01 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -11,10 +11,12 @@ from tianshou.exploration import BaseNoise from tianshou.data.batch import _create_value from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy +from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, \ + CachedReplayBuffer, to_numpy class Collector(object): + #TODO change doc """Collector enables the policy to interact with different types of envs. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` @@ -75,40 +77,66 @@ class Collector(object): Please make sure the given environment has a time limitation. """ - def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], - buffer: Optional[ReplayBuffer] = None, + buffer: Optional[CachedReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, - action_noise: Optional[BaseNoise] = None, + training = False, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: + # TODO determine whether we need start_idxs + # TODO support not only cacahedbuffer + # TODO remove listreplaybuffer + # TODO update training in all test/examples, remove action noise + # TODO buffer need to be CachedReplayBuffer now, update + # examples/docs/ after supporting all types of buffers super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) + # TODO support or seperate async + assert env.is_async == False self.env = env self.env_num = len(env) - # environments that are available in step() - # this means all environments in synchronous simulation - # but only a subset of environments in asynchronous simulation - self._ready_env_ids = np.arange(self.env_num) - # self.async is a flag to indicate whether this collector works - # with asynchronous simulation - self.is_async = env.is_async - # need cache buffers before storing in the main buffer - self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] self.buffer = buffer + self._check_buffer() self.policy = policy self.preprocess_fn = preprocess_fn - self.process_fn = policy.process_fn self._action_space = env.action_space - self._action_noise = action_noise - self._rew_metric = reward_metric or Collector._default_rew_metric + self._rew_metric = reward_metric or BasicCollector._default_rew_metric + self.training = training # avoid creating attribute outside __init__ self.reset() + def _check_buffer(self): + max_episode_steps = self.env._max_episode_steps[0] + # TODO default to Replay_buffers + # TODO support replaybuffer when self.env_num == 1 + # TODO support Replay_buffers when self.env_num == 0 + if self.buffer is None: + self.buffer = CachedReplayBuffer(size = 0, + cached_buf_n = self.env_num, max_length = max_episode_steps) + else: + assert isinstance(self.buffer, CachedReplayBuffer), \ + "BasicCollector reuqires CachedReplayBuffer as buffer input." + assert self.buffer.cached_bufs_n == self.env_num + + if self.buffer.main_buffer.maxsize < self.buffer.maxsize//2: + warnings.warn( + "The size of buffer is suggested to be larger than " + "(cached buffer number) * max_length. Otherwise you might" + "loss data of episodes you just collected, and statistics " + "might even be incorrect.", + Warning) + if self.buffer.cached_buffer[0].maxsize < max_episode_steps: + warnings.warn( + "The size of cached_buf is suggested to be larger than " + "max episode length. Otherwise you might" + "loss data of episodes you just collected, and statistics " + "might even be incorrect.", + Warning) + @staticmethod def _default_rew_metric( x: Union[Number, np.number] @@ -130,31 +158,25 @@ def reset(self) -> None: self.reset_env() self.reset_buffer() self.reset_stat() - if self._action_noise is not None: - self._action_noise.reset() def reset_stat(self) -> None: """Reset the statistic variables.""" - self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 + self.collect_step, self.collect_episode = 0, 0 def reset_buffer(self) -> None: """Reset the main data buffer.""" - if self.buffer is not None: - self.buffer.reset() - - def get_env_num(self) -> int: - """Return the number of environments the collector have.""" - return self.env_num + self.buffer.reset() def reset_env(self) -> None: """Reset all of the environment(s)' states and the cache buffers.""" - self._ready_env_ids = np.arange(self.env_num) + #should not be exposed obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) self.data.obs = obs - for b in self._cached_buf: - b.reset() + # TODO different kind of buffers, + for buf in self.buffer.cached_buffer: + buf.reset() def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" @@ -165,15 +187,16 @@ def _reset_state(self, id: Union[int, List[int]]) -> None: state[id] = None if state.dtype == np.object else 0 elif isinstance(state, Batch): state.empty_(id) - + def collect( self, n_step: Optional[int] = None, - n_episode: Optional[Union[int, List[int]]] = None, + n_episode: Optional[int] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True, ) -> Dict[str, float]: + #TODO doc update """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. @@ -202,202 +225,124 @@ def collect( * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ - assert (n_step is not None and n_episode is None and n_step > 0) or ( - n_step is None and n_episode is not None and np.sum(n_episode) > 0 - ), "Only one of n_step or n_episode is allowed in Collector.collect, " - f"got n_step = {n_step}, n_episode = {n_episode}." - start_time = time.time() + #collect at least n_step or n_episode + if n_step is not None: + assert n_episode is None, "Only one of n_step or n_episode is allowed " + f"in Collector.collect, got n_step = {n_step}, n_episode = {n_episode}." + assert n_step > 0 and n_step % self.env_num == 0, \ + "n_step must not be 0, and should be an integral multiple of #envs" + else: + assert isinstance(n_episode, int) and n_episode > 0 + step_count = 0 # episode of each environment - episode_count = np.zeros(self.env_num) - # If n_episode is a list, and some envs have collected the required - # number of episodes, these envs will be recorded in this list, and - # they will not be stepped. - finished_env_ids = [] - rewards = [] - whole_data = Batch() - if isinstance(n_episode, list): - assert len(n_episode) == self.get_env_num() - finished_env_ids = [ - i for i in self._ready_env_ids if n_episode[i] <= 0] - self._ready_env_ids = np.array( - [x for x in self._ready_env_ids if x not in finished_env_ids]) + episode_count = 0 + episode_rews = [] + episode_lens = [] + # start_idxs = [] + cached_buffer_ids = [i for i in range(self.env_num)] + while True: - if step_count >= 100000 and episode_count.sum() == 0: + if step_count >= 100000 and episode_count == 0: warnings.warn( "There are already many steps in an episode. " "You should add a time limitation to your environment!", Warning) - is_async = self.is_async or len(finished_env_ids) > 0 - if is_async: - # self.data are the data for all environments in async - # simulation or some envs have finished, - # **only a subset of data are disposed**, - # so we store the whole data in ``whole_data``, let self.data - # to be the data available in ready environments, and finally - # set these back into all the data - whole_data = self.data - self.data = self.data[self._ready_env_ids] - # restore the state and the input data last_state = self.data.state if isinstance(last_state, Batch) and last_state.is_empty(): last_state = None - self.data.update(state=Batch(), obs_next=Batch(), policy=Batch()) - # calculate the next action + # calculate the next action and update state, act & policy into self.data if random: spaces = self._action_space result = Batch( - act=[spaces[i].sample() for i in self._ready_env_ids]) + act=[spaces[i].sample() for i in range(self.env_num)]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version + # self.data.obs will be used by agent to get result(mainly action) result = self.policy(self.data, last_state) else: result = self.policy(self.data, last_state) state = result.get("state", Batch()) - # convert None to Batch(), since None is reserved for 0-init + policy = result.get("policy", Batch()) + act = to_numpy(result.act) if state is None: + # convert None to Batch(), since None is reserved for 0-init state = Batch() - self.data.update(state=state, policy=result.get("policy", Batch())) - # save hidden state to policy._state, in order to save into buffer if not (isinstance(state, Batch) and state.is_empty()): - self.data.policy._state = self.data.state - - self.data.act = to_numpy(result.act) - if self._action_noise is not None: - assert isinstance(self.data.act, np.ndarray) - self.data.act += self._action_noise(self.data.act.shape) + # save hidden state to policy._state, in order to save into buffer + policy._state = state + + # TODO discuss and change policy's add_exp_noise behavior + if self.training and not random and hasattr(self.policy, 'add_exp_noise'): + act = self.policy.add_exp_noise(act) + self.data.update(state=state, policy = policy, act = act) # step in env - if not is_async: - obs_next, rew, done, info = self.env.step(self.data.act) - else: - # store computed actions, states, etc - _batch_set_item( - whole_data, self._ready_env_ids, self.data, self.env_num) - # fetch finished data - obs_next, rew, done, info = self.env.step( - self.data.act, id=self._ready_env_ids) - self._ready_env_ids = np.array([i["env_id"] for i in info]) - # get the stepped data - self.data = whole_data[self._ready_env_ids] - # move data to self.data - self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) + obs_next, rew, done, info = self.env.step(act) + + result = {"obs_next":obs_next, "rew":rew, "done":done, "info":info} + if self.preprocess_fn: + result = self.preprocess_fn(**result) # type: ignore + + # update obs_next, rew, done, & info into self.data + self.data.update(result) if render: self.env.render() time.sleep(render) # add data into the buffer - if self.preprocess_fn: - result = self.preprocess_fn(**self.data) # type: ignore - self.data.update(result) - - for j, i in enumerate(self._ready_env_ids): - # j is the index in current ready_env_ids - # i is the index in all environments - if self.buffer is None: - # users do not want to store data, so we store - # small fake data here to make the code clean - self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0) - else: - self._cached_buf[i].add(**self.data[j]) - - if done[j]: - if not (isinstance(n_episode, list) - and episode_count[i] >= n_episode[i]): - episode_count[i] += 1 - rewards.append(self._rew_metric( - np.sum(self._cached_buf[i].rew, axis=0))) - step_count += len(self._cached_buf[i]) - if self.buffer is not None: - self.buffer.update(self._cached_buf[i]) - if isinstance(n_episode, list) and \ - episode_count[i] >= n_episode[i]: - # env i has collected enough data, it has finished - finished_env_ids.append(i) - self._cached_buf[i].reset() - self._reset_state(j) - obs_next = self.data.obs_next + data_t = self.data + if n_episode and len(cached_buffer_ids) < self.env_num: + data_t = self.data[cached_buffer_ids] + # lens, rews, idxs = self.buffer.add(**data_t, index = cached_buffer_ids) + lens, rews = self.buffer.add(**data_t, index = cached_buffer_ids) + # collect statistics + step_count += len(cached_buffer_ids) + for i in cached_buffer_ids(np.where(lens == 0)[0]): + episode_count += 1 + episode_lens.append(lens[i]) + episode_rews.append(self._rew_metric(rews[i])) + # start_idxs.append(idxs[i]) + if sum(done): - env_ind_local = np.where(done)[0] - env_ind_global = self._ready_env_ids[env_ind_local] - obs_reset = self.env.reset(env_ind_global) + finised_env_ind = np.where(done)[0] + # now we copy obs_next to obs, but since there might be finished episodes, + # we have to reset finished envs first. + # TODO might auto reset help? + obs_reset = self.env.reset(finised_env_ind) if self.preprocess_fn: obs_reset = self.preprocess_fn( obs=obs_reset).get("obs", obs_reset) - obs_next[env_ind_local] = obs_reset - self.data.obs = obs_next - if is_async: - # set data back - whole_data = deepcopy(whole_data) # avoid reference in ListBuf - _batch_set_item( - whole_data, self._ready_env_ids, self.data, self.env_num) - # let self.data be the data in all environments again - self.data = whole_data - self._ready_env_ids = np.array( - [x for x in self._ready_env_ids if x not in finished_env_ids]) - if n_step: - if step_count >= n_step: - break - else: - if isinstance(n_episode, int) and \ - episode_count.sum() >= n_episode: - break - if isinstance(n_episode, list) and \ - (episode_count >= n_episode).all(): - break - - # finished envs are ready, and can be used for the next collection - self._ready_env_ids = np.array( - self._ready_env_ids.tolist() + finished_env_ids) + self.data.obs_next[finised_env_ind] = obs_reset + for i in finised_env_ind: + self._reset_state(i) + if n_episode and n_episode - episode_count < self.env_num: + try: + cached_buffer_ids.remove(i) + except ValueError: + pass + self.data.obs[:] = self.data.obs_next + + if (n_step and step_count >= n_step) or \ + (n_episode and episode_count >= n_episode): + break # generate the statistics - episode_count = sum(episode_count) - duration = max(time.time() - start_time, 1e-9) self.collect_step += step_count self.collect_episode += episode_count - self.collect_time += duration + if n_episode: + self.reset_env() + # TODO change api in trainer and other collector usage return { "n/ep": episode_count, "n/st": step_count, - "v/st": step_count / duration, - "v/ep": episode_count / duration, - "rew": np.mean(rewards), - "rew_std": np.std(rewards), - "len": step_count / episode_count, - } - - -def _batch_set_item( - source: Batch, indices: np.ndarray, target: Batch, size: int -) -> None: - # for any key chain k, there are four cases - # 1. source[k] is non-reserved, but target[k] does not exist or is reserved - # 2. source[k] does not exist or is reserved, but target[k] is non-reserved - # 3. both source[k] and target[k] are non-reserved - # 4. both source[k] and target[k] do not exist or are reserved, do nothing. - # A special case in case 4, if target[k] is reserved but source[k] does - # not exist, make source[k] reserved, too. - for k, vt in target.items(): - if not isinstance(vt, Batch) or not vt.is_empty(): - # target[k] is non-reserved - vs = source.get(k, Batch()) - if isinstance(vs, Batch): - if vs.is_empty(): - # case 2, use __dict__ to avoid many type checks - source.__dict__[k] = _create_value(vt[0], size) - else: - assert isinstance(vt, Batch) - _batch_set_item(source.__dict__[k], indices, vt, size) - else: - # target[k] is reserved - # case 1 or special case of case 4 - if k not in source.__dict__: - source.__dict__[k] = Batch() - continue - source.__dict__[k][indices] = vt + "rews": np.array(episode_rews), + "lens": np.array(episode_lens), + # "idxs": np.array(start_idxs) + } \ No newline at end of file diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index dfffd71a4..7ec617d7f 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -22,8 +22,8 @@ def test_episode( policy.eval() if test_fn: test_fn(epoch, global_step) - if collector.get_env_num() > 1 and isinstance(n_episode, int): - n = collector.get_env_num() + if collector.env_num > 1 and isinstance(n_episode, int): + n = collector.env_num n_ = np.zeros(n) + n_episode // n n_[:n_episode % n] += 1 n_episode = list(n_) From 26bb74c80d51bf8f9dd3827ec0a6205fd35917ee Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 28 Jan 2021 17:41:08 +0800 Subject: [PATCH 026/104] pep8 fix --- tianshou/data/buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 9053a6988..8311ba87a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -602,8 +602,8 @@ class CachedReplayBuffer(ReplayBuffers): """CachedReplayBuffer contains a given main buffer and n cached buffers, \ cached_buffer_num * ReplayBuffer(size=max_episode_length). - The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] - | ... | cached_buffers[cached_buffer_num - 1]``. + The memory layout is: ``| main_buffer | cached_buffers[0] | + cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1]``. The data is first stored in cached buffers. When the episode is terminated, the data will move to the main buffer and the corresponding From 74df1c5c7d3f0dc6a9cf6062676da3474be91779 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 28 Jan 2021 18:16:57 +0800 Subject: [PATCH 027/104] fix ci --- tianshou/data/buffer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 8311ba87a..d68a21388 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -664,12 +664,14 @@ def add( # type: ignore """ if cached_buffer_ids is None: cached_buffer_ids = np.arange(self.cached_buffer_num) + else: # make sure it is np.ndarray + cached_buffer_ids = np.asarray(cached_buffer_ids) # in self.buffers, the first buffer is main_buffer - buffer_ids = np.asarray(cached_buffer_ids) + 1 + buffer_ids = cached_buffer_ids + 1 # type: ignore result = super().add(obs, act, rew, done, obs_next, info, policy, buffer_ids=buffer_ids, **kwargs) # find the terminated episode, move data from cached buf to main buf - for buffer_idx in cached_buffer_ids[np.asarray(done) > 0]: + for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]: self.main_buffer.update(self.cached_buffers[buffer_idx]) self.cached_buffers[buffer_idx].reset() return result From 39463ba682a48c567ffbdc7797a86706bdf31378 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 28 Jan 2021 20:31:36 +0800 Subject: [PATCH 028/104] recover speed to 2000+ --- tianshou/data/buffer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index d68a21388..8917d2dce 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -232,8 +232,7 @@ def add( self._episode_length, self._episode_reward = 0, 0.0 return result else: - return np.zeros_like(self._episode_length), \ - np.zeros_like(self._episode_reward) + return self._episode_length * 0, self._episode_reward * 0. def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. From 31f0c94a5454c366c52a414c612df9c160f3bd36 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 28 Jan 2021 20:51:31 +0800 Subject: [PATCH 029/104] improve documents --- docs/tutorials/concepts.rst | 10 ++++++++-- tianshou/data/buffer.py | 25 ++++++++++++++----------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 1294cf1b3..a314cdedb 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -53,7 +53,7 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair Buffer ------ -:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. +:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style. The current implementation of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: @@ -66,7 +66,13 @@ The current implementation of Tianshou typically use 7 reserved keys in * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``); * ``policy`` the data computed by policy in step :math:`t`; -The following code snippet illustrates its usage: +The following code snippet illustrates its usage, including: + +- the basic data storage: ``add()``; +- get attribute, get slicing data, ...; +- sample from buffer: ``sample_index(batch_size)`` and ``sample(batch_size)``; +- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``; +- save/load data from buffer: pickle and HDF5; :: diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 8917d2dce..c6f741e61 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -12,8 +12,13 @@ class ReplayBuffer: """:class:`~tianshou.data.ReplayBuffer` stores data generated from \ - interaction between the policy and environment. ReplayBuffer can be \ - considered as a specialized form (or management) of Batch. + interaction between the policy and environment. + + ReplayBuffer can be considered as a specialized form (or management) of + Batch. It stores all the data in a batch with circular-queue style. + + For the example usage of ReplayBuffer, please check out Section Buffer in + :doc:`/tutorials/concepts`. :param int size: the maximum size of replay buffer. :param int stack_num: the frame-stack sampling argument, should be greater @@ -80,12 +85,9 @@ def __setstate__(self, state: Dict[str, Any]) -> None: ("buffer.__getattr__" is customized). """ self.__dict__.update(state) - # compatible with previous version HDF5 + # compatible with version == 0.3.1's HDF5 data format self._indices = np.arange(self.maxsize) - def __getstate__(self) -> Dict[str, Any]: - return self.__dict__ - def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" assert key not in self._reserved_keys, ( @@ -95,7 +97,7 @@ def __setattr__(self, key: str, value: Any) -> None: def save_hdf5(self, path: str) -> None: """Save replay buffer to HDF5 file.""" with h5py.File(path, "w") as f: - to_hdf5(self.__getstate__(), f) + to_hdf5(self.__dict__, f) @classmethod def load_hdf5( @@ -205,6 +207,7 @@ def add( self._add_to_buffer("obs", obs) self._add_to_buffer("act", act) # make sure the data type of reward is float instead of int + # but rew may be np.ndarray, so that we cannot use float(rew) rew = rew * 1.0 # type: ignore self._add_to_buffer("rew", rew) self._add_to_buffer("done", bool(done)) # done should be a bool scalar @@ -217,15 +220,15 @@ def add( self._add_to_buffer("info", info) self._add_to_buffer("policy", policy) - self._episode_reward += rew - self._episode_length += 1 - if self.maxsize > 0: self._size = min(self._size + 1, self.maxsize) self._index = (self._index + 1) % self.maxsize else: # TODO: remove this after deleting ListReplayBuffer self._size = self._index = self._size + 1 + self._episode_reward += rew + self._episode_length += 1 + if done: result = np.asarray(self._episode_length), \ np.asarray(self._episode_reward) @@ -282,7 +285,7 @@ def get( """Return the stacked result. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the - indice. + index. """ if stack_num is None: stack_num = self.stack_num From 720da29982291c450d6276c6008804d5f5c7f8e2 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 28 Jan 2021 21:58:52 +0800 Subject: [PATCH 030/104] ReplayBuffers -> ReplayBufferManager; alloc_fn -> _buffer_allocator; self._episode_rew/len defaults to 0.0/0 instead of np.array --- test/base/test_buffer.py | 41 +++++++++++++++++++------------- tianshou/data/__init__.py | 4 ++-- tianshou/data/buffer.py | 49 +++++++++++++++++---------------------- 3 files changed, 48 insertions(+), 46 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index f09d6c59e..d89c636a5 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -10,7 +10,7 @@ from tianshou.data.utils.converter import to_hdf5 from tianshou.data import Batch, SegmentTree, ReplayBuffer from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer -from tianshou.data import ReplayBuffers, CachedReplayBuffer +from tianshou.data import ReplayBufferManager, CachedReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -43,6 +43,7 @@ def test_replaybuffer(size=10, bufsize=20): assert b.sample_index(-1).tolist() == [] b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}}) assert b.obs[0] == 1 + assert b.done[0] assert b.obs_next[0] == 'str' assert np.all(b.obs[1:] == 0) assert np.all(b.obs_next[1:] == np.array(None)) @@ -101,9 +102,12 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) + # test if CachedReplayBuffer can handle stack_num + ignore_obs_next buf4 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num, size) + # test if CachedReplayBuffer can handle super corner case: + # prio-buffer + stack_num + ignore_obs_next + sample_avail buf5 = CachedReplayBuffer( PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num, ignore_obs_next=True, sample_avail=True), @@ -140,17 +144,18 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): assert indice.tolist() == [6] with pytest.raises(IndexError): buf[bufsize * 2] + # check the `add` order is correct assert np.allclose(buf4.obs.reshape(-1), [ - 12, 13, 14, 4, 6, 7, 8, 9, 11, - 1, 2, 3, 4, 0, - 6, 7, 8, 9, 0, - 11, 12, 13, 14, 0, + 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer + 1, 2, 3, 4, 0, # cached_buffer[0] + 6, 7, 8, 9, 0, # cached_buffer[1] + 11, 12, 13, 14, 0, # cached_buffer[2] ]), buf4.obs assert np.allclose(buf4.done, [ - 0, 0, 1, 1, 0, 0, 0, 1, 0, - 0, 0, 0, 1, 0, - 0, 0, 0, 1, 0, - 0, 0, 0, 1, 0, + 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer + 0, 0, 0, 1, 0, # cached_buffer[0] + 0, 0, 0, 1, 0, # cached_buffer[1] + 0, 0, 0, 1, 0, # cached_buffer[2] ]), buf4.done assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indice = sorted(buf4.sample_index(0)) @@ -191,7 +196,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): assert unmodified_weight.max() < 1 cached_weight = weight[9:] assert cached_weight.min() == cached_weight.max() == 1 - # test Atari + # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True, ignore_obs_next=True), @@ -372,7 +377,7 @@ def test_hdf5(): "array": ReplayBuffer(size, stack_num=2), "list": ListReplayBuffer(), "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), - "vector": ReplayBuffers([ReplayBuffer(size) for i in range(4)]), + "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]), "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) } buffer_types = {k: b.__class__ for k, b in buffers.items()} @@ -421,7 +426,7 @@ def test_hdf5(): buffers[k][:].info.number.n == _buffers[k][:].info.number.n) assert np.all( buffers[k][:].info.extra == _buffers[k][:].info.extra) - # check shallow copy in ReplayBuffers + # check shallow copy in ReplayBufferManager for k in ["vector", "cached"]: buffers[k].info.number.n[0] = -100 assert buffers[k].buffers[0].info.number.n[0] == -100 @@ -464,12 +469,13 @@ def test_hdf5(): to_hdf5(data, grp) -def test_vectorbuffer(): - buf = ReplayBuffers([ReplayBuffer(size=5) for i in range(4)]) +def test_replaybuffermanager(): + buf = ReplayBufferManager([ReplayBuffer(size=5) for i in range(4)]) ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], done=[0, 0, 1], buffer_ids=[0, 1, 2]) assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3]) with pytest.raises(NotImplementedError): + # ReplayBufferManager cannot be updated buf.update(buf) indice = buf.sample_index(11000) assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 @@ -549,7 +555,9 @@ def test_vectorbuffer(): assert np.allclose(buf.buffers[-1].info.n, [1] * 5) assert buf.sample_index(-1).tolist() == [] assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object - # CachedReplayBuffer + + +def test_cachedbuffer(): buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) assert buf.sample_index(0).tolist() == [] ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0], @@ -616,5 +624,6 @@ def test_vectorbuffer(): test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) test_update() - test_vectorbuffer() + test_replaybuffermanager() + test_cachedbuffer() test_hdf5() diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 1fbf1e481..368427a19 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -2,7 +2,7 @@ from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer import ReplayBuffer, ListReplayBuffer, \ - PrioritizedReplayBuffer, ReplayBuffers, CachedReplayBuffer + PrioritizedReplayBuffer, ReplayBufferManager, CachedReplayBuffer from tianshou.data.collector import Collector __all__ = [ @@ -14,7 +14,7 @@ "ReplayBuffer", "ListReplayBuffer", "PrioritizedReplayBuffer", - "ReplayBuffers", + "ReplayBufferManager", "CachedReplayBuffer", "Collector", ] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index c6f741e61..47a3697a2 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -56,11 +56,7 @@ def __init__( self._save_obs_next = not ignore_obs_next self._save_only_last_obs = save_only_last_obs self._sample_avail = sample_avail - self._index = 0 # current index - self._size = 0 # current buffer size self._meta: Batch = Batch() - self._episode_reward = 0.0 - self._episode_length = 0 self.reset() def __len__(self) -> int: @@ -110,7 +106,7 @@ def load_hdf5( return buf def reset(self) -> None: - """Clear all the data in replay buffer.""" + """Clear all the data in replay buffer and episode statistics.""" self._index = self._size = 0 self._episode_length, self._episode_reward = 0, 0.0 @@ -149,15 +145,15 @@ def update(self, buffer: "ReplayBuffer") -> None: if len(buffer) == 0 or self.maxsize == 0: return stack_num, buffer.stack_num = buffer.stack_num, 1 - save_only_last_obs, self._save_only_last_obs = \ - self._save_only_last_obs, False + save_only_last_obs = self._save_only_last_obs + self._save_only_last_obs = False indices = buffer.sample_index(0) # get all available indices for i in indices: self.add(**buffer[i]) # type: ignore buffer.stack_num = stack_num self._save_only_last_obs = save_only_last_obs - def alloc_fn(self, key: List[str], value: Any) -> None: + def _buffer_allocator(self, key: List[str], value: Any) -> None: """Allocate memory on buffer._meta for new (key, value) pair.""" data = self._meta for k in key[:-1]: @@ -168,7 +164,7 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] except KeyError: - self.alloc_fn([name], inst) + self._buffer_allocator([name], inst) value = self._meta[name] if isinstance(inst, (torch.Tensor, np.ndarray)): if inst.shape != value.shape[1:]: @@ -180,7 +176,7 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: value[self._index] = inst except KeyError: # inst is a dict/Batch for key in set(inst.keys()).difference(value.keys()): - self.alloc_fn([name, key], inst[key]) + self._buffer_allocator([name, key], inst[key]) self._meta[name][self._index] = inst def add( @@ -193,7 +189,7 @@ def add( info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, **kwargs: Any, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[int, Union[float, np.ndarray]]: """Add a batch of data into replay buffer. Return (episode_length, episode_reward) if one episode is terminated, @@ -230,12 +226,11 @@ def add( self._episode_length += 1 if done: - result = np.asarray(self._episode_length), \ - np.asarray(self._episode_reward) + result = self._episode_length, self._episode_reward self._episode_length, self._episode_reward = 0, 0.0 return result else: - return self._episode_length * 0, self._episode_reward * 0. + return self._episode_length * 0, self._episode_reward * 0.0 def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. @@ -362,7 +357,7 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: self._meta[name].append(inst) def reset(self) -> None: - self._index = self._size = 0 + super().reset() for k in self._meta.keys(): if isinstance(self._meta[k], list): self._meta.__dict__[k] = [] @@ -406,7 +401,7 @@ def add( policy: Optional[Union[dict, Batch]] = {}, weight: Optional[Union[Number, np.number]] = None, **kwargs: Any, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[int, Union[float, np.ndarray]]: if weight is None: weight = self._max_prio else: @@ -461,8 +456,8 @@ def __getitem__( return batch -class ReplayBuffers(ReplayBuffer): - """ReplayBuffers contains a list of ReplayBuffer. +class ReplayBufferManager(ReplayBuffer): + """ReplayBufferManager contains a list of ReplayBuffer. These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory. @@ -481,15 +476,13 @@ def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None: self._offset = [] offset = 0 for buf in self.buffers: - # overwrite sub-buffers' alloc_fn so that + # overwrite sub-buffers' _buffer_allocator so that # the top buffer can allocate new memory for all sub-buffers - buf.alloc_fn = self.alloc_fn # type: ignore + buf._buffer_allocator = self._buffer_allocator # type: ignore assert buf._meta.is_empty() self._offset.append(offset) offset += buf.maxsize super().__init__(size=offset, **kwargs) - # delete useless variables in ReplayBuffer - del self._index, self._size, self._episode_reward, self._episode_length def __len__(self) -> int: return sum([len(buf) for buf in self.buffers]) @@ -530,11 +523,11 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: return next_indices def update(self, buffer: ReplayBuffer) -> None: - """The ReplayBuffers cannot be updated by any buffer.""" + """The ReplayBufferManager cannot be updated by any buffer.""" raise NotImplementedError - def alloc_fn(self, key: List[str], value: Any) -> None: - super().alloc_fn(key, value) + def _buffer_allocator(self, key: List[str], value: Any) -> None: + super()._buffer_allocator(key, value) self._set_batch_for_children() def add( # type: ignore @@ -549,7 +542,7 @@ def add( # type: ignore buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, **kwargs: Any ) -> Tuple[np.ndarray, np.ndarray]: - """Add a batch of data into ReplayBuffers. + """Add a batch of data into ReplayBufferManager. Each of the data's length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. @@ -600,7 +593,7 @@ def sample_index(self, batch_size: int) -> np.ndarray: ]) -class CachedReplayBuffer(ReplayBuffers): +class CachedReplayBuffer(ReplayBufferManager): """CachedReplayBuffer contains a given main buffer and n cached buffers, \ cached_buffer_num * ReplayBuffer(size=max_episode_length). @@ -621,7 +614,7 @@ class CachedReplayBuffer(ReplayBuffers): .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` or - :class:`~tianshou.data.ReplayBuffers` for more detailed + :class:`~tianshou.data.ReplayBufferManager` for more detailed explanation. """ From b9385d8d3478e10650982cf9ac10aeae860a5172 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 28 Jan 2021 22:43:35 +0800 Subject: [PATCH 031/104] re-organize test_buffer.py --- test/base/test_buffer.py | 325 +++++++++++++++++++++++---------------- 1 file changed, 193 insertions(+), 132 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index d89c636a5..db79594a1 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -56,8 +56,6 @@ def test_replaybuffer(size=10, bufsize=20): b = ListReplayBuffer() with pytest.raises(NotImplementedError): b.sample(0) - with pytest.raises(NotImplementedError): - b.update(b) def test_ignore_obs_next(size=10): @@ -102,112 +100,28 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) - # test if CachedReplayBuffer can handle stack_num + ignore_obs_next - buf4 = CachedReplayBuffer( - ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), - cached_num, size) - # test if CachedReplayBuffer can handle super corner case: - # prio-buffer + stack_num + ignore_obs_next + sample_avail - buf5 = CachedReplayBuffer( - PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num, - ignore_obs_next=True, sample_avail=True), - cached_num, size) obs = env.reset(1) - for i in range(18): + for i in range(16): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) buf3.add([None, None, obs], 1, rew, done, [None, obs], info) - obs_list = np.array([obs + size * i for i in range(cached_num)]) - act_list = [1] * cached_num - rew_list = [rew] * cached_num - done_list = [done] * cached_num - obs_next_list = -obs_list - info_list = [info] * cached_num - buf4.add(obs_list, act_list, rew_list, done_list, - obs_next_list, info_list) - buf5.add(obs_list, act_list, rew_list, done_list, - obs_next_list, info_list) obs = obs_next if done: obs = env.reset(1) indice = np.arange(len(buf)) assert np.allclose(buf.get(indice, 'obs')[..., 0], [ - [2, 2, 2, 2], [2, 2, 2, 3], [2, 2, 3, 4], + [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2]]) + [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs')) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next')) _, indice = buf2.sample(0) - assert indice.tolist() == [6] + assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) - assert indice.tolist() == [6] + assert indice[0] in [2, 6] with pytest.raises(IndexError): buf[bufsize * 2] - # check the `add` order is correct - assert np.allclose(buf4.obs.reshape(-1), [ - 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer - 1, 2, 3, 4, 0, # cached_buffer[0] - 6, 7, 8, 9, 0, # cached_buffer[1] - 11, 12, 13, 14, 0, # cached_buffer[2] - ]), buf4.obs - assert np.allclose(buf4.done, [ - 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer - 0, 0, 0, 1, 0, # cached_buffer[0] - 0, 0, 0, 1, 0, # cached_buffer[1] - 0, 0, 0, 1, 0, # cached_buffer[2] - ]), buf4.done - assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) - indice = sorted(buf4.sample_index(0)) - assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) - assert np.allclose(buf4[indice].obs[..., 0], [ - [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], - [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], - [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], - [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], - [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], - ]) - assert np.allclose(buf4[indice].obs_next[..., 0], [ - [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], - [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], - [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], - [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], - [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], - ]) - assert np.all(buf4.done == buf5.done) - indice = buf5.sample_index(0) - assert np.allclose(sorted(indice), [2, 7]) - assert np.all(np.isin(buf5.sample_index(100), indice)) - # manually change the stack num - buf5.stack_num = 2 - for buf in buf5.buffers: - buf.stack_num = 2 - indice = buf5.sample_index(0) - assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) - batch, _ = buf5.sample(0) - assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1) - buf5.update_weight(indice, batch.weight * 0) - weight = buf5[np.arange(buf5.maxsize)].weight - modified_weight = weight[[0, 1, 2, 5, 6, 7]] - assert modified_weight.min() == modified_weight.max() - assert modified_weight.max() < 1 - unmodified_weight = weight[[3, 4, 8]] - assert unmodified_weight.min() == unmodified_weight.max() - assert unmodified_weight.max() < 1 - cached_weight = weight[9:] - assert cached_weight.min() == cached_weight.max() == 1 - # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next - buf6 = CachedReplayBuffer( - ReplayBuffer(bufsize, stack_num=stack_num, - save_only_last_obs=True, ignore_obs_next=True), - cached_num, size) - obs = np.random.rand(size, 4, 84, 84) - buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], - obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2]) - assert buf6.obs.shape == (buf6.maxsize, 84, 84) - assert np.allclose(buf6.obs[0], obs[0, -1]) - assert np.allclose(buf6.obs[14], obs[2, -1]) - assert np.allclose(buf6.obs[19], obs[0, -1]) def test_priortized_replaybuffer(size=32, bufsize=15): @@ -242,6 +156,12 @@ def test_update(): assert len(buf1) == len(buf2) assert (buf2[0].obs == buf1[1].obs).all() assert (buf2[-1].obs == buf1[0].obs).all() + b = ListReplayBuffer() + with pytest.raises(NotImplementedError): + b.update(b) + b = CachedReplayBuffer(ReplayBuffer(10), 4, 5) + with pytest.raises(NotImplementedError): + b.update(b) def test_segtree(): @@ -377,8 +297,6 @@ def test_hdf5(): "array": ReplayBuffer(size, stack_num=2), "list": ListReplayBuffer(), "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), - "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]), - "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -394,10 +312,6 @@ def test_hdf5(): buffers["array"].add(**kwargs) buffers["list"].add(**kwargs) buffers["prioritized"].add(weight=np.random.rand(), **kwargs) - buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), - buffer_ids=[0, 1, 2]) - buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), - cached_buffer_ids=[0, 1, 2]) # save paths = {} @@ -426,36 +340,6 @@ def test_hdf5(): buffers[k][:].info.number.n == _buffers[k][:].info.number.n) assert np.all( buffers[k][:].info.extra == _buffers[k][:].info.extra) - # check shallow copy in ReplayBufferManager - for k in ["vector", "cached"]: - buffers[k].info.number.n[0] = -100 - assert buffers[k].buffers[0].info.number.n[0] == -100 - # check if still behave normally - for k in ["vector", "cached"]: - kwargs = { - 'obs': Batch(index=np.array([5])), - 'act': 5, - 'rew': np.array([2, 1]), - 'done': False, - 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, - } - buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]])) - act = np.zeros(buffers[k].maxsize) - if k == "vector": - act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) - act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) - act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) - act[size * 3] = 5 - elif k == "cached": - act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) - act[np.arange(3) + size] = np.array([3, 5, 2]) - act[np.arange(3) + size * 2] = np.array([3, 5, 2]) - act[np.arange(3) + size * 3] = np.array([3, 5, 2]) - act[size * 4] = 5 - assert np.allclose(buffers[k].act, act) - - for path in paths.values(): - os.remove(path) # raise exception when value cannot be pickled data = {"not_supported": lambda x: x * x} @@ -477,8 +361,9 @@ def test_replaybuffermanager(): with pytest.raises(NotImplementedError): # ReplayBufferManager cannot be updated buf.update(buf) + # sample index / prev / next / unfinished_index indice = buf.sample_index(11000) - assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 + assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample batch, indice = buf.sample(0) assert np.allclose(indice, [0, 5, 10]) indice_prev = buf.prev(indice) @@ -509,6 +394,7 @@ def test_replaybuffermanager(): batch, indice = buf.sample(10) indice = buf.sample_index(0) assert np.allclose(indice, np.arange(len(buf))) + # check the actual data stored in buf._meta assert np.allclose(buf.done, [ 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, @@ -560,6 +446,7 @@ def test_replaybuffermanager(): def test_cachedbuffer(): buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) assert buf.sample_index(0).tolist() == [] + # check the normal function/usage/storage in CachedReplayBuffer ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0], cached_buffer_ids=[1]) obs = np.zeros(buf.maxsize) @@ -593,8 +480,9 @@ def test_cachedbuffer(): assert np.allclose(buf.prev(indice), [0, 1, 1, 25]) assert np.allclose(buf.next(indice), [0, 2, 2, 25]) indice = buf.sample_index(10000) - assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 - # cached buffer with main_buffer size == 0 + assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample + # cached buffer with main_buffer size == 0 (no update) + # used in test_collector buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) data = np.zeros(4) rew = np.ones([4, 4]) @@ -615,15 +503,188 @@ def test_cachedbuffer(): assert np.allclose(buf.next(indice), [1, 1, 11, 11]) +def test_multibuf_stack(): + size = 5 + bufsize = 9 + stack_num = 4 + cached_num = 3 + env = MyTestEnv(size) + # test if CachedReplayBuffer can handle stack_num + ignore_obs_next + buf4 = CachedReplayBuffer( + ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), + cached_num, size) + # test if CachedReplayBuffer can handle super corner case: + # prio-buffer + stack_num + ignore_obs_next + sample_avail + buf5 = CachedReplayBuffer( + PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num, + ignore_obs_next=True, sample_avail=True), + cached_num, size) + obs = env.reset(1) + for i in range(18): + obs_next, rew, done, info = env.step(1) + obs_list = np.array([obs + size * i for i in range(cached_num)]) + act_list = [1] * cached_num + rew_list = [rew] * cached_num + done_list = [done] * cached_num + obs_next_list = -obs_list + info_list = [info] * cached_num + buf4.add(obs_list, act_list, rew_list, done_list, + obs_next_list, info_list) + buf5.add(obs_list, act_list, rew_list, done_list, + obs_next_list, info_list) + obs = obs_next + if done: + obs = env.reset(1) + # check the `add` order is correct + assert np.allclose(buf4.obs.reshape(-1), [ + 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer + 1, 2, 3, 4, 0, # cached_buffer[0] + 6, 7, 8, 9, 0, # cached_buffer[1] + 11, 12, 13, 14, 0, # cached_buffer[2] + ]), buf4.obs + assert np.allclose(buf4.done, [ + 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer + 0, 0, 0, 1, 0, # cached_buffer[0] + 0, 0, 0, 1, 0, # cached_buffer[1] + 0, 0, 0, 1, 0, # cached_buffer[2] + ]), buf4.done + assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) + indice = sorted(buf4.sample_index(0)) + assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) + assert np.allclose(buf4[indice].obs[..., 0], [ + [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], + [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], + [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], + [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], + [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], + ]) + assert np.allclose(buf4[indice].obs_next[..., 0], [ + [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], + [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], + [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], + [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], + [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], + ]) + assert np.all(buf4.done == buf5.done) + indice = buf5.sample_index(0) + assert np.allclose(sorted(indice), [2, 7]) + assert np.all(np.isin(buf5.sample_index(100), indice)) + # manually change the stack num + buf5.stack_num = 2 + for buf in buf5.buffers: + buf.stack_num = 2 + indice = buf5.sample_index(0) + assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) + batch, _ = buf5.sample(0) + assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1) + buf5.update_weight(indice, batch.weight * 0) + weight = buf5[np.arange(buf5.maxsize)].weight + modified_weight = weight[[0, 1, 2, 5, 6, 7]] + assert modified_weight.min() == modified_weight.max() + assert modified_weight.max() < 1 + unmodified_weight = weight[[3, 4, 8]] + assert unmodified_weight.min() == unmodified_weight.max() + assert unmodified_weight.max() < 1 + cached_weight = weight[9:] + assert cached_weight.min() == cached_weight.max() == 1 + # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next + buf6 = CachedReplayBuffer( + ReplayBuffer(bufsize, stack_num=stack_num, + save_only_last_obs=True, ignore_obs_next=True), + cached_num, size) + obs = np.random.rand(size, 4, 84, 84) + buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], + obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2]) + assert buf6.obs.shape == (buf6.maxsize, 84, 84) + assert np.allclose(buf6.obs[0], obs[0, -1]) + assert np.allclose(buf6.obs[14], obs[2, -1]) + assert np.allclose(buf6.obs[19], obs[0, -1]) + assert buf6[0].obs.shape == (4, 84, 84) + + +def test_multibuf_hdf5(): + size = 100 + buffers = { + "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]), + "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) + } + buffer_types = {k: b.__class__ for k, b in buffers.items()} + device = 'cuda' if torch.cuda.is_available() else 'cpu' + info_t = torch.tensor([1.]).to(device) + for i in range(4): + kwargs = { + 'obs': Batch(index=np.array([i])), + 'act': i, + 'rew': np.array([1, 2]), + 'done': i % 3 == 2, + 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, + } + buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + buffer_ids=[0, 1, 2]) + buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + cached_buffer_ids=[0, 1, 2]) + + # save + paths = {} + for k, buf in buffers.items(): + f, path = tempfile.mkstemp(suffix='.hdf5') + os.close(f) + buf.save_hdf5(path) + paths[k] = path + + # load replay buffer + _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()} + + # compare + for k in buffers.keys(): + assert len(_buffers[k]) == len(buffers[k]) + assert np.allclose(_buffers[k].act, buffers[k].act) + assert _buffers[k].stack_num == buffers[k].stack_num + assert _buffers[k].maxsize == buffers[k].maxsize + assert np.all(_buffers[k]._indices == buffers[k]._indices) + # check shallow copy in ReplayBufferManager + for k in ["vector", "cached"]: + buffers[k].info.number.n[0] = -100 + assert buffers[k].buffers[0].info.number.n[0] == -100 + # check if still behave normally + for k in ["vector", "cached"]: + kwargs = { + 'obs': Batch(index=np.array([5])), + 'act': 5, + 'rew': np.array([2, 1]), + 'done': False, + 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, + } + buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]])) + act = np.zeros(buffers[k].maxsize) + if k == "vector": + act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) + act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) + act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) + act[size * 3] = 5 + elif k == "cached": + act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) + act[np.arange(3) + size] = np.array([3, 5, 2]) + act[np.arange(3) + size * 2] = np.array([3, 5, 2]) + act[np.arange(3) + size * 3] = np.array([3, 5, 2]) + act[size * 4] = 5 + assert np.allclose(buffers[k].act, act) + + for path in paths.values(): + os.remove(path) + + if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() test_stack() - test_pickle() test_segtree() test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) test_update() + test_pickle() + test_hdf5() test_replaybuffermanager() test_cachedbuffer() - test_hdf5() + test_multibuf_stack() + test_multibuf_hdf5() From 16bb42e05b4f5234e417b0658d2bfd2aec766d0e Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 29 Jan 2021 08:57:01 +0800 Subject: [PATCH 032/104] improve test --- test/base/test_buffer.py | 2 ++ test/discrete/test_qrdqn.py | 2 +- tianshou/data/buffer.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index db79594a1..04b1928e2 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -120,6 +120,8 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) assert indice[0] in [2, 6] + batch, indice = buf2.sample(-1) # neg bsz -> no data + assert indice.tolist() == [] and len(batch) == 0 with pytest.raises(IndexError): buf[bufsize * 2] diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 7020df275..0d03fce97 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -16,7 +16,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 47a3697a2..4c1ea14a6 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -230,7 +230,7 @@ def add( self._episode_length, self._episode_reward = 0, 0.0 return result else: - return self._episode_length * 0, self._episode_reward * 0.0 + return 0, self._episode_reward * 0.0 def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. From cada7cc1fae7568ceb16c21ca3abe6bf3d08fa20 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 29 Jan 2021 09:18:08 +0800 Subject: [PATCH 033/104] test if can be faster --- test/continuous/test_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index d0987cdbf..1a48aaee9 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) From 13683a0d1f1232f19e7c1b495fb95ff72dcd5eec Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 29 Jan 2021 10:36:14 +0800 Subject: [PATCH 034/104] update tianshou/trainer --- tianshou/data/collector.py | 4 +++- tianshou/trainer/offline.py | 6 +++--- tianshou/trainer/offpolicy.py | 19 +++++++++---------- tianshou/trainer/onpolicy.py | 20 +++++++++----------- tianshou/trainer/utils.py | 5 ----- 5 files changed, 24 insertions(+), 30 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 018992e01..362892f4c 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -161,7 +161,7 @@ def reset(self) -> None: def reset_stat(self) -> None: """Reset the statistic variables.""" - self.collect_step, self.collect_episode = 0, 0 + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 def reset_buffer(self) -> None: """Reset the main data buffer.""" @@ -233,6 +233,7 @@ def collect( "n_step must not be 0, and should be an integral multiple of #envs" else: assert isinstance(n_episode, int) and n_episode > 0 + start_time = time.time() step_count = 0 # episode of each environment @@ -336,6 +337,7 @@ def collect( # generate the statistics self.collect_step += step_count self.collect_episode += episode_count + self.collect_time += max(time.time() - start_time, 1e-9) if n_episode: self.reset_env() # TODO change api in trainer and other collector usage diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index e69364135..f46cd79e7 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -82,14 +82,14 @@ def offline_trainer( # test result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, gradient_step) - if best_epoch == -1 or best_reward < result["rew"]: - best_reward, best_reward_std = result["rew"], result["rew_std"] + if best_epoch == -1 or best_reward < result["rews"].mean(): + best_reward, best_reward_std = result["rews"].mean(), result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " - f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " + f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index f34f5b281..fd1271c38 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -93,22 +93,20 @@ def offpolicy_trainer( env_step += int(result["n/st"]) data = { "env_step": str(env_step), - "rew": f"{result['rew']:.2f}", - "len": str(int(result["len"])), + "rew": f"{result['rews'].mean():.2f}", + "len": str(int(result["lens"].mean())), "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), - "v/ep": f"{result['v/ep']:.2f}", - "v/st": f"{result['v/st']:.2f}", } if writer and env_step % log_interval == 0: for k in result.keys(): writer.add_scalar( "train/" + k, result[k], global_step=env_step) - if test_in_train and stop_fn and stop_fn(result["rew"]): + if test_in_train and stop_fn and stop_fn(result["rews"].mean()): test_result = test_episode( policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step) - if stop_fn(test_result["rew"]): + if stop_fn(test_result["rews"].mean()): if save_fn: save_fn(policy) for k in result.keys(): @@ -116,7 +114,7 @@ def offpolicy_trainer( t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"]) + test_result["rews"].mean(), test_result["rew_std"].std()) else: policy.train() for i in range(update_per_step * min( @@ -136,16 +134,17 @@ def offpolicy_trainer( # test result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step) - if best_epoch == -1 or best_reward < result["rew"]: - best_reward, best_reward_std = result["rew"], result["rew_std"] + if best_epoch == -1 or best_reward < result["rews"].mean(): + best_reward, best_reward_std = result["rews"].mean(), result["rew_std"].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {result["rews"].mean():.6f} ± " f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info(start_time, train_collector, test_collector, best_reward, best_reward_std) + diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index f094ddd7d..9435f5fe6 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -93,22 +93,20 @@ def onpolicy_trainer( env_step += int(result["n/st"]) data = { "env_step": str(env_step), - "rew": f"{result['rew']:.2f}", - "len": str(int(result["len"])), + "rew": f"{result["rews"].mean():.2f}", + "len": str(int(result["lens"].mean())), "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), - "v/ep": f"{result['v/ep']:.2f}", - "v/st": f"{result['v/st']:.2f}", } if writer and env_step % log_interval == 0: for k in result.keys(): writer.add_scalar( "train/" + k, result[k], global_step=env_step) - if test_in_train and stop_fn and stop_fn(result["rew"]): + if test_in_train and stop_fn and stop_fn(result["rews"].mean()): test_result = test_episode( policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step) - if stop_fn(test_result["rew"]): + if stop_fn(test_result["rews"].mean()): if save_fn: save_fn(policy) for k in result.keys(): @@ -116,7 +114,7 @@ def onpolicy_trainer( t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"]) + test_result["rews"].mean(), test_result["rew_std"]) else: policy.train() losses = policy.update( @@ -139,14 +137,14 @@ def onpolicy_trainer( # test result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step) - if best_epoch == -1 or best_reward < result["rew"]: - best_reward, best_reward_std = result["rew"], result["rew_std"] + if best_epoch == -1 or best_reward < result["rews"].mean(): + best_reward, best_reward_std = result["rews"].mean(), result["rew_std"] best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " - f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " + f"{result["rews"].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 7ec617d7f..cb9a2ef2a 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -22,11 +22,6 @@ def test_episode( policy.eval() if test_fn: test_fn(epoch, global_step) - if collector.env_num > 1 and isinstance(n_episode, int): - n = collector.env_num - n_ = np.zeros(n) + n_episode // n - n_[:n_episode % n] += 1 - n_episode = list(n_) result = collector.collect(n_episode=n_episode) if writer is not None and global_step is not None: for k in result.keys(): From 1fed4a683a4f4543d26e71b8336cd0765312b7fc Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 29 Jan 2021 14:59:46 +0800 Subject: [PATCH 035/104] small change --- tianshou/data/collector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 362892f4c..68d7159e6 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -87,7 +87,7 @@ def __init__( reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: # TODO determine whether we need start_idxs - # TODO support not only cacahedbuffer + # TODO support not only cacahedbuffer,(maybe auto change) # TODO remove listreplaybuffer # TODO update training in all test/examples, remove action noise # TODO buffer need to be CachedReplayBuffer now, update @@ -169,7 +169,6 @@ def reset_buffer(self) -> None: def reset_env(self) -> None: """Reset all of the environment(s)' states and the cache buffers.""" - #should not be exposed obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) From cf8a738953e3e26e39f073a541406825c5c3f9b2 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 29 Jan 2021 23:11:31 +0800 Subject: [PATCH 036/104] collector's buffer type check --- tianshou/data/buffer.py | 3 +- tianshou/data/collector.py | 64 ++++++++++++++++++++++---------------- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 4c1ea14a6..7fdfb316a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -345,7 +345,8 @@ class ListReplayBuffer(ReplayBuffer): """ def __init__(self, **kwargs: Any) -> None: - warnings.warn("ListReplayBuffer will be replaced in version 0.4.0.") + # TODO + warnings.warn("ListReplayBuffer will be removed in version 0.4.0.") super().__init__(size=0, ignore_obs_next=False, **kwargs) def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 68d7159e6..b633a0125 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -12,7 +12,7 @@ from tianshou.data.batch import _create_value from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, \ - CachedReplayBuffer, to_numpy + ReplayBufferManager, CachedReplayBuffer, to_numpy class Collector(object): @@ -81,14 +81,13 @@ def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], - buffer: Optional[CachedReplayBuffer] = None, + buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, training = False, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: # TODO determine whether we need start_idxs # TODO support not only cacahedbuffer,(maybe auto change) - # TODO remove listreplaybuffer # TODO update training in all test/examples, remove action noise # TODO buffer need to be CachedReplayBuffer now, update # examples/docs/ after supporting all types of buffers @@ -111,31 +110,35 @@ def __init__( def _check_buffer(self): max_episode_steps = self.env._max_episode_steps[0] - # TODO default to Replay_buffers - # TODO support replaybuffer when self.env_num == 1 - # TODO support Replay_buffers when self.env_num == 0 if self.buffer is None: - self.buffer = CachedReplayBuffer(size = 0, - cached_buf_n = self.env_num, max_length = max_episode_steps) - else: - assert isinstance(self.buffer, CachedReplayBuffer), \ - "BasicCollector reuqires CachedReplayBuffer as buffer input." - assert self.buffer.cached_bufs_n == self.env_num - - if self.buffer.main_buffer.maxsize < self.buffer.maxsize//2: + if self.training: + warnings.warn("ReplayBufferManager is not suggested to be used" + "in training mode, consider using CachedReplayBuffer, instead.") + self.buffer = ReplayBufferManager( + [ReplayBuffer(max_episode_steps)] * self.env_num) + elif isinstance(self.buffer, ReplayBufferManager): + if type(self.buffer) == ReplayBufferManager: + if self.training: + warnings.warn("ReplayBufferManager is not suggested to be used" + "in training mode, make sure you know how it works.") + self.buffer.cached_buffers = self.buffer.buffers + self.buffer.cached_buffer_num = self.buffer.buffer_num + assert self.buffer.cached_buffer_num == self.env_num + if self.buffer.cached_buffers[0].maxsize < max_episode_steps: warnings.warn( - "The size of buffer is suggested to be larger than " - "(cached buffer number) * max_length. Otherwise you might" - "loss data of episodes you just collected, and statistics " - "might even be incorrect.", - Warning) - if self.buffer.cached_buffer[0].maxsize < max_episode_steps: - warnings.warn( - "The size of cached_buf is suggested to be larger than " + "The size of cached_buffer is suggested to be larger than " "max episode length. Otherwise you might" "loss data of episodes you just collected, and statistics " - "might even be incorrect.", - Warning) + "might even be incorrect.", Warning) + else: #type ReplayBuffer + assert self.buffer.maxsize > 0 + if self.env_num != 1: + warnings.warn( + "CachedReplayBuffer/ReplayBufferManager rather than ReplayBuffer" + "is required in collector when #env > 1. Input buffer is switched" + "to CachedReplayBuffer.", Warning) + self.buffer = CachedReplayBuffer(self.buffer, + self.env_num, max_episode_steps) @staticmethod def _default_rew_metric( @@ -173,9 +176,9 @@ def reset_env(self) -> None: if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) self.data.obs = obs - # TODO different kind of buffers, - for buf in self.buffer.cached_buffer: - buf.reset() + if hasattr(self.buffer, "cached_buffers"): + for buf in self.buffer.cached_buffers: + buf.reset() def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" @@ -300,8 +303,15 @@ def collect( data_t = self.data if n_episode and len(cached_buffer_ids) < self.env_num: data_t = self.data[cached_buffer_ids] + if type(self.buffer) == ReplayBuffer: + data_t = data_t[0] # lens, rews, idxs = self.buffer.add(**data_t, index = cached_buffer_ids) + # rews need to be array for ReplayBuffer lens, rews = self.buffer.add(**data_t, index = cached_buffer_ids) + if type(self.buffer) == ReplayBuffer: + lens = np.asarray(lens) + rews = np.asarray(rews) + # idxs = np.asarray(idxs) # collect statistics step_count += len(cached_buffer_ids) for i in cached_buffer_ids(np.where(lens == 0)[0]): From 174f037629873f0df802e6a36e7776aba033d44a Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 30 Jan 2021 10:21:10 +0800 Subject: [PATCH 037/104] fix syntax err --- tianshou/data/collector.py | 34 +++++++++++++++++----------------- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 4 ++-- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b633a0125..4d1d80687 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -12,7 +12,7 @@ from tianshou.data.batch import _create_value from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, \ - ReplayBufferManager, CachedReplayBuffer, to_numpy + ReplayBufferManager, CachedReplayBuffer, to_numpy class Collector(object): @@ -83,7 +83,7 @@ def __init__( env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, - training = False, + training: bool = False, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: # TODO determine whether we need start_idxs @@ -95,28 +95,28 @@ def __init__( if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) # TODO support or seperate async - assert env.is_async == False + assert env.is_async is False self.env = env self.env_num = len(env) - self.buffer = buffer - self._check_buffer() + self.training = training + self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn self._action_space = env.action_space self._rew_metric = reward_metric or BasicCollector._default_rew_metric - self.training = training # avoid creating attribute outside __init__ self.reset() - def _check_buffer(self): - max_episode_steps = self.env._max_episode_steps[0] - if self.buffer is None: - if self.training: - warnings.warn("ReplayBufferManager is not suggested to be used" - "in training mode, consider using CachedReplayBuffer, instead.") - self.buffer = ReplayBufferManager( - [ReplayBuffer(max_episode_steps)] * self.env_num) - elif isinstance(self.buffer, ReplayBufferManager): + def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: + if not hasattr(self.env, "_max_episode_steps"): + warnings.warn("No time limit found in given env, set to 100000.") + max_episode_steps = 100000 + else: + max_episode_steps = self.env._max_episode_steps[0] + if buffer is None: + self.buffer = CachedReplayBuffer( + ReplayBuffer(0), self.env_num, max_episode_steps) + elif isinstance(buffer, ReplayBufferManager): if type(self.buffer) == ReplayBufferManager: if self.training: warnings.warn("ReplayBufferManager is not suggested to be used" @@ -130,14 +130,14 @@ def _check_buffer(self): "max episode length. Otherwise you might" "loss data of episodes you just collected, and statistics " "might even be incorrect.", Warning) - else: #type ReplayBuffer + else: # type ReplayBuffer assert self.buffer.maxsize > 0 if self.env_num != 1: warnings.warn( "CachedReplayBuffer/ReplayBufferManager rather than ReplayBuffer" "is required in collector when #env > 1. Input buffer is switched" "to CachedReplayBuffer.", Warning) - self.buffer = CachedReplayBuffer(self.buffer, + self.buffer = CachedReplayBuffer(buffer, self.env_num, max_episode_steps) @staticmethod diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index fd1271c38..d1d487fe1 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -140,7 +140,7 @@ def offpolicy_trainer( if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result["rews"].mean():.6f} ± " + print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 9435f5fe6..584c6a5bc 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -93,7 +93,7 @@ def onpolicy_trainer( env_step += int(result["n/st"]) data = { "env_step": str(env_step), - "rew": f"{result["rews"].mean():.2f}", + "rew": f"{result['rews'].mean():.2f}", "len": str(int(result["lens"].mean())), "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), @@ -144,7 +144,7 @@ def onpolicy_trainer( save_fn(policy) if verbose: print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result["rews"].std():.6f}, best_reward: {best_reward:.6f} ± " + f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break From 8810c23ed691d65a26402c9461a5fa0752b8298f Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Sat, 30 Jan 2021 10:37:20 +0800 Subject: [PATCH 038/104] change collector API in all files --- docs/tutorials/dqn.rst | 6 +++--- docs/tutorials/tictactoe.rst | 2 +- examples/atari/runnable/pong_a2c.py | 2 +- examples/atari/runnable/pong_ppo.py | 2 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/bipedal_hardcore_sac.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- examples/box2d/mcc_sac.py | 2 +- examples/mujoco/runnable/ant_v2_ddpg.py | 2 +- examples/mujoco/runnable/ant_v2_td3.py | 2 +- examples/mujoco/runnable/halfcheetahBullet_v0_sac.py | 2 +- examples/mujoco/runnable/point_maze_td3.py | 2 +- test/base/test_collector.py | 10 +++++----- test/continuous/test_ddpg.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_sac_with_il.py | 4 ++-- test/continuous/test_td3.py | 2 +- test/discrete/test_a2c_with_il.py | 4 ++-- test/discrete/test_c51.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_il_bcq.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_sac.py | 2 +- test/modelbase/test_psrl.py | 2 +- test/multiagent/Gomoku.py | 2 +- test/multiagent/tic_tac_toe.py | 2 +- tianshou/data/collector.py | 7 ------- tianshou/trainer/offline.py | 2 +- 31 files changed, 38 insertions(+), 45 deletions(-) diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index faea6a869..edd39355b 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -215,11 +215,11 @@ Tianshou supports user-defined training code. Here is the code snippet: # once if the collected episodes' mean returns reach the threshold, # or every 1000 steps, we test it on test_collector - if collect_result['rew'] >= env.spec.reward_threshold or i % 1000 == 0: + if collect_result['rews'].mean() >= env.spec.reward_threshold or i % 1000 == 0: policy.set_eps(0.05) result = test_collector.collect(n_episode=100) - if result['rew'] >= env.spec.reward_threshold: - print(f'Finished training! Test mean returns: {result["rew"]}') + if result['rews'].mean() >= env.spec.reward_threshold: + print(f'Finished training! Test mean returns: {result["rews"].mean()}') break else: # back to training eps diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index ac9adc118..bda6ff17a 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -294,7 +294,7 @@ With the above preparation, we are close to the first learned agent. The followi policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if args.watch: watch(args) exit(0) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 100ae24a6..4da98ce67 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -98,7 +98,7 @@ def stop_fn(mean_rewards): env = create_atari_environment(args.task) collector = Collector(policy, env, preprocess_fn=preprocess_fn) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 55219da68..3141a4bf2 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -102,7 +102,7 @@ def stop_fn(mean_rewards): env = create_atari_environment(args.task) collector = Collector(policy, env, preprocess_fn=preprocess_fn) result = collector.collect(n_step=2000, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 071990069..a0b8f7788 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -116,7 +116,7 @@ def test_fn(epoch, env_step): test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 58eb98ec9..87539e46e 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -153,7 +153,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 5c98fb779..781ae63ec 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -112,7 +112,7 @@ def test_fn(epoch, env_step): test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 47ca4e25c..a700e2f24 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -121,7 +121,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index 13192dbc9..e1bcd4964 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -96,7 +96,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index 5ed45506a..6e37c61f4 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -105,7 +105,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index b669d264a..ced529913 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -107,7 +107,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index dbb612fc8..781ec9734 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -113,7 +113,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/base/test_collector.py b/test/base/test_collector.py index e7e11759f..ec799e640 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -223,18 +223,18 @@ def reward_metric(x): c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn, reward_metric=reward_metric) # n_step=3 will collect a full episode - r = c0.collect(n_step=3)['rew'] + r = c0.collect(n_step=3)['rews'].mean() assert np.asanyarray(r).size == 1 and r == 4. - r = c0.collect(n_episode=2)['rew'] + r = c0.collect(n_episode=2)['rews'].mean() assert np.asanyarray(r).size == 1 and r == 4. env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) c1 = Collector(policy, envs, ReplayBuffer(size=100), Logger.single_preprocess_fn, reward_metric=reward_metric) - r = c1.collect(n_step=10)['rew'] + r = c1.collect(n_step=10)['rews'].mean() assert np.asanyarray(r).size == 1 and r == 4. - r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] + r = c1.collect(n_episode=[2, 1, 1, 2])['rews'].mean() assert np.asanyarray(r).size == 1 and r == 4. batch, _ = c1.buffer.sample(10) print(batch) @@ -250,7 +250,7 @@ def reward_metric(x): [[x] * 4 for x in rew]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), Logger.single_preprocess_fn, reward_metric=reward_metric) - r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] + r = c2.collect(n_episode=[0, 0, 0, 10])['rews'].mean() assert np.asanyarray(r).size == 1 and r == 4. batch, _ = c2.buffer.sample(10) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 52257342d..911cd8bf8 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -112,7 +112,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 1a48aaee9..a6d7469d3 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -130,7 +130,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index f35e18497..20eb63e88 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -118,7 +118,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') # here we define an imitation collector with a trivial policy policy.eval() @@ -149,7 +149,7 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 32c0efd43..8eb0c8fe2 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -123,7 +123,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 90f6681fd..7c5eff4c0 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -105,7 +105,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') policy.eval() # here we define an imitation collector with a trivial policy @@ -134,7 +134,7 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 32a41d0df..e726471f0 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -124,7 +124,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') def test_pc51(args=get_args()): diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 1e9f08984..94b488886 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -127,7 +127,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') # save buffer in pickle format, for imitation learning unittest buf = ReplayBuffer(args.buffer_size) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index f3f00e69f..c695aa916 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -103,7 +103,7 @@ def test_fn(epoch, env_step): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index e9e857ef1..697fab23b 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -104,7 +104,7 @@ def stop_fn(mean_rewards): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == "__main__": diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index b26213c59..75447b768 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -91,7 +91,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 231ad5032..50539642c 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -117,7 +117,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 0d03fce97..138f18f41 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -122,7 +122,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') def test_pqrdqn(args=get_args()): diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 67da3fb57..4c0c51625 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -114,7 +114,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 29dfb6b8c..9795dd483 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -88,7 +88,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') elif env.spec.reward_threshold: assert result["best_reward"] >= env.spec.reward_threshold diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py index 53652a2e4..0c0dcf58c 100644 --- a/test/multiagent/Gomoku.py +++ b/test/multiagent/Gomoku.py @@ -46,7 +46,7 @@ def env_func(): policy.replace_policy(opponent, 3 - args.agent_id) test_collector = Collector(policy, test_envs) results = test_collector.collect(n_episode=100) - rews.append(results['rew']) + rews.append(results['rews'].mean()) rews = np.array(rews) # weight opponent by their difficulty level rews = np.exp(-rews * 10.0) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index f9d6af104..fe85213e4 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -178,4 +178,4 @@ def watch( policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b633a0125..d1532bc4b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -87,10 +87,7 @@ def __init__( reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: # TODO determine whether we need start_idxs - # TODO support not only cacahedbuffer,(maybe auto change) # TODO update training in all test/examples, remove action noise - # TODO buffer need to be CachedReplayBuffer now, update - # examples/docs/ after supporting all types of buffers super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) @@ -222,8 +219,6 @@ def collect( * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. - * ``v/st`` the speed of steps per second. - * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ @@ -324,7 +319,6 @@ def collect( finised_env_ind = np.where(done)[0] # now we copy obs_next to obs, but since there might be finished episodes, # we have to reset finished envs first. - # TODO might auto reset help? obs_reset = self.env.reset(finised_env_ind) if self.preprocess_fn: obs_reset = self.preprocess_fn( @@ -349,7 +343,6 @@ def collect( self.collect_time += max(time.time() - start_time, 1e-9) if n_episode: self.reset_env() - # TODO change api in trainer and other collector usage return { "n/ep": episode_count, "n/st": step_count, diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index f46cd79e7..2547a2507 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -88,7 +88,7 @@ def offline_trainer( if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): From 7b62804d6da9ae38edc031228fbb303845d61211 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 30 Jan 2021 16:54:38 +0800 Subject: [PATCH 039/104] rewrite multibuf.add and buffer.update --- test/base/test_buffer.py | 3 ++- tianshou/data/buffer.py | 56 ++++++++++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 04b1928e2..0b57e6e73 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -534,6 +534,8 @@ def test_multibuf_stack(): obs_next_list, info_list) buf5.add(obs_list, act_list, rew_list, done_list, obs_next_list, info_list) + assert np.all(buf4.obs == buf5.obs) + assert np.all(buf4.done == buf5.done) obs = obs_next if done: obs = env.reset(1) @@ -567,7 +569,6 @@ def test_multibuf_stack(): [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], ]) - assert np.all(buf4.done == buf5.done) indice = buf5.sample_index(0) assert np.allclose(sorted(indice), [2, 7]) assert np.all(np.isin(buf5.sample_index(100), indice)) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 7fdfb316a..d47600857 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -140,21 +140,31 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: end_flag = self.done[index] | np.isin(index, self.unfinished_index()) return (index + (1 - end_flag)) % self._size - def update(self, buffer: "ReplayBuffer") -> None: + def update(self, buffer: "ReplayBuffer") -> np.ndarray: """Move the data from the given buffer to current buffer.""" if len(buffer) == 0 or self.maxsize == 0: return stack_num, buffer.stack_num = buffer.stack_num, 1 - save_only_last_obs = self._save_only_last_obs - self._save_only_last_obs = False - indices = buffer.sample_index(0) # get all available indices - for i in indices: - self.add(**buffer[i]) # type: ignore + from_indices = buffer.sample_index(0) # get all available indices buffer.stack_num = stack_num - self._save_only_last_obs = save_only_last_obs + if len(from_indices) == 0: + return + to_indices = [] + for _ in range(len(from_indices)): + to_indices.append(self._index) + self._index = (self._index + 1) % self.maxsize + self._size = min(self._size + 1, self.maxsize) + to_indices = np.array(to_indices) + if self._meta.is_empty(): + self._buffer_allocator([], buffer._meta[from_indices[0]]) + self._meta[to_indices] = buffer._meta[from_indices] + return to_indices def _buffer_allocator(self, key: List[str], value: Any) -> None: """Allocate memory on buffer._meta for new (key, value) pair.""" + if key == []: + self._meta = _create_value(value, self.maxsize) + return data = self._meta for k in key[:-1]: data = data[k] @@ -363,7 +373,7 @@ def reset(self) -> None: if isinstance(self._meta[k], list): self._meta.__dict__[k] = [] - def update(self, buffer: ReplayBuffer) -> None: + def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ListReplayBuffer cannot be updated by any buffer.""" raise NotImplementedError @@ -391,6 +401,10 @@ def __init__( self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() + def update(self, buffer: ReplayBuffer) -> np.ndarray: + indices = super().update(buffer) + self.weight[indices] = self._max_prio ** self._alpha + def add( self, obs: Any, @@ -523,7 +537,7 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: next_indices[mask] = buf.next(index[mask] - offset) + offset return next_indices - def update(self, buffer: ReplayBuffer) -> None: + def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" raise NotImplementedError @@ -554,18 +568,28 @@ def add( # type: ignore """ if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) - # assume each element in buffer_ids is unique - assert np.bincount(buffer_ids).max() == 1 - batch = Batch(obs=obs, act=act, rew=rew, done=done, - obs_next=obs_next, info=info, policy=policy) - assert len(buffer_ids) == len(batch) episode_lengths = [] # (len(buffer_ids),) episode_rewards = [] # (len(buffer_ids), ...) + is_obs_next_empty = isinstance(obs_next, Batch) and obs_next.is_empty() + is_info_empty = isinstance(info, Batch) and info.is_empty() + is_policy_empty = isinstance(policy, Batch) and policy.is_empty() for batch_idx, buffer_id in enumerate(buffer_ids): - length, reward = self.buffers[buffer_id].add(**batch[batch_idx]) + kwargs = { + "obs": obs[batch_idx], + "act": act[batch_idx], + "rew": rew[batch_idx], + "done": done[batch_idx], + } + if not is_obs_next_empty: + kwargs["obs_next"] = obs_next[batch_idx] + if not is_info_empty: + kwargs["info"] = info[batch_idx] + if not is_policy_empty: + kwargs["policy"] = policy[batch_idx] + length, reward = self.buffers[buffer_id].add(**kwargs) episode_lengths.append(length) episode_rewards.append(reward) - return np.stack(episode_lengths), np.stack(episode_rewards) + return np.array(episode_lengths), np.array(episode_rewards) def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: From 3f621013f151429cf193288438b2df755e2c4a8d Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 1 Feb 2021 13:06:02 +0800 Subject: [PATCH 040/104] buffer update (draft) --- test/base/test_buffer.py | 16 +- tianshou/data/__init__.py | 5 +- tianshou/data/batch.py | 19 +-- tianshou/data/buffer.py | 333 +++++++++++++------------------------- 4 files changed, 129 insertions(+), 244 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 0b57e6e73..ceeb04f40 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -9,7 +9,7 @@ from tianshou.data.utils.converter import to_hdf5 from tianshou.data import Batch, SegmentTree, ReplayBuffer -from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import PrioritizedReplayBuffer from tianshou.data import ReplayBufferManager, CachedReplayBuffer if __name__ == '__main__': @@ -27,7 +27,7 @@ def test_replaybuffer(size=10, bufsize=20): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(obs, [a], rew, done, obs_next, info) + buf.add(Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)) obs = obs_next assert len(buf) == min(bufsize, i + 1) with pytest.raises(ValueError): @@ -53,9 +53,6 @@ def test_replaybuffer(size=10, bufsize=20): assert np.all(b.info.b.c[1:] == 0.0) with pytest.raises(IndexError): b[22] - b = ListReplayBuffer() - with pytest.raises(NotImplementedError): - b.sample(0) def test_ignore_obs_next(size=10): @@ -158,9 +155,6 @@ def test_update(): assert len(buf1) == len(buf2) assert (buf2[0].obs == buf1[1].obs).all() assert (buf2[-1].obs == buf1[0].obs).all() - b = ListReplayBuffer() - with pytest.raises(NotImplementedError): - b.update(b) b = CachedReplayBuffer(ReplayBuffer(10), 4, 5) with pytest.raises(NotImplementedError): b.update(b) @@ -270,22 +264,17 @@ def sample_tree(): def test_pickle(): size = 100 vbuf = ReplayBuffer(size, stack_num=2) - lbuf = ListReplayBuffer() pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) rew = np.array([1, 1]) for i in range(4): vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) - for i in range(3): - lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0) for i in range(5): pbuf.add(obs=Batch(index=np.array([i])), act=2, rew=rew, done=0, weight=np.random.rand()) # save & load _vbuf = pickle.loads(pickle.dumps(vbuf)) - _lbuf = pickle.loads(pickle.dumps(lbuf)) _pbuf = pickle.loads(pickle.dumps(pbuf)) assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act) - assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act) assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act) # make sure the meta var is identical assert _vbuf.stack_num == vbuf.stack_num @@ -297,7 +286,6 @@ def test_hdf5(): size = 100 buffers = { "array": ReplayBuffer(size, stack_num=2), - "list": ListReplayBuffer(), "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), } buffer_types = {k: b.__class__ for k, b in buffers.items()} diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 368427a19..1d7a68493 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,8 +1,8 @@ from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree -from tianshou.data.buffer import ReplayBuffer, ListReplayBuffer, \ - PrioritizedReplayBuffer, ReplayBufferManager, CachedReplayBuffer +from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer, \ + ReplayBufferManager, CachedReplayBuffer from tianshou.data.collector import Collector __all__ = [ @@ -12,7 +12,6 @@ "to_torch_as", "SegmentTree", "ReplayBuffer", - "ListReplayBuffer", "PrioritizedReplayBuffer", "ReplayBufferManager", "CachedReplayBuffer", diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index fe3516046..9fe27e86f 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -181,17 +181,18 @@ def __init__( copy: bool = False, **kwargs: Any, ) -> None: + if batch_dict is None: + if len(kwargs) == 0: + return + batch_dict = kwargs if copy: batch_dict = deepcopy(batch_dict) - if batch_dict is not None: - if isinstance(batch_dict, (dict, Batch)): - _assert_type_keys(batch_dict.keys()) - for k, v in batch_dict.items(): - self.__dict__[k] = _parse_value(v) - elif _is_batch_set(batch_dict): - self.stack_(batch_dict) - if len(kwargs) > 0: - self.__init__(kwargs, copy=copy) # type: ignore + if isinstance(batch_dict, (dict, Batch)): + _assert_type_keys(batch_dict.keys()) + for k, v in batch_dict.items(): + self.__dict__[k] = _parse_value(v) + elif _is_batch_set(batch_dict): + self.stack_(batch_dict) def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index d47600857..5829a316e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -108,7 +108,7 @@ def load_hdf5( def reset(self) -> None: """Clear all the data in replay buffer and episode statistics.""" self._index = self._size = 0 - self._episode_length, self._episode_reward = 0, 0.0 + self._ep_len, self._ep_rew, self._ep_idx = 0, 0.0, 0 def set_batch(self, batch: Batch) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" @@ -156,91 +156,57 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._size = min(self._size + 1, self.maxsize) to_indices = np.array(to_indices) if self._meta.is_empty(): - self._buffer_allocator([], buffer._meta[from_indices[0]]) + self._meta = _create_value(buffer._meta[0], self.maxsize) self._meta[to_indices] = buffer._meta[from_indices] return to_indices - def _buffer_allocator(self, key: List[str], value: Any) -> None: - """Allocate memory on buffer._meta for new (key, value) pair.""" - if key == []: - self._meta = _create_value(value, self.maxsize) - return - data = self._meta - for k in key[:-1]: - data = data[k] - data[key[-1]] = _create_value(value, self.maxsize) + def add_index( + self, rew: Union[float, np.ndarray], done: bool + ) -> Tuple[int, int, Union[float, np.ndarray], int]: + """TODO""" + index = self._index + self._size = min(self._size + 1, self.maxsize) + self._index = (self._index + 1) % self.maxsize - def _add_to_buffer(self, name: str, inst: Any) -> None: - try: - value = self._meta.__dict__[name] - except KeyError: - self._buffer_allocator([name], inst) - value = self._meta[name] - if isinstance(inst, (torch.Tensor, np.ndarray)): - if inst.shape != value.shape[1:]: - raise ValueError( - "Cannot add data to a buffer with different shape with key" - f" {name}, expect {value.shape[1:]}, given {inst.shape}." - ) - try: - value[self._index] = inst - except KeyError: # inst is a dict/Batch - for key in set(inst.keys()).difference(value.keys()): - self._buffer_allocator([name, key], inst[key]) - self._meta[name][self._index] = inst + self._ep_rew += rew + self._ep_len += 1 - def add( - self, - obs: Any, - act: Any, - rew: Union[Number, np.number, np.ndarray], - done: Union[Number, np.number, np.bool_], - obs_next: Any = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {}, - **kwargs: Any, - ) -> Tuple[int, Union[float, np.ndarray]]: + if done: + result = index, self._ep_len, self._ep_rew, self._ep_idx + self._ep_len, self._ep_rew, self._ep_idx = 0, 0.0, self._index + return result + else: + return index, 0, self._ep_rew * 0.0, self._ep_idx + + def add(self, batch: Batch) -> Tuple[int, Union[float, np.ndarray], int]: """Add a batch of data into replay buffer. Return (episode_length, episode_reward) if one episode is terminated, otherwise return (0, 0.0). """ - assert isinstance( - info, (dict, Batch) - ), "You should return a dict in the last argument of env.step()." + # preprocess batch + for key in set(batch.keys()).difference(self._reserved_keys): + batch.pop(key) # save only reserved keys + assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) if self._save_only_last_obs: - obs = obs[-1] - self._add_to_buffer("obs", obs) - self._add_to_buffer("act", act) - # make sure the data type of reward is float instead of int - # but rew may be np.ndarray, so that we cannot use float(rew) - rew = rew * 1.0 # type: ignore - self._add_to_buffer("rew", rew) - self._add_to_buffer("done", bool(done)) # done should be a bool scalar - if self._save_obs_next: - if obs_next is None: - obs_next = Batch() - elif self._save_only_last_obs: - obs_next = obs_next[-1] - self._add_to_buffer("obs_next", obs_next) - self._add_to_buffer("info", info) - self._add_to_buffer("policy", policy) - - if self.maxsize > 0: - self._size = min(self._size + 1, self.maxsize) - self._index = (self._index + 1) % self.maxsize - else: # TODO: remove this after deleting ListReplayBuffer - self._size = self._index = self._size + 1 - - self._episode_reward += rew - self._episode_length += 1 - - if done: - result = self._episode_length, self._episode_reward - self._episode_length, self._episode_reward = 0, 0.0 - return result - else: - return 0, self._episode_reward * 0.0 + batch.__dict__["obs"] = batch.obs[-1] + batch.__dict__["rew"] = batch.rew.astype(np.float) + batch.__dict__["done"] = batch.done.astype(np.bool_) + if not self._save_obs_next: + batch.pop("obs_next", None) + elif self._save_only_last_obs: + batch.__dict__["obs_next"] = batch.obs_next[-1] + # get ptr + index, ep_len, ep_rew, ep_idx = self.add_index(batch.rew, batch.done) + try: + self._meta[index] = batch + except KeyError: # extra keys + if self._meta.is_empty(): + self._meta = _create_value(batch, self.maxsize) + else: # TODO: dynamic key pops up + pass + self._meta[index] = batch + return ep_len, ep_rew, ep_idx def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. @@ -339,45 +305,6 @@ def __getitem__( ) -class ListReplayBuffer(ReplayBuffer): - """List-based replay buffer. - - The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same - as :class:`~tianshou.data.ReplayBuffer`. The only difference is that - :class:`~tianshou.data.ListReplayBuffer` is based on list. Therefore, - it does not support advanced indexing, which means you cannot sample a - batch of data out of it. It is typically used for storing data. - - .. seealso:: - - Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed - explanation. - """ - - def __init__(self, **kwargs: Any) -> None: - # TODO - warnings.warn("ListReplayBuffer will be removed in version 0.4.0.") - super().__init__(size=0, ignore_obs_next=False, **kwargs) - - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - raise NotImplementedError("ListReplayBuffer cannot be sampled!") - - def _add_to_buffer(self, name: str, inst: Any) -> None: - if self._meta.get(name) is None: - self._meta.__dict__[name] = [] - self._meta[name].append(inst) - - def reset(self) -> None: - super().reset() - for k in self._meta.keys(): - if isinstance(self._meta[k], list): - self._meta.__dict__[k] = [] - - def update(self, buffer: ReplayBuffer) -> np.ndarray: - """The ListReplayBuffer cannot be updated by any buffer.""" - raise NotImplementedError - - class PrioritizedReplayBuffer(ReplayBuffer): """Implementation of Prioritized Experience Replay. arXiv:1511.05952. @@ -401,31 +328,16 @@ def __init__( self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() + def init_weight(self, index: Union[int, np.ndarray]) -> None: + self.weight[index] = self._max_prio ** self._alpha + def update(self, buffer: ReplayBuffer) -> np.ndarray: indices = super().update(buffer) - self.weight[indices] = self._max_prio ** self._alpha + self.init_weight(indices) - def add( - self, - obs: Any, - act: Any, - rew: Union[Number, np.number, np.ndarray], - done: Union[Number, np.number, np.bool_], - obs_next: Any = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {}, - weight: Optional[Union[Number, np.number]] = None, - **kwargs: Any, - ) -> Tuple[int, Union[float, np.ndarray]]: - if weight is None: - weight = self._max_prio - else: - weight = np.abs(weight) - self._max_prio = max(self._max_prio, weight) - self._min_prio = min(self._min_prio, weight) - self.weight[self._index] = weight ** self._alpha - return super().add(obs, act, rew, done, obs_next, - info, policy, **kwargs) + def add(self, batch: Batch) -> Tuple[int, Union[float, np.ndarray], int]: + self.init_weight(self._index) + return super().add(batch) def sample_index(self, batch_size: int) -> np.ndarray: if batch_size > 0 and self._size > 0: @@ -487,17 +399,17 @@ class ReplayBufferManager(ReplayBuffer): def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None: self.buffer_num = len(buffer_list) - self.buffers = buffer_list - self._offset = [] - offset = 0 + self.buffers = np.array(buffer_list) + offset, is_prio, size = [], [], 0 for buf in self.buffers: - # overwrite sub-buffers' _buffer_allocator so that - # the top buffer can allocate new memory for all sub-buffers - buf._buffer_allocator = self._buffer_allocator # type: ignore assert buf._meta.is_empty() - self._offset.append(offset) - offset += buf.maxsize - super().__init__(size=offset, **kwargs) + offset.append(size) + is_prio.append(isinstance(buf, PrioritizedReplayBuffer)) + size += buf.maxsize + self._offset = np.array(offset) + self._is_prio = np.array(is_prio) + self._is_any_prio = np.any(is_prio) + super().__init__(size=size, **kwargs) def __len__(self) -> int: return sum([len(buf) for buf in self.buffers]) @@ -541,22 +453,11 @@ def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" raise NotImplementedError - def _buffer_allocator(self, key: List[str], value: Any) -> None: - super()._buffer_allocator(key, value) - self._set_batch_for_children() - - def add( # type: ignore + def add( self, - obs: Any, - act: Any, - rew: np.ndarray, - done: np.ndarray, - obs_next: Any = Batch(), - info: Optional[Batch] = Batch(), - policy: Optional[Batch] = Batch(), + batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, - **kwargs: Any - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. Each of the data's length (first dimension) must equal to the length of @@ -568,28 +469,26 @@ def add( # type: ignore """ if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) - episode_lengths = [] # (len(buffer_ids),) - episode_rewards = [] # (len(buffer_ids), ...) - is_obs_next_empty = isinstance(obs_next, Batch) and obs_next.is_empty() - is_info_empty = isinstance(info, Batch) and info.is_empty() - is_policy_empty = isinstance(policy, Batch) and policy.is_empty() + indices, ep_lens, ep_rews, ep_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): - kwargs = { - "obs": obs[batch_idx], - "act": act[batch_idx], - "rew": rew[batch_idx], - "done": done[batch_idx], - } - if not is_obs_next_empty: - kwargs["obs_next"] = obs_next[batch_idx] - if not is_info_empty: - kwargs["info"] = info[batch_idx] - if not is_policy_empty: - kwargs["policy"] = policy[batch_idx] - length, reward = self.buffers[buffer_id].add(**kwargs) - episode_lengths.append(length) - episode_rewards.append(reward) - return np.array(episode_lengths), np.array(episode_rewards) + index, ep_len, ep_rew, ep_idx = self.buffers[buffer_id].add_index( + batch.rew[batch_idx], batch.done[batch_idx]) + indices.append(index + self._offset[buffer_id]) + ep_lens.append(ep_len) + ep_rews.append(ep_rew) + ep_idxs.append(ep_idx + self._offset[buffer_id]) + if self._is_prio[buffer_id]: + self.buffers[buffer_id].init_weight(index) + indices = np.array(indices) + try: + self._meta[indices] = batch + except KeyError: + if self._meta.is_empty(): + self._meta = _create_value(batch[0], self.maxsize) + else: # TODO: dynamic key pops up + pass + self._set_batch_for_children() + return np.array(ep_lens), np.array(ep_rews), np.array(ep_idxs) def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: @@ -617,6 +516,37 @@ def sample_index(self, batch_size: int) -> np.ndarray: for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) ]) + def __getitem__( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> Batch: + batch = super().__getitem__(index) + if self._is_any_prio: + indice = self._indices[index] + batch.weight = np.ones(len(indice)) + for offset, buf in zip( + self._offset[self._is_prio], self.buffers[self._is_prio]): + mask = (offset <= indice) & (indice < offset + buf.maxsize) + if np.any(mask): + batch.weight[mask] = buf.get_weight(indice[mask]) + return batch + + def update_weight( + self, + index: np.ndarray, + new_weight: Union[np.ndarray, torch.Tensor], + ) -> None: + """Update priority weight if any PrioritizedReplayBuffer is in buffers. + + :param np.ndarray index: index you want to update weight. + :param np.ndarray new_weight: new priority weight you want to update. + """ + if self._is_any_prio: + for offset, buf in zip( + self._offset[self._is_prio], self.buffers[self._is_prio]): + mask = (offset <= index) & (index < offset + buf.maxsize) + if np.any(mask): + buf.update_weight(index[mask], new_weight[mask]) + class CachedReplayBuffer(ReplayBufferManager): """CachedReplayBuffer contains a given main buffer and n cached buffers, \ @@ -659,18 +589,11 @@ def __init__( self.cached_buffers = self.buffers[1:] self.cached_buffer_num = cached_buffer_num - def add( # type: ignore + def add( self, - obs: Any, - act: Any, - rew: np.ndarray, - done: np.ndarray, - obs_next: Any = Batch(), - info: Optional[Batch] = Batch(), - policy: Optional[Batch] = Batch(), + batch: Batch, cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, - **kwargs: Any, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. Each of the data's length (first dimension) must equal to the length of @@ -688,35 +611,9 @@ def add( # type: ignore cached_buffer_ids = np.asarray(cached_buffer_ids) # in self.buffers, the first buffer is main_buffer buffer_ids = cached_buffer_ids + 1 # type: ignore - result = super().add(obs, act, rew, done, obs_next, info, - policy, buffer_ids=buffer_ids, **kwargs) + result = super().add(batch, buffer_ids=buffer_ids) # find the terminated episode, move data from cached buf to main buf - for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]: + for buffer_idx in cached_buffer_ids[batch.done.astype(np.bool_)]: self.main_buffer.update(self.cached_buffers[buffer_idx]) self.cached_buffers[buffer_idx].reset() return result - - def __getitem__( - self, index: Union[slice, int, np.integer, np.ndarray] - ) -> Batch: - batch = super().__getitem__(index) - if self._is_prioritized: - indice = self._indices[index] - mask = indice < self.main_buffer.maxsize - batch.weight = np.ones(len(indice)) - batch.weight[mask] = self.main_buffer.get_weight(indice[mask]) - return batch - - def update_weight( - self, - index: np.ndarray, - new_weight: Union[np.ndarray, torch.Tensor], - ) -> None: - """Update priority weight by index in main buffer. - - :param np.ndarray index: index you want to update weight. - :param np.ndarray new_weight: new priority weight you want to update. - """ - if self._is_prioritized: - mask = index < self.main_buffer.maxsize - self.main_buffer.update_weight(index[mask], new_weight[mask]) From 036cf149f59699cefe322eb7c185c54098152638 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 2 Feb 2021 11:20:14 +0800 Subject: [PATCH 041/104] buffer update: fix test --- test/base/test_batch.py | 4 +- test/base/test_buffer.py | 179 ++++++++++++++++------------- tianshou/data/batch.py | 17 ++- tianshou/data/buffer.py | 227 +++++++++++++++++++++---------------- tianshou/data/collector.py | 2 +- 5 files changed, 239 insertions(+), 190 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 650d56080..3e1bd5078 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -143,8 +143,10 @@ def test_batch(): assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 - with pytest.raises(KeyError): + with pytest.raises(ValueError): batch3.a.d[0] = Batch(f=5.0, g=0.0) + with pytest.raises(ValueError): + batch3[0] = Batch(a={"c": 2, "e": 1}) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == np.object # auto convert to np.object diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index ceeb04f40..64fd29193 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,9 +1,9 @@ import os +import h5py import torch import pickle import pytest import tempfile -import h5py import numpy as np from timeit import timeit @@ -27,13 +27,12 @@ def test_replaybuffer(size=10, bufsize=20): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)) + buf.add(Batch(obs=obs, act=[a], rew=rew, + done=done, obs_next=obs_next, info=info)) obs = obs_next assert len(buf) == min(bufsize, i + 1) - with pytest.raises(ValueError): - buf._add_to_buffer('rew', np.array([1, 2, 3])) - assert buf.act.dtype == np.object - assert isinstance(buf.act[0], list) + assert buf.act.dtype == np.int + assert buf.act.shape == (bufsize, 1) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() @@ -41,7 +40,8 @@ def test_replaybuffer(size=10, bufsize=20): b = ReplayBuffer(size=10) # neg bsz should return empty index assert b.sample_index(-1).tolist() == [] - b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}}) + b.add(Batch(obs=1, act=1, rew=1, done=1, + obs_next='str', info={'a': 3, 'b': {'c': 5.0}})) assert b.obs[0] == 1 assert b.done[0] assert b.obs_next[0] == 'str' @@ -51,6 +51,13 @@ def test_replaybuffer(size=10, bufsize=20): assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact assert np.all(b.info.b.c[1:] == 0.0) + # test extra keys pop up, the buffer should handle it dynamically + b.add(Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", + info={"a": 4, "d": {"e": -np.inf}})) + info_keys = ["a", "b", "d"] + assert set(b.info.keys()) == set(info_keys) + assert b.info.a[1] == 4 and b.info.b.c[1] == 0 + assert b.info.d.e[1] == -np.inf with pytest.raises(IndexError): b[22] @@ -59,14 +66,14 @@ def test_ignore_obs_next(size=10): # Issue 82 buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): - buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]), - 'mask2': np.array([i + 4, 0, 1, 0, 0]), - 'mask': i}, - act={'act_id': i, - 'position_id': i + 3}, - rew=i, - done=i % 3 == 0, - info={'if': i}) + buf.add(Batch(obs={'mask1': np.array([i, 1, 1, 0, 0]), + 'mask2': np.array([i + 4, 0, 1, 0, 0]), + 'mask': i}, + act={'act_id': i, + 'position_id': i + 3}, + rew=i, + done=i % 3 == 0, + info={'if': i})) indice = np.arange(len(buf)) orig = np.arange(len(buf)) data = buf[indice] @@ -100,9 +107,10 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): obs = env.reset(1) for i in range(16): obs_next, rew, done, info = env.step(1) - buf.add(obs, 1, rew, done, None, info) - buf2.add(obs, 1, rew, done, None, info) - buf3.add([None, None, obs], 1, rew, done, [None, obs], info) + buf.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) + buf2.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) + buf3.add(Batch(obs=[obs, obs, obs], act=1, rew=rew, + done=done, obs_next=[obs, obs], info=info)) obs = obs_next if done: obs = env.reset(1) @@ -130,7 +138,8 @@ def test_priortized_replaybuffer(size=32, bufsize=15): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5) + buf.add(Batch(obs=obs, act=a, rew=rew, done=done, obs_next=obs_next, + info=info, policy=np.random.randn() - 0.5)) obs = obs_next data, indice = buf.sample(len(buf) // 2) if len(buf) // 2 == 0: @@ -148,8 +157,8 @@ def test_update(): buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): - buf1.add(obs=np.array([i]), act=float(i), rew=i * i, - done=i % 2 == 0, info={'incident': 'found'}) + buf1.add(Batch(obs=np.array([i]), act=float(i), rew=i * i, + done=i % 2 == 0, info={'incident': 'found'})) assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) @@ -267,10 +276,10 @@ def test_pickle(): pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) rew = np.array([1, 1]) for i in range(4): - vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) + vbuf.add(Batch(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)) for i in range(5): - pbuf.add(obs=Batch(index=np.array([i])), - act=2, rew=rew, done=0, weight=np.random.rand()) + pbuf.add(Batch(obs=Batch(index=np.array([i])), + act=2, rew=rew, done=0, info=np.random.rand())) # save & load _vbuf = pickle.loads(pickle.dumps(vbuf)) _pbuf = pickle.loads(pickle.dumps(pbuf)) @@ -299,9 +308,8 @@ def test_hdf5(): 'done': i % 3 == 2, 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, } - buffers["array"].add(**kwargs) - buffers["list"].add(**kwargs) - buffers["prioritized"].add(weight=np.random.rand(), **kwargs) + buffers["array"].add(Batch(kwargs)) + buffers["prioritized"].add(Batch(kwargs)) # save paths = {} @@ -345,9 +353,10 @@ def test_hdf5(): def test_replaybuffermanager(): buf = ReplayBufferManager([ReplayBuffer(size=5) for i in range(4)]) - ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], - done=[0, 0, 1], buffer_ids=[0, 1, 2]) - assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3]) + batch = Batch(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], done=[0, 0, 1]) + ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2]) + assert np.all(ep_len == [0, 0, 1]) and np.all(ep_rew == [0, 0, 3]) + assert np.all(ptr == [0, 5, 10]) and np.all(ep_idx == [0, 5, 10]) with pytest.raises(NotImplementedError): # ReplayBufferManager cannot be updated buf.update(buf) @@ -361,7 +370,7 @@ def test_replaybuffermanager(): indice_next = buf.next(indice) assert np.allclose(indice_next, indice), indice_next assert np.allclose(buf.unfinished_index(), [0, 5]) - buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_ids=[3]) + buf.add(Batch(obs=[4], act=[4], rew=[4], done=[1]), buffer_ids=[3]) assert np.allclose(buf.unfinished_index(), [0, 5]) batch, indice = buf.sample(10) batch, indice = buf.sample(0) @@ -371,12 +380,14 @@ def test_replaybuffermanager(): indice_next = buf.next(indice) assert np.allclose(indice_next, indice), indice_next data = np.array([0, 0, 0, 0]) - buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3]) - buf.add(obs=data, act=data, rew=data, done=1 - data, + buf.add(Batch(obs=data, act=data, rew=data, done=data), + buffer_ids=[0, 1, 2, 3]) + buf.add(Batch(obs=data, act=data, rew=data, done=1 - data), buffer_ids=[0, 1, 2, 3]) assert len(buf) == 12 - buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3]) - buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1], + buf.add(Batch(obs=data, act=data, rew=data, done=data), + buffer_ids=[0, 1, 2, 3]) + buf.add(Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]), buffer_ids=[0, 1, 2, 3]) assert len(buf) == 20 indice = buf.sample_index(120000) @@ -404,9 +415,10 @@ def test_replaybuffermanager(): 15, 17, 17, 19, 19, ]) assert np.allclose(buf.unfinished_index(), [4, 14]) - ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[1], - buffer_ids=[2]) - assert np.allclose(ep_len, [3]) and np.allclose(ep_rew, [1]) + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2]) + assert np.all(ep_len == [3]) and np.all(ep_rew == [1]) + assert np.all(ptr == [10]) and np.all(ep_idx == [13]) assert np.allclose(buf.unfinished_index(), [4]) indice = list(sorted(buf.sample_index(0))) assert np.allclose(indice, np.arange(len(buf))) @@ -426,9 +438,9 @@ def test_replaybuffermanager(): assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] batch = buf._meta - batch.info.n = np.ones(buf.maxsize) + batch.info = np.ones(buf.maxsize) buf.set_batch(batch) - assert np.allclose(buf.buffers[-1].info.n, [1] * 5) + assert np.allclose(buf.buffers[-1].info, [1] * 5) assert buf.sample_index(-1).tolist() == [] assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object @@ -437,8 +449,8 @@ def test_cachedbuffer(): buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) assert buf.sample_index(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer - ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0], - cached_buffer_ids=[1]) + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=[1], act=[1], rew=[1], done=[0]), cached_buffer_ids=[1]) obs = np.zeros(buf.maxsize) obs[15] = 1 indice = buf.sample_index(0) @@ -446,21 +458,25 @@ def test_cachedbuffer(): assert np.allclose(buf.prev(indice), [15]) assert np.allclose(buf.next(indice), [15]) assert np.allclose(buf.obs, obs) - assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0]) - ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1], - cached_buffer_ids=[3]) + assert np.all(ep_len == [0]) and np.all(ep_rew == [0.0]) + assert np.all(ptr == [15]) and np.all(ep_idx == [15]) + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=[2], act=[2], rew=[2], done=[1]), cached_buffer_ids=[3]) obs[[0, 25]] = 2 indice = buf.sample_index(0) assert np.allclose(indice, [0, 15]) assert np.allclose(buf.prev(indice), [0, 15]) assert np.allclose(buf.next(indice), [0, 15]) assert np.allclose(buf.obs, obs) - assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0]) + assert np.all(ep_len == [1]) and np.all(ep_rew == [2.0]) + assert np.all(ptr == [0]) and np.all(ep_idx == [0]) assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_index(0), [0, 15]) - ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4], - done=[0, 1], cached_buffer_ids=[3, 1]) - assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0]) + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), + cached_buffer_ids=[3, 1]) + assert np.all(ep_len == [0, 2]) and np.all(ep_rew == [0, 5.0]) + assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] assert np.allclose(buf.obs, obs) assert np.allclose(buf.unfinished_index(), [25]) @@ -476,11 +492,15 @@ def test_cachedbuffer(): buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) data = np.zeros(4) rew = np.ones([4, 4]) - buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data) - buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) - buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data) - buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) - buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data) + buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 1, 1])) + buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0])) + buf.add(Batch(obs=data, act=data, rew=rew, done=[1, 1, 1, 1])) + buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0])) + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1])) + assert np.all(ptr == [1, -1, 11, -1]) and np.all(ep_idx == [0, -1, 10, -1]) + assert np.all(ep_len == [0, 2, 0, 2]) + assert np.all(ep_rew == [data, data + 2, data, data + 2]) assert np.allclose(buf.done, [ 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, @@ -503,11 +523,11 @@ def test_multibuf_stack(): buf4 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num, size) - # test if CachedReplayBuffer can handle super corner case: - # prio-buffer + stack_num + ignore_obs_next + sample_avail + # test if CachedReplayBuffer can handle corner case: + # buffer + stack_num + ignore_obs_next + sample_avail buf5 = CachedReplayBuffer( - PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num, - ignore_obs_next=True, sample_avail=True), + ReplayBuffer(bufsize, stack_num=stack_num, + ignore_obs_next=True, sample_avail=True), cached_num, size) obs = env.reset(1) for i in range(18): @@ -518,10 +538,10 @@ def test_multibuf_stack(): done_list = [done] * cached_num obs_next_list = -obs_list info_list = [info] * cached_num - buf4.add(obs_list, act_list, rew_list, done_list, - obs_next_list, info_list) - buf5.add(obs_list, act_list, rew_list, done_list, - obs_next_list, info_list) + batch = Batch(obs=obs_list, act=act_list, rew=rew_list, + done=done_list, obs_next=obs_next_list, info=info_list) + buf5.add(batch) + buf4.add(batch) assert np.all(buf4.obs == buf5.obs) assert np.all(buf4.done == buf5.done) obs = obs_next @@ -567,25 +587,26 @@ def test_multibuf_stack(): indice = buf5.sample_index(0) assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) batch, _ = buf5.sample(0) - assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1) - buf5.update_weight(indice, batch.weight * 0) - weight = buf5[np.arange(buf5.maxsize)].weight - modified_weight = weight[[0, 1, 2, 5, 6, 7]] - assert modified_weight.min() == modified_weight.max() - assert modified_weight.max() < 1 - unmodified_weight = weight[[3, 4, 8]] - assert unmodified_weight.min() == unmodified_weight.max() - assert unmodified_weight.max() < 1 - cached_weight = weight[9:] - assert cached_weight.min() == cached_weight.max() == 1 + # the below test code should move to PrioritizedReplayBufferManager + # assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1) + # buf5.update_weight(indice, batch.weight * 0) + # weight = buf5[np.arange(buf5.maxsize)].weight + # modified_weight = weight[[0, 1, 2, 5, 6, 7]] + # assert modified_weight.min() == modified_weight.max() + # assert modified_weight.max() < 1 + # unmodified_weight = weight[[3, 4, 8]] + # assert unmodified_weight.min() == unmodified_weight.max() + # assert unmodified_weight.max() < 1 + # cached_weight = weight[9:] + # assert cached_weight.min() == cached_weight.max() == 1 # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True, ignore_obs_next=True), cached_num, size) obs = np.random.rand(size, 4, 84, 84) - buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], - obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2]) + buf6.add(Batch(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], + obs_next=[obs[3], obs[1]]), cached_buffer_ids=[1, 2]) assert buf6.obs.shape == (buf6.maxsize, 84, 84) assert np.allclose(buf6.obs[0], obs[0, -1]) assert np.allclose(buf6.obs[14], obs[2, -1]) @@ -610,9 +631,9 @@ def test_multibuf_hdf5(): 'done': i % 3 == 2, 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, } - buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + buffers["vector"].add(Batch.cat([[kwargs], [kwargs], [kwargs]]), buffer_ids=[0, 1, 2]) - buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + buffers["cached"].add(Batch.cat([[kwargs], [kwargs], [kwargs]]), cached_buffer_ids=[0, 1, 2]) # save @@ -646,7 +667,7 @@ def test_multibuf_hdf5(): 'done': False, 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, } - buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]])) + buffers[k].add(Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]])) act = np.zeros(buffers[k].maxsize) if k == "vector": act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) @@ -660,6 +681,8 @@ def test_multibuf_hdf5(): act[np.arange(3) + size * 3] = np.array([3, 5, 2]) act[size * 4] = 5 assert np.allclose(buffers[k].act, act) + info_keys = ["number", "extra", "Timelimit.truncate"] + assert set(buffers[k].info.keys()) == set(info_keys) for path in paths.values(): os.remove(path) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 9fe27e86f..96dca7d40 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -8,12 +8,6 @@ from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \ Sequence -# Disable pickle warning related to torch, since it has been removed -# on torch master branch. See Pull Request #39003 for details: -# https://github.com/pytorch/pytorch/pull/39003 -warnings.filterwarnings( - "ignore", message="pickle support for Storage will be removed in 1.5.") - def _is_batch_set(data: Any) -> bool: # Batch set is a list/tuple of dict/Batch objects, @@ -91,6 +85,8 @@ def _create_value( has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) is_scalar = _is_scalar(inst) if not stack and is_scalar: + if isinstance(inst, Batch) and inst.is_empty(recurse=True): + return inst # should never hit since it has already checked in Batch.cat_ # here we do not consider scalar types, following the behavior of numpy # which does not support concatenation of zero-dimensional arrays @@ -250,15 +246,14 @@ def __setitem__( value: Any, ) -> None: """Assign value to self[index].""" - value = _parse_value(value) if isinstance(index, str): - self.__dict__[index] = value + self.__dict__[index] = _parse_value(value) return - if not isinstance(value, Batch): + if not isinstance(value, (dict, Batch)): raise ValueError("Batch does not supported tensor assignment. " "Use a compatible Batch or dict instead.") if not set(value.keys()).issubset(self.__dict__.keys()): - raise KeyError( + raise ValueError( "Creating keys is not supported by item assignment.") for key, val in self.items(): try: @@ -497,6 +492,8 @@ def stack_( self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0 ) -> None: """Stack a list of Batch object into current batch.""" + batches = [x for x in batches if isinstance( + x, dict) and x or isinstance(x, Batch) and not x.is_empty()] if len(batches) == 0: return batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 5829a316e..dd7c083a6 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -10,6 +10,17 @@ from tianshou.data.utils.converter import to_hdf5, from_hdf5 +def _alloc_by_keys_diff( + meta: Batch, batch: Batch, size: int, stack: bool = True +) -> None: + for key in batch.keys(): + if key in meta.keys(): + if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): + _alloc_by_keys_diff(meta[key], batch[key], size, stack) + else: + meta[key] = _create_value(batch[key], size, stack) + + class ReplayBuffer: """:class:`~tianshou.data.ReplayBuffer` stores data generated from \ interaction between the policy and environment. @@ -108,7 +119,7 @@ def load_hdf5( def reset(self) -> None: """Clear all the data in replay buffer and episode statistics.""" self._index = self._size = 0 - self._ep_len, self._ep_rew, self._ep_idx = 0, 0.0, 0 + self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 def set_batch(self, batch: Batch) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" @@ -141,14 +152,17 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: return (index + (1 - end_flag)) % self._size def update(self, buffer: "ReplayBuffer") -> np.ndarray: - """Move the data from the given buffer to current buffer.""" + """Move the data from the given buffer to current buffer. + + Return the updated indices. If update fails, return an empty array. + """ if len(buffer) == 0 or self.maxsize == 0: - return + return np.array([], np.int) stack_num, buffer.stack_num = buffer.stack_num, 1 from_indices = buffer.sample_index(0) # get all available indices buffer.stack_num = stack_num if len(from_indices) == 0: - return + return np.array([], np.int) to_indices = [] for _ in range(len(from_indices)): to_indices.append(self._index) @@ -156,15 +170,19 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._size = min(self._size + 1, self.maxsize) to_indices = np.array(to_indices) if self._meta.is_empty(): - self._meta = _create_value(buffer._meta[0], self.maxsize) + self._meta = _create_value(buffer._meta, self.maxsize, stack=False) self._meta[to_indices] = buffer._meta[from_indices] return to_indices - def add_index( + def _add_index( self, rew: Union[float, np.ndarray], done: bool - ) -> Tuple[int, int, Union[float, np.ndarray], int]: - """TODO""" - index = self._index + ) -> Tuple[int, Union[float, np.ndarray], int, int]: + """Maintain the buffer's state after adding one data batch. + + Return (index_to_be_modified, episode_reward, episode_length, + episode_start_index). + """ + ptr = self._index self._size = min(self._size + 1, self.maxsize) self._index = (self._index + 1) % self.maxsize @@ -172,41 +190,43 @@ def add_index( self._ep_len += 1 if done: - result = index, self._ep_len, self._ep_rew, self._ep_idx - self._ep_len, self._ep_rew, self._ep_idx = 0, 0.0, self._index + result = ptr, self._ep_rew, self._ep_len, self._ep_idx + self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, self._index return result else: - return index, 0, self._ep_rew * 0.0, self._ep_idx + return ptr, self._ep_rew * 0.0, 0, self._ep_idx - def add(self, batch: Batch) -> Tuple[int, Union[float, np.ndarray], int]: + def add( + self, batch: Batch + ) -> Tuple[int, Union[float, np.ndarray], int, int]: """Add a batch of data into replay buffer. - Return (episode_length, episode_reward) if one episode is terminated, - otherwise return (0, 0.0). + Return (current_index, episode_reward, episode_length, + episode_start_index). If the episode is not finished, the return value + of episode_length and episode_reward is 0. """ # preprocess batch - for key in set(batch.keys()).difference(self._reserved_keys): - batch.pop(key) # save only reserved keys + assert set(batch.keys()).issubset(self._reserved_keys) assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) if self._save_only_last_obs: - batch.__dict__["obs"] = batch.obs[-1] - batch.__dict__["rew"] = batch.rew.astype(np.float) - batch.__dict__["done"] = batch.done.astype(np.bool_) + batch.obs = batch.obs[-1] + batch.rew = batch.rew.astype(np.float) + batch.done = batch.done.astype(np.bool_) if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: - batch.__dict__["obs_next"] = batch.obs_next[-1] + batch.obs_next = batch.obs_next[-1] # get ptr - index, ep_len, ep_rew, ep_idx = self.add_index(batch.rew, batch.done) + ptr, ep_rew, ep_len, ep_idx = self._add_index(batch.rew, batch.done) try: - self._meta[index] = batch - except KeyError: # extra keys + self._meta[ptr] = batch + except ValueError: if self._meta.is_empty(): self._meta = _create_value(batch, self.maxsize) - else: # TODO: dynamic key pops up - pass - self._meta[index] = batch - return ep_len, ep_rew, ep_idx + else: # dynamic key pops up in batch + _alloc_by_keys_diff(self._meta, batch, self.maxsize) + self._meta[ptr] = batch + return ptr, ep_rew, ep_len, ep_idx def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. @@ -251,6 +271,7 @@ def get( self, index: Union[int, np.integer, np.ndarray], key: str, + default_value: Optional[Any] = None, stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: """Return the stacked result. @@ -258,9 +279,11 @@ def get( E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index. """ + if key not in self._meta and default_value is not None: + return default_value + val = self._meta[key] if stack_num is None: stack_num = self.stack_num - val = self._meta[key] try: if stack_num == 1: # the most often case return val[index] @@ -291,17 +314,17 @@ def __getitem__( # raise KeyError first instead of AttributeError, to support np.array obs = self.get(index, "obs") if self._save_obs_next: - obs_next = self.get(index, "obs_next") + obs_next = self.get(index, "obs_next", Batch()) else: - obs_next = self.get(self.next(index), "obs") + obs_next = self.get(self.next(index), "obs", Batch()) return Batch( obs=obs, act=self.act[index], rew=self.rew[index], done=self.done[index], obs_next=obs_next, - info=self.get(index, "info"), - policy=self.get(index, "policy"), + info=self.get(index, "info", Batch()), + policy=self.get(index, "policy", Batch()), ) @@ -327,6 +350,7 @@ def __init__( # save weight directly in this class instead of self._meta self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() + self.options.update(alpha=alpha, beta=beta) def init_weight(self, index: Union[int, np.ndarray]) -> None: self.weight[index] = self._max_prio ** self._alpha @@ -335,12 +359,15 @@ def update(self, buffer: ReplayBuffer) -> np.ndarray: indices = super().update(buffer) self.init_weight(indices) - def add(self, batch: Batch) -> Tuple[int, Union[float, np.ndarray], int]: - self.init_weight(self._index) - return super().add(batch) + def add( + self, batch: Batch + ) -> Tuple[int, Union[float, np.ndarray], int, int]: + ptr, ep_rew, ep_len, ep_idx = super().add(batch) + self.init_weight(ptr) + return ptr, ep_rew, ep_len, ep_idx def sample_index(self, batch_size: int) -> np.ndarray: - if batch_size > 0 and self._size > 0: + if batch_size > 0 and len(self) > 0: scalar = np.random.rand(batch_size) * self.weight.reduce() return self.weight.get_prefix_sum_idx(scalar) else: @@ -384,12 +411,13 @@ def __getitem__( class ReplayBufferManager(ReplayBuffer): - """ReplayBufferManager contains a list of ReplayBuffer. + """ReplayBufferManager contains a list of ReplayBuffer with exactly the \ + same configuration. These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory. - :param int buffer_list: a list of ReplayBuffers needed to be handled. + :param int buffer_list: a list of ReplayBuffer needed to be handled. .. seealso:: @@ -397,18 +425,18 @@ class ReplayBufferManager(ReplayBuffer): explanation. """ - def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None: + def __init__(self, buffer_list: List[ReplayBuffer]) -> None: self.buffer_num = len(buffer_list) self.buffers = np.array(buffer_list) - offset, is_prio, size = [], [], 0 + offset, size = [], 0 + buffer_type = type(self.buffers[0]) + kwargs = self.buffers[0].options for buf in self.buffers: assert buf._meta.is_empty() + assert isinstance(buf, buffer_type) and buf.options == kwargs offset.append(size) - is_prio.append(isinstance(buf, PrioritizedReplayBuffer)) size += buf.maxsize self._offset = np.array(offset) - self._is_prio = np.array(is_prio) - self._is_any_prio = np.any(is_prio) super().__init__(size=size, **kwargs) def __len__(self) -> int: @@ -457,38 +485,49 @@ def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. Each of the data's length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. - Return the array of episode_length and episode_reward with shape - (len(buffer_ids), ...), where (episode_length[i], episode_reward[i]) - refers to the buffer_ids[i]'s corresponding episode result. + Return (current_index, episode_reward, episode_length, + episode_start_index). If the episode is not finished, the return value + of episode_length and episode_reward is 0. """ + # preprocess batch + assert set(batch.keys()).issubset(self._reserved_keys) + assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) + if self._save_only_last_obs: + batch.obs = batch.obs[:, -1] + batch.rew = batch.rew.astype(np.float) + batch.done = batch.done.astype(np.bool_) + if not self._save_obs_next: + batch.pop("obs_next", None) + elif self._save_only_last_obs: + batch.obs_next = batch.obs_next[:, -1] + # get index if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) - indices, ep_lens, ep_rews, ep_idxs = [], [], [], [] + ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): - index, ep_len, ep_rew, ep_idx = self.buffers[buffer_id].add_index( + ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( batch.rew[batch_idx], batch.done[batch_idx]) - indices.append(index + self._offset[buffer_id]) + ptrs.append(ptr + self._offset[buffer_id]) ep_lens.append(ep_len) ep_rews.append(ep_rew) ep_idxs.append(ep_idx + self._offset[buffer_id]) - if self._is_prio[buffer_id]: - self.buffers[buffer_id].init_weight(index) - indices = np.array(indices) + ptrs = np.array(ptrs) try: - self._meta[indices] = batch - except KeyError: + self._meta[ptrs] = batch + except ValueError: if self._meta.is_empty(): - self._meta = _create_value(batch[0], self.maxsize) - else: # TODO: dynamic key pops up - pass + self._meta = _create_value(batch, self.maxsize, stack=False) + else: # dynamic key pops up in batch + _alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() - return np.array(ep_lens), np.array(ep_rews), np.array(ep_idxs) + self._meta[ptrs] = batch + return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs) def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: @@ -516,36 +555,16 @@ def sample_index(self, batch_size: int) -> np.ndarray: for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) ]) - def __getitem__( - self, index: Union[slice, int, np.integer, np.ndarray] - ) -> Batch: - batch = super().__getitem__(index) - if self._is_any_prio: - indice = self._indices[index] - batch.weight = np.ones(len(indice)) - for offset, buf in zip( - self._offset[self._is_prio], self.buffers[self._is_prio]): - mask = (offset <= indice) & (indice < offset + buf.maxsize) - if np.any(mask): - batch.weight[mask] = buf.get_weight(indice[mask]) - return batch - - def update_weight( - self, - index: np.ndarray, - new_weight: Union[np.ndarray, torch.Tensor], - ) -> None: - """Update priority weight if any PrioritizedReplayBuffer is in buffers. - :param np.ndarray index: index you want to update weight. - :param np.ndarray new_weight: new priority weight you want to update. - """ - if self._is_any_prio: - for offset, buf in zip( - self._offset[self._is_prio], self.buffers[self._is_prio]): - mask = (offset <= index) & (index < offset + buf.maxsize) - if np.any(mask): - buf.update_weight(index[mask], new_weight[mask]) +class PrioritizedReplayBufferManager( + PrioritizedReplayBuffer, ReplayBufferManager +): + def __init__(self, buffer_list: List[PrioritizedReplayBuffer]) -> None: + ReplayBufferManager.__init__(self, buffer_list) + kwargs = buffer_list[0].options + for buf in buffer_list: + del buf.weight + PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs) class CachedReplayBuffer(ReplayBufferManager): @@ -580,11 +599,11 @@ def __init__( max_episode_length: int, ) -> None: assert cached_buffer_num > 0 and max_episode_length > 0 - self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer) + assert type(main_buffer) == ReplayBuffer kwargs = main_buffer.options buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) for _ in range(cached_buffer_num)] - super().__init__(buffer_list=buffers, **kwargs) + super().__init__(buffer_list=buffers) self.main_buffer = self.buffers[0] self.cached_buffers = self.buffers[1:] self.cached_buffer_num = cached_buffer_num @@ -593,17 +612,18 @@ def add( self, batch: Batch, cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. Each of the data's length (first dimension) must equal to the length of cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ..., cached_buffer_num - 1]. - Return the array of episode_length and episode_reward with shape - (len(cached_buffer_ids), ...), where (episode_length[i], - episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's - corresponding episode result. + Return (current_index, episode_reward, episode_length, + episode_start_index) with each of the shape (len(cached_buffer_ids), + ...), where (current_index[i], episode_reward[i], episode_length[i], + episode_start_index[i]) refers to the cached_buffer_ids[i]th cached + buffer's corresponding episode result. """ if cached_buffer_ids is None: cached_buffer_ids = np.arange(self.cached_buffer_num) @@ -611,9 +631,16 @@ def add( cached_buffer_ids = np.asarray(cached_buffer_ids) # in self.buffers, the first buffer is main_buffer buffer_ids = cached_buffer_ids + 1 # type: ignore - result = super().add(batch, buffer_ids=buffer_ids) + ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buffer_ids) # find the terminated episode, move data from cached buf to main buf - for buffer_idx in cached_buffer_ids[batch.done.astype(np.bool_)]: - self.main_buffer.update(self.cached_buffers[buffer_idx]) + updated_ptr, updated_ep_idx = [], [] + for buffer_idx in cached_buffer_ids[batch.done]: + index = self.main_buffer.update(self.cached_buffers[buffer_idx]) + if len(index) == 0: # unsuccessful move, replace with -1 + index = [-1] + updated_ep_idx.append(index[0]) + updated_ptr.append(index[-1]) self.cached_buffers[buffer_idx].reset() - return result + ptr[batch.done] = updated_ptr + ep_idx[batch.done] = updated_ep_idx + return ptr, ep_rew, ep_len, ep_idx diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index e6508527c..0c33c7f6e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -11,7 +11,7 @@ from tianshou.exploration import BaseNoise from tianshou.data.batch import _create_value from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, \ +from tianshou.data import Batch, ReplayBuffer, \ ReplayBufferManager, CachedReplayBuffer, to_numpy From 4b86274fe11069797f9a749d73d71530b1999c5f Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 2 Feb 2021 17:55:06 +0800 Subject: [PATCH 042/104] VectorBuffer --- tianshou/data/buffer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index dd7c083a6..70873b68d 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -567,6 +567,16 @@ def __init__(self, buffer_list: List[PrioritizedReplayBuffer]) -> None: PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs) +class VectorReplayBuffer(ReplayBufferManager): + def __init__( + self, total_size: int, buffer_num: int, **kwargs: Any + ) -> None: + sizes = [total_size // buffer_num + (i < total_size % buffer_num) + for i in range(buffer_num)] + buffer_list = [ReplayBuffer(size, **kwargs) for size in sizes] + super().__init__(buffer_list) + + class CachedReplayBuffer(ReplayBufferManager): """CachedReplayBuffer contains a given main buffer and n cached buffers, \ cached_buffer_num * ReplayBuffer(size=max_episode_length). From 4ffdb8233bf07b333e2bcafd8adf8739f627cb18 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 2 Feb 2021 19:29:05 +0800 Subject: [PATCH 043/104] vectorbuf --- test/base/test_buffer.py | 8 ++++---- tianshou/data/__init__.py | 6 +++++- tianshou/data/buffer.py | 11 +++++++++++ tianshou/data/collector.py | 6 ++++-- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 64fd29193..756efa751 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -10,7 +10,7 @@ from tianshou.data.utils.converter import to_hdf5 from tianshou.data import Batch, SegmentTree, ReplayBuffer from tianshou.data import PrioritizedReplayBuffer -from tianshou.data import ReplayBufferManager, CachedReplayBuffer +from tianshou.data import VectorReplayBuffer, CachedReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -352,7 +352,7 @@ def test_hdf5(): def test_replaybuffermanager(): - buf = ReplayBufferManager([ReplayBuffer(size=5) for i in range(4)]) + buf = VectorReplayBuffer(20, 4) batch = Batch(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], done=[0, 0, 1]) ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2]) assert np.all(ep_len == [0, 0, 1]) and np.all(ep_rew == [0, 0, 3]) @@ -617,7 +617,7 @@ def test_multibuf_stack(): def test_multibuf_hdf5(): size = 100 buffers = { - "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]), + "vector": VectorReplayBuffer(size * 4, 4), "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) } buffer_types = {k: b.__class__ for k, b in buffers.items()} @@ -654,7 +654,7 @@ def test_multibuf_hdf5(): assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) - # check shallow copy in ReplayBufferManager + # check shallow copy in VectorReplayBuffer for k in ["vector", "cached"]: buffers[k].info.number.n[0] = -100 assert buffers[k].buffers[0].info.number.n[0] == -100 diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 1d7a68493..f32603d87 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -2,7 +2,8 @@ from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer, \ - ReplayBufferManager, CachedReplayBuffer + ReplayBufferManager, PrioritizedReplayBufferManager, \ + VectorReplayBuffer, PrioritizedVectorReplayBuffer, CachedReplayBuffer from tianshou.data.collector import Collector __all__ = [ @@ -14,6 +15,9 @@ "ReplayBuffer", "PrioritizedReplayBuffer", "ReplayBufferManager", + "PrioritizedReplayBufferManager", + "VectorReplayBuffer", + "PrioritizedVectorReplayBuffer", "CachedReplayBuffer", "Collector", ] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 70873b68d..32a9b9b6e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -577,6 +577,17 @@ def __init__( super().__init__(buffer_list) +class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): + def __init__( + self, total_size: int, buffer_num: int, **kwargs: Any + ) -> None: + sizes = [total_size // buffer_num + (i < total_size % buffer_num) + for i in range(buffer_num)] + buffer_list = [PrioritizedReplayBuffer(size, **kwargs) + for size in sizes] + super().__init__(buffer_list) + + class CachedReplayBuffer(ReplayBufferManager): """CachedReplayBuffer contains a given main buffer and n cached buffers, \ cached_buffer_num * ReplayBuffer(size=max_episode_length). diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 0c33c7f6e..6991254ae 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -11,8 +11,10 @@ from tianshou.exploration import BaseNoise from tianshou.data.batch import _create_value from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.data import Batch, ReplayBuffer, \ - ReplayBufferManager, CachedReplayBuffer, to_numpy +from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ + ReplayBufferManager, PrioritizedReplayBufferManager, \ + VectorReplayBuffer, PrioritizedVectorReplayBuffer, CachedReplayBuffer, \ + to_numpy class Collector(object): From 13a41fa42adfef184ce1ef5339b104a50fbdf3c0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 3 Feb 2021 13:31:47 +0800 Subject: [PATCH 044/104] ReplayBuffer now support add batch-style data when buffer_ids is not None --- test/base/test_buffer.py | 43 ++++++++--- tianshou/data/buffer.py | 77 ++++++++++++-------- tianshou/data/collector.py | 141 ++++++++++++++++--------------------- 3 files changed, 140 insertions(+), 121 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 3c2c3218d..3fd09590f 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -40,8 +40,9 @@ def test_replaybuffer(size=10, bufsize=20): b = ReplayBuffer(size=10) # neg bsz should return empty index assert b.sample_index(-1).tolist() == [] - b.add(Batch(obs=1, act=1, rew=1, done=1, - obs_next='str', info={'a': 3, 'b': {'c': 5.0}})) + ptr, ep_rew, ep_len, ep_idx = b.add( + Batch(obs=1, act=1, rew=1, done=1, obs_next='str', + info={'a': 3, 'b': {'c': 5.0}})) assert b.obs[0] == 1 assert b.done[0] assert b.obs_next[0] == 'str' @@ -51,13 +52,29 @@ def test_replaybuffer(size=10, bufsize=20): assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact assert np.all(b.info.b.c[1:] == 0.0) + assert ptr.shape == (1,) and ptr[0] == 0 + assert ep_rew.shape == (1,) and ep_rew[0] == 1 + assert ep_len.shape == (1,) and ep_len[0] == 1 + assert ep_idx.shape == (1,) and ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically - b.add(Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", - info={"a": 4, "d": {"e": -np.inf}})) + batch = Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", + info={"a": 4, "d": {"e": -np.inf}}) + b.add(batch) info_keys = ["a", "b", "d"] assert set(b.info.keys()) == set(info_keys) assert b.info.a[1] == 4 and b.info.b.c[1] == 0 assert b.info.d.e[1] == -np.inf + # test batch-style adding method, where len(batch) == 1 + batch.done = 1 + batch.info.e = np.zeros([1, 4]) + batch = Batch.stack([batch]) + ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) + assert ptr.shape == (1,) and ptr[0] == 2 + assert ep_rew.shape == (1,) and ep_rew[0] == 4 + assert ep_len.shape == (1,) and ep_len[0] == 2 + assert ep_idx.shape == (1,) and ep_idx[0] == 1 + assert set(b.info.keys()) == set(info_keys + ["e"]) + assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): b[22] @@ -138,8 +155,9 @@ def test_priortized_replaybuffer(size=32, bufsize=15): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(Batch(obs=obs, act=a, rew=rew, done=done, obs_next=obs_next, - info=info, policy=np.random.randn() - 0.5)) + batch = Batch(obs=obs, act=a, rew=rew, done=done, obs_next=obs_next, + info=info, policy=np.random.randn() - 0.5) + buf.add(Batch.stack([batch]), buffer_ids=[0]) obs = obs_next data, indice = buf.sample(len(buf) // 2) if len(buf) // 2 == 0: @@ -147,6 +165,9 @@ def test_priortized_replaybuffer(size=32, bufsize=15): else: assert len(data) == len(buf) // 2 assert len(buf) == min(bufsize, i + 1) + assert buf.info.key.shape == (buf.maxsize,) + assert buf.rew.dtype == np.float + assert buf.done.dtype == np.bool_ data, indice = buf.sample(len(buf) // 2) buf.update_weight(indice, -data.weight / 2) assert np.allclose( @@ -450,7 +471,7 @@ def test_cachedbuffer(): assert buf.sample_index(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[1], act=[1], rew=[1], done=[0]), cached_buffer_ids=[1]) + Batch(obs=[1], act=[1], rew=[1], done=[0]), buffer_ids=[1]) obs = np.zeros(buf.maxsize) obs[15] = 1 indice = buf.sample_index(0) @@ -461,7 +482,7 @@ def test_cachedbuffer(): assert np.all(ep_len == [0]) and np.all(ep_rew == [0.0]) assert np.all(ptr == [15]) and np.all(ep_idx == [15]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[2], act=[2], rew=[2], done=[1]), cached_buffer_ids=[3]) + Batch(obs=[2], act=[2], rew=[2], done=[1]), buffer_ids=[3]) obs[[0, 25]] = 2 indice = buf.sample_index(0) assert np.allclose(indice, [0, 15]) @@ -474,7 +495,7 @@ def test_cachedbuffer(): assert np.allclose(buf.sample_index(0), [0, 15]) ptr, ep_rew, ep_len, ep_idx = buf.add( Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), - cached_buffer_ids=[3, 1]) + buffer_ids=[3, 1]) assert np.all(ep_len == [0, 2]) and np.all(ep_rew == [0, 5.0]) assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] @@ -606,7 +627,7 @@ def test_multibuf_stack(): cached_num, size) obs = np.random.rand(size, 4, 84, 84) buf6.add(Batch(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], - obs_next=[obs[3], obs[1]]), cached_buffer_ids=[1, 2]) + obs_next=[obs[3], obs[1]]), buffer_ids=[1, 2]) assert buf6.obs.shape == (buf6.maxsize, 84, 84) assert np.allclose(buf6.obs[0], obs[0, -1]) assert np.allclose(buf6.obs[14], obs[2, -1]) @@ -634,7 +655,7 @@ def test_multibuf_hdf5(): buffers["vector"].add(Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]) buffers["cached"].add(Batch.stack([kwargs, kwargs, kwargs]), - cached_buffer_ids=[0, 1, 2]) + buffer_ids=[0, 1, 2]) # save paths = {} diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 32a9b9b6e..c09feb210 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -197,10 +197,18 @@ def _add_index( return ptr, self._ep_rew * 0.0, 0, self._ep_idx def add( - self, batch: Batch - ) -> Tuple[int, Union[float, np.ndarray], int, int]: + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. + :param Batch batch: the input data batch. Its keys must belong to the 7 + reserved keys, and "obs", "act", "rew", "done" is required. + :param buffer_ids: to make consistent with other buffer's add function; + if it is not None, we assume the input batch's first dimension is + always 1. + Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. @@ -208,23 +216,33 @@ def add( # preprocess batch assert set(batch.keys()).issubset(self._reserved_keys) assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) + stacked_batch = buffer_ids is not None + if stacked_batch: + assert len(batch) == 1 if self._save_only_last_obs: - batch.obs = batch.obs[-1] - batch.rew = batch.rew.astype(np.float) - batch.done = batch.done.astype(np.bool_) + batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: - batch.obs_next = batch.obs_next[-1] + batch.obs_next = \ + batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] # get ptr - ptr, ep_rew, ep_len, ep_idx = self._add_index(batch.rew, batch.done) + if stacked_batch: + rew, done = batch.rew[0], batch.done[0] + else: + rew, done = batch.rew, batch.done + ptr, ep_rew, ep_len, ep_idx = list(map( + lambda x: np.array([x]), self._add_index(rew, done))) try: self._meta[ptr] = batch except ValueError: + stack = not stacked_batch + batch.rew = batch.rew.astype(np.float) + batch.done = batch.done.astype(np.bool_) if self._meta.is_empty(): - self._meta = _create_value(batch, self.maxsize) + self._meta = _create_value(batch, self.maxsize, stack) else: # dynamic key pops up in batch - _alloc_by_keys_diff(self._meta, batch, self.maxsize) + _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) self._meta[ptr] = batch return ptr, ep_rew, ep_len, ep_idx @@ -360,9 +378,11 @@ def update(self, buffer: ReplayBuffer) -> np.ndarray: self.init_weight(indices) def add( - self, batch: Batch - ) -> Tuple[int, Union[float, np.ndarray], int, int]: - ptr, ep_rew, ep_len, ep_idx = super().add(batch) + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids) self.init_weight(ptr) return ptr, ep_rew, ep_len, ep_idx @@ -500,8 +520,6 @@ def add( assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] - batch.rew = batch.rew.astype(np.float) - batch.done = batch.done.astype(np.bool_) if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: @@ -521,6 +539,8 @@ def add( try: self._meta[ptrs] = batch except ValueError: + batch.rew = batch.rew.astype(np.float) + batch.done = batch.done.astype(np.bool_) if self._meta.is_empty(): self._meta = _create_value(batch, self.maxsize, stack=False) else: # dynamic key pops up in batch @@ -632,36 +652,35 @@ def __init__( def add( self, batch: Batch, - cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. Each of the data's length (first dimension) must equal to the length of - cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ..., - cached_buffer_num - 1]. + buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num + - 1]. Return (current_index, episode_reward, episode_length, - episode_start_index) with each of the shape (len(cached_buffer_ids), - ...), where (current_index[i], episode_reward[i], episode_length[i], + episode_start_index) with each of the shape (len(buffer_ids), ...), + where (current_index[i], episode_reward[i], episode_length[i], episode_start_index[i]) refers to the cached_buffer_ids[i]th cached buffer's corresponding episode result. """ - if cached_buffer_ids is None: - cached_buffer_ids = np.arange(self.cached_buffer_num) + if buffer_ids is None: + buffer_ids = np.arange(1, 1 + self.cached_buffer_num) else: # make sure it is np.ndarray - cached_buffer_ids = np.asarray(cached_buffer_ids) - # in self.buffers, the first buffer is main_buffer - buffer_ids = cached_buffer_ids + 1 # type: ignore + buffer_ids = np.asarray(buffer_ids) + 1 ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buffer_ids) # find the terminated episode, move data from cached buf to main buf updated_ptr, updated_ep_idx = [], [] - for buffer_idx in cached_buffer_ids[batch.done]: - index = self.main_buffer.update(self.cached_buffers[buffer_idx]) + done = batch.done.astype(np.bool_) + for buffer_idx in buffer_ids[done]: + index = self.main_buffer.update(self.buffers[buffer_idx]) if len(index) == 0: # unsuccessful move, replace with -1 index = [-1] updated_ep_idx.append(index[0]) updated_ptr.append(index[-1]) - self.cached_buffers[buffer_idx].reset() - ptr[batch.done] = updated_ptr - ep_idx[batch.done] = updated_ep_idx + self.buffers[buffer_idx].reset() + ptr[done] = updated_ptr + ep_idx[done] = updated_ep_idx return ptr, ep_rew, ep_len, ep_idx diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 6991254ae..3feb66dc1 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -12,13 +12,11 @@ from tianshou.data.batch import _create_value from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ - ReplayBufferManager, PrioritizedReplayBufferManager, \ - VectorReplayBuffer, PrioritizedVectorReplayBuffer, CachedReplayBuffer, \ - to_numpy + ReplayBufferManager, CachedReplayBuffer, to_numpy class Collector(object): - #TODO change doc + # TODO change doc """Collector enables the policy to interact with different types of envs. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` @@ -79,16 +77,15 @@ class Collector(object): Please make sure the given environment has a time limitation. """ + def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, - training: bool = False, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: - # TODO determine whether we need start_idxs # TODO update training in all test/examples, remove action noise super().__init__() if not isinstance(env, BaseVectorEnv): @@ -97,7 +94,7 @@ def __init__( assert env.is_async is False self.env = env self.env_num = len(env) - self.training = training + self._save_data = buffer is not None self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn @@ -107,37 +104,27 @@ def __init__( self.reset() def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: - if not hasattr(self.env, "_max_episode_steps"): - warnings.warn("No time limit found in given env, set to 100000.") - max_episode_steps = 100000 - else: - max_episode_steps = self.env._max_episode_steps[0] if buffer is None: - self.buffer = CachedReplayBuffer( - ReplayBuffer(0), self.env_num, max_episode_steps) + self.buffer = VectorReplayBuffer(self.env_num, self.env_num) elif isinstance(buffer, ReplayBufferManager): - if type(self.buffer) == ReplayBufferManager: - if self.training: - warnings.warn("ReplayBufferManager is not suggested to be used" - "in training mode, make sure you know how it works.") - self.buffer.cached_buffers = self.buffer.buffers - self.buffer.cached_buffer_num = self.buffer.buffer_num - assert self.buffer.cached_buffer_num == self.env_num - if self.buffer.cached_buffers[0].maxsize < max_episode_steps: - warnings.warn( - "The size of cached_buffer is suggested to be larger than " - "max episode length. Otherwise you might" - "loss data of episodes you just collected, and statistics " - "might even be incorrect.", Warning) - else: # type ReplayBuffer + assert self.buffer.buffer_num >= self.env_num + if isinstance(buffer, CachedReplayBuffer): + assert self.buffer.cached_buffer_num >= self.env_num + else: # ReplayBuffer or PrioritizedReplayBuffer assert self.buffer.maxsize > 0 - if self.env_num != 1: - warnings.warn( - "CachedReplayBuffer/ReplayBufferManager rather than ReplayBuffer" - "is required in collector when #env > 1. Input buffer is switched" - "to CachedReplayBuffer.", Warning) - self.buffer = CachedReplayBuffer(buffer, - self.env_num, max_episode_steps) + if self.env_num > 1: + if type(buffer) == ReplayBuffer: + buffer_type = "ReplayBuffer" + vector_type = "VectorReplayBuffer" + else: + buffer_type = "PrioritizedReplayBuffer" + vector_type = "PrioritizedVectorReplayBuffer" + raise TypeError( + f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to " + f"collect {self.env_num} envs, please use {vector_type}(" + f"total_size={buffer.maxsize}, buffer_num={self.env_num}, " + "...) instead.") + self.buffer = buffer @staticmethod def _default_rew_metric( @@ -147,16 +134,15 @@ def _default_rew_metric( # for multi-agent RL, a reward_metric must be provided assert np.asanyarray(x).size == 1, ( "Please specify the reward_metric " - "since the reward is not a scalar." - ) + "since the reward is not a scalar.") return x def reset(self) -> None: """Reset all related variables in the collector.""" - # use empty Batch for ``state`` so that ``self.data`` supports slicing + # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy - self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, - obs_next={}, policy={}) + self.data = Batch(obs={}, act={}, rew={}, done={}, obs_next={}, + info={}, policy={}) self.reset_env() self.reset_buffer() self.reset_stat() @@ -166,7 +152,7 @@ def reset_stat(self) -> None: self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 def reset_buffer(self) -> None: - """Reset the main data buffer.""" + """Reset the data buffer.""" self.buffer.reset() def reset_env(self) -> None: @@ -175,7 +161,7 @@ def reset_env(self) -> None: if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) self.data.obs = obs - if hasattr(self.buffer, "cached_buffers"): + if isinstance(self.buffer, CachedReplayBuffer): for buf in self.buffer.cached_buffers: buf.reset() @@ -188,7 +174,7 @@ def _reset_state(self, id: Union[int, List[int]]) -> None: state[id] = None if state.dtype == np.object else 0 elif isinstance(state, Batch): state.empty_(id) - + def collect( self, n_step: Optional[int] = None, @@ -197,20 +183,16 @@ def collect( render: Optional[float] = None, no_grad: bool = True, ) -> Dict[str, float]: - #TODO doc update """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. If it is an - int, it means to collect at lease ``n_episode`` episodes; if it is - a list, it means to collect exactly ``n_episode[i]`` episodes in - the i-th environment - :param bool random: whether to use random policy for collecting data, - defaults to False. + :param n_episode: how many episodes you want to collect. + :param bool random: whether to use random policy for collecting data. + Default to False. :param float render: the sleep time between rendering consecutive - frames, defaults to None (no rendering). - :param bool no_grad: whether to retain gradient in policy.forward, - defaults to True (no gradient retaining). + frames. Default to None (no rendering). + :param bool no_grad: whether to retain gradient in policy.forward. + Default to True (no gradient retaining). .. note:: @@ -221,17 +203,21 @@ def collect( * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. - * ``rew`` the mean reward over collected episodes. - * ``len`` the mean length over collected episodes. + * ``rews`` the list of episode reward over collected episodes. + * ``lens`` the list of episode length over collected episodes. """ - #collect at least n_step or n_episode + # collect at least n_step or n_episode if n_step is not None: - assert n_episode is None, "Only one of n_step or n_episode is allowed " - f"in Collector.collect, got n_step = {n_step}, n_episode = {n_episode}." - assert n_step > 0 and n_step % self.env_num == 0, \ - "n_step must not be 0, and should be an integral multiple of #envs" + assert n_episode is None, ( + "Only one of n_step or n_episode is allowed in " + f"Collector.collect, got n_step = {n_step}, " + f"n_episode = {n_episode}.") + assert n_step > 0 + assert n_step % self.env_num == 0, \ + "n_step should be a multiple of #envs" else: assert isinstance(n_episode, int) and n_episode > 0 + start_time = time.time() step_count = 0 @@ -239,35 +225,27 @@ def collect( episode_count = 0 episode_rews = [] episode_lens = [] - # start_idxs = [] - cached_buffer_ids = [i for i in range(self.env_num)] + episode_start_indices = [] + buffer_ids = list(range(self.env_num)) while True: - if step_count >= 100000 and episode_count == 0: - warnings.warn( - "There are already many steps in an episode. " - "You should add a time limitation to your environment!", - Warning) - # restore the state and the input data last_state = self.data.state if isinstance(last_state, Batch) and last_state.is_empty(): last_state = None - # calculate the next action and update state, act & policy into self.data + # calc the next action / update state / act / policy into self.data if random: - spaces = self._action_space - result = Batch( - act=[spaces[i].sample() for i in range(self.env_num)]) + result = Batch(act=[a.sample() for a in self._action_space]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version - # self.data.obs will be used by agent to get result(mainly action) + # self.data.obs will be used by agent to get result result = self.policy(self.data, last_state) else: result = self.policy(self.data, last_state) - state = result.get("state", Batch()) + state = result.get("state", None) policy = result.get("policy", Batch()) act = to_numpy(result.act) if state is None: @@ -276,16 +254,17 @@ def collect( if not (isinstance(state, Batch) and state.is_empty()): # save hidden state to policy._state, in order to save into buffer policy._state = state - + # TODO discuss and change policy's add_exp_noise behavior if self.training and not random and hasattr(self.policy, 'add_exp_noise'): act = self.policy.add_exp_noise(act) - self.data.update(state=state, policy = policy, act = act) + self.data.update(state=state, policy=policy, act=act) # step in env obs_next, rew, done, info = self.env.step(act) - result = {"obs_next":obs_next, "rew":rew, "done":done, "info":info} + result = {"obs_next": obs_next, + "rew": rew, "done": done, "info": info} if self.preprocess_fn: result = self.preprocess_fn(**result) # type: ignore @@ -299,12 +278,12 @@ def collect( # add data into the buffer data_t = self.data if n_episode and len(cached_buffer_ids) < self.env_num: - data_t = self.data[cached_buffer_ids] + data_t = self.data[cached_buffer_ids] if type(self.buffer) == ReplayBuffer: data_t = data_t[0] # lens, rews, idxs = self.buffer.add(**data_t, index = cached_buffer_ids) # rews need to be array for ReplayBuffer - lens, rews = self.buffer.add(**data_t, index = cached_buffer_ids) + lens, rews = self.buffer.add(**data_t, index=cached_buffer_ids) if type(self.buffer) == ReplayBuffer: lens = np.asarray(lens) rews = np.asarray(rews) @@ -337,7 +316,7 @@ def collect( if (n_step and step_count >= n_step) or \ (n_episode and episode_count >= n_episode): - break + break # generate the statistics self.collect_step += step_count @@ -351,4 +330,4 @@ def collect( "rews": np.array(episode_rews), "lens": np.array(episode_lens), # "idxs": np.array(start_idxs) - } \ No newline at end of file + } From 20f254977adb4e8a968bc6bf8ba446d2d5824128 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 4 Feb 2021 09:43:30 +0800 Subject: [PATCH 045/104] finish step collector, time to write unittest --- test/base/test_collector.py | 17 ++-- tianshou/data/buffer.py | 10 +- tianshou/data/collector.py | 176 +++++++++++++++++------------------- tianshou/policy/base.py | 6 ++ 4 files changed, 106 insertions(+), 103 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index ec799e640..31702402d 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -43,15 +43,16 @@ def __init__(self, writer): def preprocess_fn(self, **kwargs): # modify info before adding into the buffer, and recorded into tfb # if only obs exist -> reset - # if obs/act/rew/done/... exist -> normal step + # if obs_next/rew/done/info exist -> normal step if 'rew' in kwargs: - n = len(kwargs['obs']) + n = len(kwargs['rew']) info = kwargs['info'] for i in range(n): info[i].update(rew=kwargs['rew'][i]) - if 'key' in info.keys(): - self.writer.add_scalar('key', np.mean( - info['key']), global_step=self.cnt) + if 'key' in info[0].keys(): + key = np.array([i['key'] for i in info]) + self.writer.add_scalar( + 'key', np.mean(key), global_step=self.cnt) self.cnt += 1 return Batch(info=info) else: @@ -61,7 +62,7 @@ def preprocess_fn(self, **kwargs): def single_preprocess_fn(**kwargs): # same as above, without tfb if 'rew' in kwargs: - n = len(kwargs['obs']) + n = len(kwargs['rew']) info = kwargs['info'] for i in range(n): info[i].update(rew=kwargs['rew'][i]) @@ -82,8 +83,8 @@ def test_collector(): c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), logger.preprocess_fn) c0.collect(n_step=3) - assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 1]) - assert np.allclose(c0.buffer[:4].obs_next[..., 0], [1, 2, 1, 2]) + assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) + assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) c0.collect(n_episode=3) assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) assert np.allclose(c0.buffer[:10].obs_next[..., 0], diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index c09feb210..ebf34b34e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -214,7 +214,10 @@ def add( of episode_length and episode_reward is 0. """ # preprocess batch - assert set(batch.keys()).issubset(self._reserved_keys) + b = Batch() + for key in set(self._reserved_keys).intersection(batch.keys()): + b.__dict__[key] = batch[key] + batch = b assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) stacked_batch = buffer_ids is not None if stacked_batch: @@ -516,7 +519,10 @@ def add( of episode_length and episode_reward is 0. """ # preprocess batch - assert set(batch.keys()).issubset(self._reserved_keys) + b = Batch() + for key in set(self._reserved_keys).intersection(batch.keys()): + b.__dict__[key] = batch[key] + batch = b assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 3feb66dc1..e9d2db080 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -60,8 +60,6 @@ class Collector(object): # collect 3 episodes collector.collect(n_episode=3) - # collect 1 episode for the first env, 3 for the third env - collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the @@ -84,22 +82,21 @@ def __init__( env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, - reward_metric: Optional[Callable[[np.ndarray], float]] = None, + training: bool = True, ) -> None: # TODO update training in all test/examples, remove action noise super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) - # TODO support or seperate async - assert env.is_async is False + assert env.is_async is False, "Please use AsyncCollector if ..." self.env = env self.env_num = len(env) + self.training = training self._save_data = buffer is not None self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn self._action_space = env.action_space - self._rew_metric = reward_metric or BasicCollector._default_rew_metric # avoid creating attribute outside __init__ self.reset() @@ -107,11 +104,12 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: if buffer is None: self.buffer = VectorReplayBuffer(self.env_num, self.env_num) elif isinstance(buffer, ReplayBufferManager): - assert self.buffer.buffer_num >= self.env_num + assert buffer.buffer_num >= self.env_num if isinstance(buffer, CachedReplayBuffer): - assert self.buffer.cached_buffer_num >= self.env_num + assert buffer.cached_buffer_num >= self.env_num + self.buffer = buffer else: # ReplayBuffer or PrioritizedReplayBuffer - assert self.buffer.maxsize > 0 + assert buffer.maxsize > 0 if self.env_num > 1: if type(buffer) == ReplayBuffer: buffer_type = "ReplayBuffer" @@ -126,23 +124,24 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: "...) instead.") self.buffer = buffer - @staticmethod - def _default_rew_metric( - x: Union[Number, np.number] - ) -> Union[Number, np.number]: - # this internal function is designed for single-agent RL - # for multi-agent RL, a reward_metric must be provided - assert np.asanyarray(x).size == 1, ( - "Please specify the reward_metric " - "since the reward is not a scalar.") - return x + # TODO move to trainer + # @staticmethod + # def _default_rew_metric( + # x: Union[Number, np.number] + # ) -> Union[Number, np.number]: + # # this internal function is designed for single-agent RL + # # for multi-agent RL, a reward_metric must be provided + # assert np.asanyarray(x).size == 1, ( + # "Please specify the reward_metric " + # "since the reward is not a scalar.") + # return x def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy - self.data = Batch(obs={}, act={}, rew={}, done={}, obs_next={}, - info={}, policy={}) + self.data = Batch(obs={}, act={}, rew={}, done={}, + obs_next={}, info={}, policy={}) self.reset_env() self.reset_buffer() self.reset_stat() @@ -157,6 +156,7 @@ def reset_buffer(self) -> None: def reset_env(self) -> None: """Reset all of the environment(s)' states and the cache buffers.""" + self._ready_env_ids = np.arange(self.env_num) obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) @@ -167,13 +167,14 @@ def reset_env(self) -> None: def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" - state = self.data.state # it is a reference - if isinstance(state, torch.Tensor): - state[id].zero_() - elif isinstance(state, np.ndarray): - state[id] = None if state.dtype == np.object else 0 - elif isinstance(state, Batch): - state.empty_(id) + if hasattr(self.data.policy, "hidden_state"): + state = self.data.policy.hidden_state # it is a reference + if isinstance(state, torch.Tensor): + state[id].zero_() + elif isinstance(state, np.ndarray): + state[id] = None if state.dtype == np.object else 0 + elif isinstance(state, Batch): + state.empty_(id) def collect( self, @@ -205,13 +206,13 @@ def collect( * ``n/st`` the collected number of steps. * ``rews`` the list of episode reward over collected episodes. * ``lens`` the list of episode length over collected episodes. + * ``idxs`` the list of episode start index over collected episodes. """ # collect at least n_step or n_episode if n_step is not None: assert n_episode is None, ( - "Only one of n_step or n_episode is allowed in " - f"Collector.collect, got n_step = {n_step}, " - f"n_episode = {n_episode}.") + "Only one of n_step or n_episode is allowed in Collector." + f"collect, got n_step={n_step}, n_episode={n_episode}.") assert n_step > 0 assert n_step % self.env_num == 0, \ "n_step should be a multiple of #envs" @@ -226,17 +227,16 @@ def collect( episode_rews = [] episode_lens = [] episode_start_indices = [] - buffer_ids = list(range(self.env_num)) while True: - # restore the state and the input data - last_state = self.data.state - if isinstance(last_state, Batch) and last_state.is_empty(): - last_state = None + assert len(self.data) == len(self._ready_env_ids) + # restore the state: if the last state is None, it won't store + last_state = self.data.policy.pop("hidden_state", None) - # calc the next action / update state / act / policy into self.data + # get the next action if random: - result = Batch(act=[a.sample() for a in self._action_space]) + result = Batch(act=[self._action_space[i].sample() + for i in self._ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -244,90 +244,80 @@ def collect( result = self.policy(self.data, last_state) else: result = self.policy(self.data, last_state) - - state = result.get("state", None) + # update state / act / policy into self.data policy = result.get("policy", Batch()) + assert isinstance(policy, Batch) + state = result.get("state", None) + if state is not None: + policy.hidden_state = state # save state into buffer act = to_numpy(result.act) - if state is None: - # convert None to Batch(), since None is reserved for 0-init - state = Batch() - if not (isinstance(state, Batch) and state.is_empty()): - # save hidden state to policy._state, in order to save into buffer - policy._state = state - - # TODO discuss and change policy's add_exp_noise behavior - if self.training and not random and hasattr(self.policy, 'add_exp_noise'): - act = self.policy.add_exp_noise(act) - self.data.update(state=state, policy=policy, act=act) + if self.training and not random: + act = self.policy.exploration_noise(act) + self.data.update(policy=policy, act=act) # step in env obs_next, rew, done, info = self.env.step(act) - result = {"obs_next": obs_next, - "rew": rew, "done": done, "info": info} + result = {"obs_next": obs_next, "rew": rew, + "done": done, "info": info} if self.preprocess_fn: - result = self.preprocess_fn(**result) # type: ignore - - # update obs_next, rew, done, & info into self.data + result.update(self.preprocess_fn(**result)) self.data.update(result) if render: self.env.render() - time.sleep(render) + if render > 0 and not np.isclose(render, 0): + time.sleep(render) # add data into the buffer - data_t = self.data - if n_episode and len(cached_buffer_ids) < self.env_num: - data_t = self.data[cached_buffer_ids] - if type(self.buffer) == ReplayBuffer: - data_t = data_t[0] - # lens, rews, idxs = self.buffer.add(**data_t, index = cached_buffer_ids) - # rews need to be array for ReplayBuffer - lens, rews = self.buffer.add(**data_t, index=cached_buffer_ids) - if type(self.buffer) == ReplayBuffer: - lens = np.asarray(lens) - rews = np.asarray(rews) - # idxs = np.asarray(idxs) + ptr, ep_rew, ep_len, ep_idx = self.buffer.add( + self.data, buffer_ids=self._ready_env_ids) + # collect statistics - step_count += len(cached_buffer_ids) - for i in cached_buffer_ids(np.where(lens == 0)[0]): - episode_count += 1 - episode_lens.append(lens[i]) - episode_rews.append(self._rew_metric(rews[i])) - # start_idxs.append(idxs[i]) - - if sum(done): - finised_env_ind = np.where(done)[0] - # now we copy obs_next to obs, but since there might be finished episodes, - # we have to reset finished envs first. - obs_reset = self.env.reset(finised_env_ind) + step_count += len(self._ready_env_ids) + + if np.any(done): + env_ind_local = np.where(done)[0] + episode_count += len(env_ind_local) + episode_lens.append(ep_len[env_ind_local]) + episode_rews.append(ep_rew[env_ind_local]) + episode_start_indices.append(ep_idx[env_ind_local]) + + # now we copy obs_next to obs, but since there might be + # finished episodes, we have to reset finished envs first. + env_ind_global = self._ready_env_ids[env_ind_local] + obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn( obs=obs_reset).get("obs", obs_reset) - self.data.obs_next[finised_env_ind] = obs_reset - for i in finised_env_ind: + self.data.obs_next[env_ind_local] = obs_reset + for i in env_ind_local: self._reset_state(i) - if n_episode and n_episode - episode_count < self.env_num: - try: - cached_buffer_ids.remove(i) - except ValueError: - pass - self.data.obs[:] = self.data.obs_next + if n_episode and n_episode - episode_count < self.env_num: + mask = ~np.isin(self._ready_env_ids, env_ind_global) + self._ready_env_ids = self._ready_env_ids[mask] + self.data = self.data[mask] + + self.data.obs = self.data.obs_next if (n_step and step_count >= n_step) or \ (n_episode and episode_count >= n_episode): break - # generate the statistics + # generate statistics self.collect_step += step_count self.collect_episode += episode_count self.collect_time += max(time.time() - start_time, 1e-9) + if n_episode: + self.data = Batch(obs={}, act={}, rew={}, done={}, + obs_next={}, info={}, policy={}) self.reset_env() + return { "n/ep": episode_count, "n/st": step_count, - "rews": np.array(episode_rews), - "lens": np.array(episode_lens), - # "idxs": np.array(start_idxs) + "rews": np.concatenate(episode_rews), + "lens": np.concatenate(episode_lens), + "idxs": np.concatenate(episode_start_indices), } diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 99e16544b..9a4dfbd8e 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -67,6 +67,12 @@ def set_agent_id(self, agent_id: int) -> None: """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id + def exploration_noise( + self, act: Union[np.ndarray, Batch] + ) -> Union[np.ndarray, Batch]: + """Modify the action from policy.forward with exploration noise.""" + return act + @abstractmethod def forward( self, From 03a1d2fa558e37c7cb2d00bc52f1490724e12e2c Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 4 Feb 2021 11:23:14 +0800 Subject: [PATCH 046/104] add test --- test/base/test_collector.py | 88 ++++++++++++++++++++----------------- tianshou/data/buffer.py | 5 ++- tianshou/data/collector.py | 35 +++++++++------ 3 files changed, 72 insertions(+), 56 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 31702402d..d4f77eca7 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,9 +1,10 @@ +import pytest import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.policy import BasePolicy from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.data import Collector, Batch, ReplayBuffer +from tianshou.data import Collector, Batch, ReplayBuffer, VectorReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -12,7 +13,7 @@ class MyPolicy(BasePolicy): - def __init__(self, dict_state: bool = False, need_state: bool = True): + def __init__(self, dict_state=False, need_state=True): """ :param bool dict_state: if the observation of the environment is a dict :param bool need_state: if the policy needs the hidden state (for RNN) @@ -45,14 +46,11 @@ def preprocess_fn(self, **kwargs): # if only obs exist -> reset # if obs_next/rew/done/info exist -> normal step if 'rew' in kwargs: - n = len(kwargs['rew']) info = kwargs['info'] - for i in range(n): - info[i].update(rew=kwargs['rew'][i]) - if 'key' in info[0].keys(): - key = np.array([i['key'] for i in info]) + info.rew = kwargs['rew'] + if 'key' in info.keys(): self.writer.add_scalar( - 'key', np.mean(key), global_step=self.cnt) + 'key', np.mean(info.key), global_step=self.cnt) self.cnt += 1 return Batch(info=info) else: @@ -62,10 +60,8 @@ def preprocess_fn(self, **kwargs): def single_preprocess_fn(**kwargs): # same as above, without tfb if 'rew' in kwargs: - n = len(kwargs['rew']) info = kwargs['info'] - for i in range(n): - info[i].update(rew=kwargs['rew'][i]) + info.rew = kwargs['rew'] return Batch(info=info) else: return Batch() @@ -80,40 +76,52 @@ def test_collector(): dum = DummyVectorEnv(env_fns) policy = MyPolicy() env = env_fns[0]() - c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), - logger.preprocess_fn) + c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn) c0.collect(n_step=3) + assert len(c0.buffer) == 3 assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) c0.collect(n_episode=3) - assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) - assert np.allclose(c0.buffer[:10].obs_next[..., 0], - [1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) + assert len(c0.buffer) == 8 + assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) + assert np.allclose(c0.buffer[:].obs_next[..., 0], + [1, 2, 1, 2, 1, 2, 1, 2]) c0.collect(n_step=3, random=True) - c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), - logger.preprocess_fn) - c1.collect(n_step=6) - assert np.allclose(c1.buffer.obs[:11, 0], - [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) - assert np.allclose(c1.buffer[:11].obs_next[..., 0], - [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) - c1.collect(n_episode=2) - assert np.allclose(c1.buffer.obs[11:21, 0], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) - assert np.allclose(c1.buffer[11:21].obs_next[..., 0], - [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) - c1.collect(n_episode=3, random=True) - c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False), - logger.preprocess_fn) - c2.collect(n_episode=[1, 2, 2, 2]) + c1 = Collector( + policy, venv, + VectorReplayBuffer(total_size=100, buffer_num=4), + logger.preprocess_fn) + with pytest.raises(AssertionError): + c1.collect(n_step=6) + c1.collect(n_step=8) + obs = np.zeros(100) + obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1] + + assert np.allclose(c1.buffer.obs[:, 0], obs) + assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + with pytest.raises(AssertionError): + c1.collect(n_episode=2) + c1.collect(n_episode=4) + assert len(c1.buffer) == 16 + obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] + assert np.allclose(c1.buffer.obs[:, 0], obs) + assert np.allclose(c1.buffer[:].obs_next[..., 0], + [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) + c1.collect(n_episode=4, random=True) + c2 = Collector( + policy, dum, + VectorReplayBuffer(total_size=100, buffer_num=4), + logger.preprocess_fn) + c2.collect(n_episode=7) assert np.allclose(c2.buffer.obs_next[:26, 0], [ 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) c2.reset_env() - c2.collect(n_episode=[2, 2, 2, 2]) + c2.collect(n_episode=8) assert np.allclose(c2.buffer.obs_next[26:54, 0], [ 1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) - c2.collect(n_episode=[1, 1, 1, 1], random=True) + c2.collect(n_episode=4, random=True) def test_collector_with_exact_episodes(): @@ -150,9 +158,9 @@ def test_collector_with_async(): venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MyPolicy() - c1 = Collector(policy, venv, - ReplayBuffer(size=1000, ignore_obs_next=False), - logger.preprocess_fn) + c1 = Collector( + policy, venv, ReplayBuffer(size=1000, ignore_obs_next=False), + logger.preprocess_fn) c1.collect(n_episode=10) # check if the data in the buffer is chronological # i.e. data in the buffer are full episodes, and each episode is @@ -217,12 +225,10 @@ def test_collector_with_dict_state(): def test_collector_with_ma(): - def reward_metric(x): - return x.sum() env = MyTestEnv(size=5, sleep=0, ma_rew=4) policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), - Logger.single_preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn) # n_step=3 will collect a full episode r = c0.collect(n_step=3)['rews'].mean() assert np.asanyarray(r).size == 1 and r == 4. @@ -232,7 +238,7 @@ def reward_metric(x): for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) c1 = Collector(policy, envs, ReplayBuffer(size=100), - Logger.single_preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn) r = c1.collect(n_step=10)['rews'].mean() assert np.asanyarray(r).size == 1 and r == 4. r = c1.collect(n_episode=[2, 1, 1, 2])['rews'].mean() @@ -250,7 +256,7 @@ def reward_metric(x): assert np.allclose(c0.buffer[:len(c0.buffer)].rew, [[x] * 4 for x in rew]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - Logger.single_preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn) r = c2.collect(n_episode=[0, 0, 0, 10])['rews'].mean() assert np.asanyarray(r).size == 1 and r == 4. batch, _ = c2.buffer.sample(10) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index ebf34b34e..edfd9e5e8 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -331,7 +331,10 @@ def __getitem__( shape (batch, len, ...). """ if isinstance(index, slice): # change slice to np array - index = self._indices[:len(self)][index] + if index == slice(None): # buffer[:] will get all available data + index = self.sample_index(0) + else: + index = self._indices[:len(self)][index] # raise KeyError first instead of AttributeError, to support np.array obs = self.get(index, "obs") if self._save_obs_next: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index e9d2db080..1580b3907 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -119,9 +119,9 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: vector_type = "PrioritizedVectorReplayBuffer" raise TypeError( f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to " - f"collect {self.env_num} envs, please use {vector_type}(" - f"total_size={buffer.maxsize}, buffer_num={self.env_num}, " - "...) instead.") + f"collect {self.env_num} envs,\n\tplease use {vector_type}" + f"(total_size={buffer.maxsize}, buffer_num={self.env_num}," + " ...) instead.") self.buffer = buffer # TODO move to trainer @@ -209,6 +209,7 @@ def collect( * ``idxs`` the list of episode start index over collected episodes. """ # collect at least n_step or n_episode + # TODO: modify docs, tell the constraints if n_step is not None: assert n_episode is None, ( "Only one of n_step or n_episode is allowed in Collector." @@ -217,7 +218,7 @@ def collect( assert n_step % self.env_num == 0, \ "n_step should be a multiple of #envs" else: - assert isinstance(n_episode, int) and n_episode > 0 + assert isinstance(n_episode, int) and n_episode >= self.env_num start_time = time.time() @@ -256,13 +257,17 @@ def collect( self.data.update(policy=policy, act=act) # step in env - obs_next, rew, done, info = self.env.step(act) + obs_next, rew, done, info = self.env.step( + act, id=self._ready_env_ids) - result = {"obs_next": obs_next, "rew": rew, - "done": done, "info": info} + self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: - result.update(self.preprocess_fn(**result)) - self.data.update(result) + self.data.update(self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + )) if render: self.env.render() @@ -314,10 +319,12 @@ def collect( obs_next={}, info={}, policy={}) self.reset_env() + rews = np.concatenate(episode_rews) if episode_rews else np.array([]) + lens = np.concatenate(episode_lens) if episode_lens else np.array([]) + idxs = np.concatenate(episode_start_indices) if episode_start_indices \ + else np.array([]) + return { - "n/ep": episode_count, - "n/st": step_count, - "rews": np.concatenate(episode_rews), - "lens": np.concatenate(episode_lens), - "idxs": np.concatenate(episode_start_indices), + "n/ep": episode_count, "n/st": step_count, + "rews": rews, "lens": lens, "idxs": idxs, } From 43e69a72d4922afc6c1a3024b866e37119c44772 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 4 Feb 2021 12:59:41 +0800 Subject: [PATCH 047/104] finish step_collector test --- test/base/test_collector.py | 161 ++++++++++++++++++++---------------- tianshou/data/collector.py | 27 +++--- 2 files changed, 109 insertions(+), 79 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index d4f77eca7..64a10f0df 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -113,40 +113,28 @@ def test_collector(): VectorReplayBuffer(total_size=100, buffer_num=4), logger.preprocess_fn) c2.collect(n_episode=7) - assert np.allclose(c2.buffer.obs_next[:26, 0], [ - 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, - 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) + obs1 = obs.copy() + obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] + obs2 = obs.copy() + obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3] + c2obs = c2.buffer.obs[:, 0] + assert np.all(c2obs == obs1) or np.all(c2obs == obs2) c2.reset_env() - c2.collect(n_episode=8) - assert np.allclose(c2.buffer.obs_next[26:54, 0], [ - 1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5, - 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) + c2.reset_buffer() + assert c2.collect(n_episode=8)['n/ep'] == 8 + obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3] + assert np.all(c2.buffer.obs[:, 0] == obs) c2.collect(n_episode=4, random=True) - - -def test_collector_with_exact_episodes(): - env_lens = [2, 6, 3, 10] - writer = SummaryWriter('log/exact_collector') - logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True) - for i in env_lens] - - venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) - policy = MyPolicy() - c1 = Collector(policy, venv, - ReplayBuffer(size=1000, ignore_obs_next=False), - logger.preprocess_fn) - n_episode1 = [2, 0, 5, 1] - n_episode2 = [1, 3, 2, 0] - c1.collect(n_episode=n_episode1) - expected_steps = sum([a * b for a, b in zip(env_lens, n_episode1)]) - actual_steps = sum(venv.steps) - assert expected_steps == actual_steps - c1.collect(n_episode=n_episode2) - expected_steps = sum( - [a * (b + c) for a, b, c in zip(env_lens, n_episode1, n_episode2)]) - actual_steps = sum(venv.steps) - assert expected_steps == actual_steps + env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] + dum = DummyVectorEnv(env_fns) + num = len(env_fns) + c3 = Collector(policy, dum, + VectorReplayBuffer(total_size=40000, buffer_num=num)) + for i in range(num, 400): + c3.reset() + result = c3.collect(n_episode=i) + assert result['n/ep'] == i + assert result['n/st'] == len(c3.buffer) def test_collector_with_async(): @@ -201,26 +189,49 @@ def test_collector_with_dict_state(): Logger.single_preprocess_fn) c0.collect(n_step=3) c0.collect(n_episode=2) + assert len(c0.buffer) == 10 env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) envs.seed(666) obs = envs.reset() assert not np.isclose(obs[0]['rand'], obs[1]['rand']) - c1 = Collector(policy, envs, ReplayBuffer(size=100), - Logger.single_preprocess_fn) - c1.collect(n_step=10) - c1.collect(n_episode=[2, 1, 1, 2]) + c1 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4), + Logger.single_preprocess_fn) + with pytest.raises(AssertionError): + c1.collect(n_step=10) + c1.collect(n_step=12) + result = c1.collect(n_episode=8) + assert result['n/ep'] == 8 + lens = np.bincount(result['lens']) + assert result['n/st'] == 21 and np.all(lens == [0, 0, 2, 2, 2, 2]) or \ + result['n/st'] == 20 and np.all(lens == [0, 0, 3, 1, 2, 2]) batch, _ = c1.buffer.sample(10) - print(batch) c0.buffer.update(c1.buffer) - assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index[..., 0], [ - 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., - 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., - 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) - c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - Logger.single_preprocess_fn) - c2.collect(n_episode=[0, 0, 0, 10]) + assert len(c0.buffer) in [42, 43] + if len(c0.buffer) == 42: + assert np.all(c0.buffer[:].obs.index[..., 0] == [ + 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, + 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 2, 0, 1, 2, + 0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, + ]), c0.buffer[:].obs.index[..., 0] + else: + assert np.all(c0.buffer[:].obs.index[..., 0] == [ + 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, + 0, 1, 0, 1, 0, 1, + 0, 1, 2, 0, 1, 2, 0, 1, 2, + 0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, + ]), c0.buffer[:].obs.index[..., 0] + c2 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), + Logger.single_preprocess_fn) + c2.collect(n_episode=10) batch, _ = c2.buffer.sample(10) @@ -230,35 +241,48 @@ def test_collector_with_ma(): c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) # n_step=3 will collect a full episode - r = c0.collect(n_step=3)['rews'].mean() - assert np.asanyarray(r).size == 1 and r == 4. - r = c0.collect(n_episode=2)['rews'].mean() - assert np.asanyarray(r).size == 1 and r == 4. + r = c0.collect(n_step=3)['rews'] + assert len(r) == 0 + r = c0.collect(n_episode=2)['rews'] + assert r.shape == (2, 4) and np.all(r == 1) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) - c1 = Collector(policy, envs, ReplayBuffer(size=100), - Logger.single_preprocess_fn) - r = c1.collect(n_step=10)['rews'].mean() - assert np.asanyarray(r).size == 1 and r == 4. - r = c1.collect(n_episode=[2, 1, 1, 2])['rews'].mean() - assert np.asanyarray(r).size == 1 and r == 4. + c1 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4), + Logger.single_preprocess_fn) + r = c1.collect(n_step=12)['rews'] + assert r.shape == (2, 4) and np.all(r == 1), r + r = c1.collect(n_episode=8)['rews'] + assert r.shape == (8, 4) and np.all(r == 1) batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) - assert np.allclose(c0.buffer[:len(c0.buffer)].obs[..., 0], [ - 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., - 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., - 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) - rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, - 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, - 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1] - assert np.allclose(c0.buffer[:len(c0.buffer)].rew, - [[x] * 4 for x in rew]) - c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - Logger.single_preprocess_fn) - r = c2.collect(n_episode=[0, 0, 0, 10])['rews'].mean() - assert np.asanyarray(r).size == 1 and r == 4. + assert len(c0.buffer) in [42, 43] + if len(c0.buffer) == 42: + rew = [ + 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 0, 0, 1, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 1, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + ] + else: + rew = [ + 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, 1, 0, 1, 0, 1, + 0, 0, 1, 0, 0, 1, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 1, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + ] + assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew]) + c2 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), + Logger.single_preprocess_fn) + r = c2.collect(n_episode=10)['rews'] + assert r.shape == (10, 4) and np.all(r == 1) batch, _ = c2.buffer.sample(10) @@ -266,5 +290,4 @@ def test_collector_with_ma(): test_collector() test_collector_with_dict_state() test_collector_with_ma() - test_collector_with_async() - test_collector_with_exact_episodes() + # test_collector_with_async() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1580b3907..2bf7dff45 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -283,14 +283,9 @@ def collect( if np.any(done): env_ind_local = np.where(done)[0] - episode_count += len(env_ind_local) - episode_lens.append(ep_len[env_ind_local]) - episode_rews.append(ep_rew[env_ind_local]) - episode_start_indices.append(ep_idx[env_ind_local]) - + env_ind_global = self._ready_env_ids[env_ind_local] # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - env_ind_global = self._ready_env_ids[env_ind_local] obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn( @@ -298,10 +293,22 @@ def collect( self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) - if n_episode and n_episode - episode_count < self.env_num: - mask = ~np.isin(self._ready_env_ids, env_ind_global) - self._ready_env_ids = self._ready_env_ids[mask] - self.data = self.data[mask] + if n_episode: + total = episode_count + len(env_ind_local) + self.env_num + if n_episode < total: + if len(env_ind_global) > total - n_episode: + env_ind_global = np.random.choice( + env_ind_global, + total - n_episode, + replace=False) + mask = np.isin(self._ready_env_ids, env_ind_global) + self._ready_env_ids = self._ready_env_ids[~mask] + self.data = self.data[~mask] + + episode_count += len(env_ind_local) + episode_lens.append(ep_len[env_ind_local]) + episode_rews.append(ep_rew[env_ind_local]) + episode_start_indices.append(ep_idx[env_ind_local]) self.data.obs = self.data.obs_next From 311fba6048f474fc81b8c5da1db1747024e46be8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 4 Feb 2021 18:13:28 +0800 Subject: [PATCH 048/104] add more test about atari-style buffer setting and CachedReplayBuffer --- test/base/env.py | 18 ++++-- test/base/test_collector.py | 118 +++++++++++++++++++++++++++++++++++- tianshou/data/collector.py | 35 ++++++----- 3 files changed, 148 insertions(+), 23 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index f0907f8e1..e71957a7e 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -10,15 +10,16 @@ class MyTestEnv(gym.Env): """ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, - ma_rew=0, multidiscrete_action=False, random_sleep=False): - assert not ( - dict_state and recurse_state), \ - "dict_state and recurse_state cannot both be true" + ma_rew=0, multidiscrete_action=False, random_sleep=False, + array_state=False): + assert dict_state + recurse_state + array_state <= 1, \ + "dict_state / recurse_state / array_state can be only one true" self.size = size self.sleep = sleep self.random_sleep = random_sleep self.dict_state = dict_state self.recurse_state = recurse_state + self.array_state = array_state self.ma_rew = ma_rew self._md_action = multidiscrete_action # how many steps this env has stepped @@ -36,6 +37,8 @@ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, "rand": Box(shape=(1, 2), low=0, high=1, dtype=np.float64)}) }) + elif array_state: + self.observation_space = Box(shape=(4, 84, 84), low=0, high=255) else: self.observation_space = Box(shape=(1, ), low=0, high=size - 1) if multidiscrete_action: @@ -72,6 +75,13 @@ def _get_state(self): 'dict': {"tuple": (np.array([1], dtype=np.int64), self.rng.rand(2)), "rand": self.rng.rand(1, 2)}} + elif self.array_state: + img = np.zeros([4, 84, 84], np.int) + img[3, np.arange(84), np.arange(84)] = self.index + img[2, np.arange(84)] = self.index + img[1, :, np.arange(84)] = self.index + img[0] = self.index + return img else: return np.array([self.index], dtype=np.float32) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 64a10f0df..8d2e8cbbe 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -4,7 +4,8 @@ from tianshou.policy import BasePolicy from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.data import Collector, Batch, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, Batch, ReplayBuffer, VectorReplayBuffer, \ + CachedReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -277,6 +278,7 @@ def test_collector_with_ma(): 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, ] assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew]) + assert np.all(c0.buffer[:].done == rew) c2 = Collector( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), @@ -286,8 +288,122 @@ def test_collector_with_ma(): batch, _ = c2.buffer.sample(10) +def test_collector_with_atari_setting(): + reference_obs = np.zeros([6, 4, 84, 84]) + for i in range(6): + reference_obs[i, 3, np.arange(84), np.arange(84)] = i + reference_obs[i, 2, np.arange(84)] = i + reference_obs[i, 1, :, np.arange(84)] = i + reference_obs[i, 0] = i + + # atari single buffer + env = MyTestEnv(size=5, sleep=0, array_state=True) + policy = MyPolicy() + c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0.collect(n_step=6) + c0.collect(n_episode=2) + assert c0.buffer.obs.shape == (100, 4, 84, 84) + assert c0.buffer.obs_next.shape == (100, 4, 84, 84) + assert len(c0.buffer) == 15 + obs = np.zeros_like(c0.buffer.obs) + obs[np.arange(15)] = reference_obs[np.arange(15) % 5] + assert np.all(obs == c0.buffer.obs) + + c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True)) + c1.collect(n_episode=3) + assert np.allclose(c0.buffer.obs, c1.buffer.obs) + with pytest.raises(AttributeError): + c1.buffer.obs_next + assert np.all(reference_obs[[1, 2, 3, 4, 4] * 3] == c1.buffer[:].obs_next) + + c2 = Collector( + policy, env, + ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True)) + c2.collect(n_step=8) + assert c2.buffer.obs.shape == (100, 84, 84) + obs = np.zeros_like(c2.buffer.obs) + obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1] + assert np.all(c2.buffer.obs == obs) + assert np.allclose(c2.buffer[:].obs_next, + reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) + + # atari multi buffer + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) + for i in [2, 3, 4, 5]] + envs = DummyVectorEnv(env_fns) + c3 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4)) + c3.collect(n_step=12) + result = c3.collect(n_episode=9) + assert result["n/ep"] == 9 and result["n/st"] == 23 + assert c3.buffer.obs.shape == (100, 4, 84, 84) + obs = np.zeros_like(c3.buffer.obs) + obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]] + obs[np.arange(25, 34)] = reference_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]] + obs[np.arange(50, 58)] = reference_obs[[0, 1, 2, 3, 0, 1, 2, 3]] + obs[np.arange(75, 85)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] + assert np.all(obs == c3.buffer.obs) + obs_next = np.zeros_like(c3.buffer.obs_next) + obs_next[np.arange(8)] = reference_obs[[1, 2, 1, 2, 1, 2, 1, 2]] + obs_next[np.arange(25, 34)] = reference_obs[[1, 2, 3, 1, 2, 3, 1, 2, 3]] + obs_next[np.arange(50, 58)] = reference_obs[[1, 2, 3, 4, 1, 2, 3, 4]] + obs_next[np.arange(75, 85)] = reference_obs[[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]] + assert np.all(obs_next == c3.buffer.obs_next) + c4 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4, + ignore_obs_next=True, save_only_last_obs=True)) + c4.collect(n_step=12) + result = c4.collect(n_episode=9) + assert result["n/ep"] == 9 and result["n/st"] == 23 + assert c4.buffer.obs.shape == (100, 84, 84) + obs = np.zeros_like(c4.buffer.obs) + slice_obs = reference_obs[:, -1] + obs[np.arange(8)] = slice_obs[[0, 1, 0, 1, 0, 1, 0, 1]] + obs[np.arange(25, 34)] = slice_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]] + obs[np.arange(50, 58)] = slice_obs[[0, 1, 2, 3, 0, 1, 2, 3]] + obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] + assert np.all(c4.buffer.obs == obs) + obs_next = np.zeros([len(c4.buffer), 4, 84, 84]) + ref_index = np.array([ + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 2, 2, 1, 2, 2, 1, 2, 2, + 1, 2, 3, 3, 1, 2, 3, 3, + 1, 2, 3, 4, 4, 1, 2, 3, 4, 4, + ]) + obs_next[:, -1] = slice_obs[ref_index] + ref_index -= 1 + ref_index[ref_index < 0] = 0 + obs_next[:, -2] = slice_obs[ref_index] + ref_index -= 1 + ref_index[ref_index < 0] = 0 + obs_next[:, -3] = slice_obs[ref_index] + ref_index -= 1 + ref_index[ref_index < 0] = 0 + obs_next[:, -4] = slice_obs[ref_index] + assert np.all(obs_next == c4.buffer[:].obs_next) + + buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, + save_only_last_obs=True) + c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) + c5.collect(n_step=12) + assert len(buf) == 5 and len(c5.buffer) == 12 + result = c5.collect(n_episode=9) + assert result["n/ep"] == 9 and result["n/st"] == 23 + assert len(buf) == 35 + assert np.all(buf.obs[:len(buf)] == slice_obs[[ + 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, + 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4]]) + assert np.all(buf[:].obs_next[:, -1] == slice_obs[[ + 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, + 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4]]) + assert len(buf) == len(c5.buffer) + + if __name__ == '__main__': test_collector() test_collector_with_dict_state() test_collector_with_ma() + test_collector_with_atari_setting() # test_collector_with_async() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 2bf7dff45..027608a1f 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -284,6 +284,10 @@ def collect( if np.any(done): env_ind_local = np.where(done)[0] env_ind_global = self._ready_env_ids[env_ind_local] + episode_count += len(env_ind_local) + episode_lens.append(ep_len[env_ind_local]) + episode_rews.append(ep_rew[env_ind_local]) + episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. obs_reset = self.env.reset(env_ind_global) @@ -294,21 +298,14 @@ def collect( for i in env_ind_local: self._reset_state(i) if n_episode: - total = episode_count + len(env_ind_local) + self.env_num - if n_episode < total: - if len(env_ind_global) > total - n_episode: + diff = episode_count + self.env_num - n_episode + if diff > 0: + if len(env_ind_global) > diff: env_ind_global = np.random.choice( - env_ind_global, - total - n_episode, - replace=False) - mask = np.isin(self._ready_env_ids, env_ind_global) - self._ready_env_ids = self._ready_env_ids[~mask] - self.data = self.data[~mask] - - episode_count += len(env_ind_local) - episode_lens.append(ep_len[env_ind_local]) - episode_rews.append(ep_rew[env_ind_local]) - episode_start_indices.append(ep_idx[env_ind_local]) + env_ind_global, diff, replace=False) + mask = ~np.isin(self._ready_env_ids, env_ind_global) + self._ready_env_ids = self._ready_env_ids[mask] + self.data = self.data[mask] self.data.obs = self.data.obs_next @@ -326,10 +323,12 @@ def collect( obs_next={}, info={}, policy={}) self.reset_env() - rews = np.concatenate(episode_rews) if episode_rews else np.array([]) - lens = np.concatenate(episode_lens) if episode_lens else np.array([]) - idxs = np.concatenate(episode_start_indices) if episode_start_indices \ - else np.array([]) + if episode_count > 0: + rews, lens, idxs = list(map(np.concatenate, [ + episode_rews, episode_lens, episode_start_indices])) + else: + rews, lens, idxs = \ + np.array([]), np.array([], np.int), np.array([], np.int) return { "n/ep": episode_count, "n/st": step_count, From a99a6132367df9c9be6184769d675eb96b37319f Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 4 Feb 2021 19:11:46 +0800 Subject: [PATCH 049/104] fix Collector(buffer=None) --- test/base/test_collector.py | 11 ++++++++++- tianshou/data/buffer.py | 2 -- tianshou/data/collector.py | 8 ++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 8d2e8cbbe..9cefc3d1a 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -387,7 +387,7 @@ def test_collector_with_atari_setting(): buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True) c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) - c5.collect(n_step=12) + result_ = c5.collect(n_step=12) assert len(buf) == 5 and len(c5.buffer) == 12 result = c5.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 @@ -400,6 +400,15 @@ def test_collector_with_atari_setting(): 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4]]) assert len(buf) == len(c5.buffer) + # test buffer=None + c6 = Collector(policy, envs) + result1 = c6.collect(n_step=12) + for key in ["n/ep", "n/st", "rews", "lens"]: + assert np.allclose(result1[key], result_[key]) + result2 = c6.collect(n_episode=9) + for key in ["n/ep", "n/st", "rews", "lens"]: + assert np.allclose(result2[key], result[key]) + if __name__ == '__main__': test_collector() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index edfd9e5e8..d3bd4cce2 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,8 +1,6 @@ import h5py import torch -import warnings import numpy as np -from numbers import Number from typing import Any, Dict, List, Tuple, Union, Optional from tianshou.data.batch import _create_value diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 027608a1f..32c20b3ed 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -3,16 +3,12 @@ import torch import warnings import numpy as np -from copy import deepcopy -from numbers import Number from typing import Dict, List, Union, Optional, Callable from tianshou.policy import BasePolicy -from tianshou.exploration import BaseNoise -from tianshou.data.batch import _create_value from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ - ReplayBufferManager, CachedReplayBuffer, to_numpy +from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager, \ + VectorReplayBuffer, CachedReplayBuffer, to_numpy class Collector(object): From 220854fc3af53c589e4d38e2badfffbe9d13093c Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 4 Feb 2021 23:50:23 +0800 Subject: [PATCH 050/104] collector enhance --- tianshou/data/buffer.py | 12 +++++------ tianshou/data/collector.py | 44 +++++++++++++++++--------------------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index edfd9e5e8..4a13488de 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -600,9 +600,9 @@ class VectorReplayBuffer(ReplayBufferManager): def __init__( self, total_size: int, buffer_num: int, **kwargs: Any ) -> None: - sizes = [total_size // buffer_num + (i < total_size % buffer_num) - for i in range(buffer_num)] - buffer_list = [ReplayBuffer(size, **kwargs) for size in sizes] + assert buffer_num > 0 and total_size % buffer_num == 0 + size = total_size // buffer_num + buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)] super().__init__(buffer_list) @@ -610,10 +610,10 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): def __init__( self, total_size: int, buffer_num: int, **kwargs: Any ) -> None: - sizes = [total_size // buffer_num + (i < total_size % buffer_num) - for i in range(buffer_num)] + assert buffer_num > 0 and total_size % buffer_num == 0 + size = total_size // buffer_num buffer_list = [PrioritizedReplayBuffer(size, **kwargs) - for size in sizes] + for _ in range(buffer_num)] super().__init__(buffer_list) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 027608a1f..944079a4a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -92,7 +92,6 @@ def __init__( self.env = env self.env_num = len(env) self.training = training - self._save_data = buffer is not None self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn @@ -101,13 +100,14 @@ def __init__( self.reset() def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: + max_episode_steps = self.env._max_episode_steps[0] if buffer is None: - self.buffer = VectorReplayBuffer(self.env_num, self.env_num) + buffer = VectorReplayBuffer( + self.env_num * 1, self.env_num) elif isinstance(buffer, ReplayBufferManager): assert buffer.buffer_num >= self.env_num if isinstance(buffer, CachedReplayBuffer): assert buffer.cached_buffer_num >= self.env_num - self.buffer = buffer else: # ReplayBuffer or PrioritizedReplayBuffer assert buffer.maxsize > 0 if self.env_num > 1: @@ -122,7 +122,7 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: f"collect {self.env_num} envs,\n\tplease use {vector_type}" f"(total_size={buffer.maxsize}, buffer_num={self.env_num}," " ...) instead.") - self.buffer = buffer + self.buffer =buffer # TODO move to trainer # @staticmethod @@ -156,7 +156,6 @@ def reset_buffer(self) -> None: def reset_env(self) -> None: """Reset all of the environment(s)' states and the cache buffers.""" - self._ready_env_ids = np.arange(self.env_num) obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) @@ -218,26 +217,27 @@ def collect( assert n_step % self.env_num == 0, \ "n_step should be a multiple of #envs" else: - assert isinstance(n_episode, int) and n_episode >= self.env_num - + assert n_episode > 0 start_time = time.time() step_count = 0 - # episode of each environment episode_count = 0 episode_rews = [] episode_lens = [] episode_start_indices = [] + ready_env_ids = np.arange(min(self.env_num, n_episode)) + self.data = self.data[:min(self.env_num, n_episode)] + while True: - assert len(self.data) == len(self._ready_env_ids) + assert len(self.data) == len(ready_env_ids) # restore the state: if the last state is None, it won't store last_state = self.data.policy.pop("hidden_state", None) # get the next action if random: result = Batch(act=[self._action_space[i].sample() - for i in self._ready_env_ids]) + for i in ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -258,7 +258,7 @@ def collect( # step in env obs_next, rew, done, info = self.env.step( - act, id=self._ready_env_ids) + act, id=ready_env_ids) self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: @@ -276,14 +276,14 @@ def collect( # add data into the buffer ptr, ep_rew, ep_len, ep_idx = self.buffer.add( - self.data, buffer_ids=self._ready_env_ids) + self.data, buffer_ids=ready_env_ids) # collect statistics - step_count += len(self._ready_env_ids) + step_count += len(ready_env_ids) if np.any(done): env_ind_local = np.where(done)[0] - env_ind_global = self._ready_env_ids[env_ind_local] + env_ind_global = ready_env_ids[env_ind_local] episode_count += len(env_ind_local) episode_lens.append(ep_len[env_ind_local]) episode_rews.append(ep_rew[env_ind_local]) @@ -297,16 +297,12 @@ def collect( self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) - if n_episode: - diff = episode_count + self.env_num - n_episode - if diff > 0: - if len(env_ind_global) > diff: - env_ind_global = np.random.choice( - env_ind_global, diff, replace=False) - mask = ~np.isin(self._ready_env_ids, env_ind_global) - self._ready_env_ids = self._ready_env_ids[mask] - self.data = self.data[mask] - + surplus_env_n = len(ready_env_ids) - (n_episode - episode_count) + if n_episode and surplus_env_n > 0: + mask = np.ones_like(ready_env_ids, np.bool) + mask[env_ind_local[:surplus_env_n]] = False + ready_env_ids = ready_env_ids[mask] + self.data = self.data[mask] self.data.obs = self.data.obs_next if (n_step and step_count >= n_step) or \ From 1b8275e597475808b804a565e57e517daa12bd6b Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Sun, 7 Feb 2021 18:55:19 +0800 Subject: [PATCH 051/104] add expl nosie in collector --- tianshou/data/collector.py | 2 +- tianshou/policy/base.py | 14 +++++++++++-- tianshou/policy/modelfree/ddpg.py | 13 +++++++++--- tianshou/policy/modelfree/discrete_sac.py | 7 +++++++ tianshou/policy/modelfree/dqn.py | 24 +++++++++++++++-------- tianshou/policy/modelfree/sac.py | 4 +--- 6 files changed, 47 insertions(+), 17 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 16467a9bd..7623eb86a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -249,7 +249,7 @@ def collect( policy.hidden_state = state # save state into buffer act = to_numpy(result.act) if self.training and not random: - act = self.policy.exploration_noise(act) + act = self.policy.exploration_noise(act, self.data) self.data.update(policy=policy, act=act) # step in env diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 9a4dfbd8e..5152660be 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -68,9 +68,19 @@ def set_agent_id(self, agent_id: int) -> None: self.agent_id = agent_id def exploration_noise( - self, act: Union[np.ndarray, Batch] + self, + act: Union[np.ndarray, Batch], + batch: Batch, ) -> Union[np.ndarray, Batch]: - """Modify the action from policy.forward with exploration noise.""" + """Modify the action from policy.forward with exploration noise. + + :param act: a data batch or numpy.ndarray which is the action taken by + policy.forward. + :param batch: the input batch for policy.forward, kept for advanced usage. + + :return: action in the same form of input 'act' but with added + exploration noise. + """ return act @abstractmethod diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 1ac293c88..bb81433b6 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -141,9 +141,6 @@ def forward( obs = batch[input] actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias - if self._noise and not self.updating: - actions += to_torch_as(self._noise(actions.shape), actions) - actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: @@ -166,3 +163,13 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: "loss/actor": actor_loss.item(), "loss/critic": critic_loss.item(), } + + def exploration_noise( + self, + act: np.ndarray, + batch: Batch, + ) -> np.ndarray: + if self._noise: + act = act + self._noise(act.shape) + act = act.clip(self._range[0], self._range[1]) + return act \ No newline at end of file diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 4db2dc9dc..02781a32d 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -145,3 +145,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: result["alpha"] = self._alpha.item() # type: ignore return result + + def exploration_noise( + self, + act: Union[np.ndarray, Batch], + batch: Batch, + ) -> Union[np.ndarray, Batch]: + return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 54397915c..54b570971 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -148,20 +148,14 @@ def forward( obs_ = obs.obs if hasattr(obs, "obs") else obs logits, h = model(obs_, state=state, info=batch.info) q = self.compute_q_value(logits) + if not hasattr(self, "max_action_num"): + self.max_action_num = q.shape[1] act: np.ndarray = to_numpy(q.max(dim=1)[1]) if hasattr(obs, "mask"): # some of actions are masked, they cannot be selected q_: np.ndarray = to_numpy(q) q_[~obs.mask] = -np.inf act = q_.argmax(axis=1) - # add eps to act in training or testing phase - if not self.updating and not np.isclose(self.eps, 0.0): - for i in range(len(q)): - if np.random.rand() < self.eps: - q_ = np.random.rand(*q[i].shape) - if hasattr(obs, "mask"): - q_[~obs.mask[i]] = -np.inf - act[i] = q_.argmax() return Batch(logits=logits, act=act, state=h) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: @@ -179,3 +173,17 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.optim.step() self._iter += 1 return {"loss": loss.item()} + + def exploration_noise( + self, + act: np.ndarray, + batch: Batch, + ) -> np.ndarray: + if not np.isclose(self.eps, 0.0): + for i in range(len(act)): + if np.random.rand() < self.eps: + q_ = np.random.rand(self.max_action_num) + if hasattr(batch["obs"], "mask"): + q_[~batch["obs"].mask[i]] = -np.inf + act[i] = q_.argmax() + return act \ No newline at end of file diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index fbdd12297..a8fb02f51 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -130,9 +130,7 @@ def forward( # type: ignore y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) - if self._noise is not None and self.training and not self.updating: - act += to_torch_as(self._noise(act.shape), act) - act = act.clamp(self._range[0], self._range[1]) + return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) From 509e417575b1cf1a1020398e6679686f68d01820 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Sun, 7 Feb 2021 21:37:50 +0800 Subject: [PATCH 052/104] refactor basepolicy to coordinate with step collector --- test/base/test_returns.py | 2 +- tianshou/data/collector.py | 1 + tianshou/policy/base.py | 100 +++++++++++++++++++++---------- tianshou/policy/modelfree/a2c.py | 9 ++- tianshou/policy/modelfree/pg.py | 3 +- tianshou/policy/modelfree/ppo.py | 4 +- 6 files changed, 80 insertions(+), 39 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index a8bdc7c9d..d3d7af52a 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -17,7 +17,7 @@ def compute_episodic_return_base(batch, gamma): batch.returns = returns return batch - +# TODO need to change def test_episodic_returns(size=2560): fn = BasePolicy.compute_episodic_return batch = Batch( diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 7623eb86a..b47222935 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -13,6 +13,7 @@ class Collector(object): # TODO change doc + # TODO change all test to ensure api """Collector enables the policy to interact with different types of envs. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 5152660be..124b33c37 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -189,23 +189,31 @@ def update( self.updating = False return result + @staticmethod + def value_mask(batch): + # TODO doc + return ~batch.done + @staticmethod def compute_episodic_return( batch: Batch, + buffer: ReplayBuffer, + indice: np.ndarray, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, gae_lambda: float = 0.95, rew_norm: bool = False, ) -> Batch: + # TODO change doc """Compute returns over given full-length episodes. Implementation of Generalized Advantage Estimator (arXiv:1506.02438). :param batch: a data batch which contains several full-episode data - chronologically. + chronologically. TODO generalize :type batch: :class:`~tianshou.data.Batch` :param v_s_: the value function of all next states :math:`V(s')`. - :type v_s_: numpy.ndarray + :type v_s_: numpy.ndarray #TODO n+1 value shape :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param float gae_lambda: the parameter for Generalized Advantage @@ -217,8 +225,15 @@ def compute_episodic_return( array with shape (bsz, ). """ rew = batch.rew - v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten()) - returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda) + if v_s_ is None: + assert np.isclose(gae_lambda, 1.0) + v_s_ = np.zeros_like(rew) + else: + v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(batch) + + end_flag = batch.done.copy() + end_flag[np.isin(indice, buffer.unfinished_index())] = True + returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda) if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): returns = (returns - returns.mean()) / returns.std() batch.returns = returns @@ -234,6 +249,7 @@ def compute_nstep_return( n_step: int = 1, rew_norm: bool = False, ) -> Batch: + # TODO, doc r"""Compute n-step return for Q-learning targets. .. math:: @@ -246,8 +262,7 @@ def compute_nstep_return( :param batch: a data batch, which is equal to buffer[indice]. :type batch: :class:`~tianshou.data.Batch` - :param buffer: a data buffer which contains several full-episode data - chronologically. + :param buffer: the data buffer. :type buffer: :class:`~tianshou.data.ReplayBuffer` :param indice: sampled timestep. :type indice: numpy.ndarray @@ -264,6 +279,7 @@ def compute_nstep_return( torch.Tensor with the same shape as target_q_fn's return tensor. """ rew = buffer.rew + # TODO this rew_norm will cause unstablity in training if rew_norm: bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() @@ -271,14 +287,22 @@ def compute_nstep_return( mean, std = 0.0, 1.0 else: mean, std = 0.0, 1.0 - buf_len = len(buffer) - terminal = (indice + n_step - 1) % buf_len + indices = [indice] + for _ in range(n_step - 1): + indices.append(buffer.next(indices[-1])) + indices = np.stack(indices) + + # terminal indicates buffer indexes nstep after 'indice', + # and are truncated at the end of each episode + terminal = indices[-1] with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) - target_q = to_numpy(target_q_torch) + target_q = to_numpy(target_q_torch) * BasePolicy.value_mask(batch) + end_flag = buffer.done.copy() + end_flag[buffer.unfinished_index()] = True + target_q = _nstep_return(rew, end_flag, target_q, indices, + gamma, n_step, mean, std) - target_q = _nstep_return(rew, buffer.done, target_q, indice, - gamma, n_step, len(buffer), mean, std) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) @@ -288,57 +312,69 @@ def _compile(self) -> None: f64 = np.array([0, 1], dtype=np.float64) f32 = np.array([0, 1], dtype=np.float32) b = np.array([False, True], dtype=np.bool_) - i64 = np.array([0, 1], dtype=np.int64) + i64 = np.array([[0, 1]], dtype=np.int64) + _gae_return(f64, f64, f64, b, 0.1, 0.1) + _gae_return(f32, f32, f64, b, 0.1, 0.1) _episodic_return(f64, f64, b, 0.1, 0.1) _episodic_return(f32, f64, b, 0.1, 0.1) - _nstep_return(f64, b, f32, i64, 0.1, 1, 4, 0.0, 1.0) - + _nstep_return(f64, b, f32, i64, 0.1, 1, 0.0, 1.0) @njit -def _episodic_return( +def _gae_return( + v_s: np.ndarray, v_s_: np.ndarray, rew: np.ndarray, - done: np.ndarray, + end_flag: np.ndarray, gamma: float, gae_lambda: float, ) -> np.ndarray: - """Numba speedup: 4.1s -> 0.057s.""" - returns = np.roll(v_s_, 1) - m = (1.0 - done) * gamma - delta = rew + v_s_ * m - returns - m *= gae_lambda + returns = np.zeros(rew.shape) + delta = rew + v_s_ * gamma - v_s + m = (1.0 - end_flag) * (gamma * gae_lambda) gae = 0.0 for i in range(len(rew) - 1, -1, -1): gae = delta[i] + m[i] * gae - returns[i] += gae + returns[i] = gae return returns +@njit +def _episodic_return( + v_s_: np.ndarray, + rew: np.ndarray, + end_flag: np.ndarray, + gamma: float, + gae_lambda: float, +) -> np.ndarray: + """Numba speedup: 4.1s -> 0.057s.""" + v_s = np.roll(v_s_, 1) + return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s + @njit def _nstep_return( rew: np.ndarray, - done: np.ndarray, + end_flag: np.ndarray, target_q: np.ndarray, - indice: np.ndarray, + indices: np.ndarray, gamma: float, n_step: int, - buf_len: int, mean: float, std: float, ) -> np.ndarray: - """Numba speedup: 0.3s -> 0.15s.""" + gamma_buffer = np.ones(n_step+1) + for i in range(1, n_step+1): + gamma_buffer[i] = gamma_buffer[i-1]*gamma target_shape = target_q.shape bsz = target_shape[0] # change target_q to 2d array target_q = target_q.reshape(bsz, -1) returns = np.zeros(target_q.shape) - gammas = np.full(indice.shape, n_step) + gammas = np.full(indices[0].shape, n_step) for n in range(n_step - 1, -1, -1): - now = (indice + n) % buf_len - gammas[done[now] > 0] = n - returns[done[now] > 0] = 0.0 + now = indices[n] + gammas[end_flag[now] > 0] = n + returns[end_flag[now] > 0] = 0.0 returns = (rew[now].reshape(-1, 1) - mean) / std + gamma * returns - target_q[gammas != n_step] = 0.0 gammas = gammas.reshape(-1, 1) - target_q = target_q * (gamma ** gammas) + returns + target_q = target_q * gamma_buffer[gammas] + returns return target_q.reshape(target_shape) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index e215c6bbd..b50eb7fb6 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -69,15 +69,18 @@ def process_fn( ) -> Batch: if self._lambda in [0.0, 1.0]: return self.compute_episodic_return( - batch, None, gamma=self._gamma, gae_lambda=self._lambda) + batch, buffer, indice, + None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_.append(to_numpy(self.critic(b.obs_next))) v_ = np.concatenate(v_, axis=0) return self.compute_episodic_return( - batch, v_, gamma=self._gamma, gae_lambda=self._lambda, - rew_norm=self._rew_norm) + batch, buffer, indice, + v_, gamma=self._gamma, + gae_lambda=self._lambda, rew_norm=self._rew_norm) + def forward( self, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 2f3304658..92b7f5d89 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -56,7 +56,8 @@ def process_fn( # batch.returns = self._vanilla_returns(batch) # batch.returns = self._vectorized_returns(batch) return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1.0, rew_norm=self._rew_norm) + batch, buffer, indice, gamma=self._gamma, + gae_lambda=1.0, rew_norm=self._rew_norm) def forward( self, diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 4cf9f9054..7fd6f1f26 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -100,8 +100,8 @@ def process_fn( ) v_ = to_numpy(torch.cat(v_, dim=0)) batch = self.compute_episodic_return( - batch, v_, gamma=self._gamma, gae_lambda=self._lambda, - rew_norm=self._rew_norm) + batch, buffer, indice, v_, gamma=self._gamma, + gae_lambda=self._lambda, rew_norm=self._rew_norm) batch.v = torch.cat(v, dim=0).flatten() # old value batch.act = to_torch_as(batch.act, v[0]) batch.logp_old = torch.cat(old_log_prob, dim=0) From a1c98518dff85b08640bb58662c871734a006591 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Mon, 8 Feb 2021 11:46:19 +0800 Subject: [PATCH 053/104] fix bug, test_collector passed now --- test/base/test_collector.py | 4 +--- tianshou/data/collector.py | 23 +++++++++++++---------- tianshou/policy/base.py | 12 +++++------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 9cefc3d1a..1fa7c1d19 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -100,8 +100,6 @@ def test_collector(): assert np.allclose(c1.buffer.obs[:, 0], obs) assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) - with pytest.raises(AssertionError): - c1.collect(n_episode=2) c1.collect(n_episode=4) assert len(c1.buffer) == 16 obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] @@ -130,7 +128,7 @@ def test_collector(): dum = DummyVectorEnv(env_fns) num = len(env_fns) c3 = Collector(policy, dum, - VectorReplayBuffer(total_size=40000, buffer_num=num)) + VectorReplayBuffer(total_size=90000, buffer_num=num)) for i in range(num, 400): c3.reset() result = c3.collect(n_episode=i) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b47222935..1699e1dea 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -97,7 +97,6 @@ def __init__( self.reset() def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: - max_episode_steps = self.env._max_episode_steps[0] if buffer is None: buffer = VectorReplayBuffer( self.env_num * 1, self.env_num) @@ -213,8 +212,11 @@ def collect( assert n_step > 0 assert n_step % self.env_num == 0, \ "n_step should be a multiple of #envs" + ready_env_ids = np.arange(self.env_num) else: assert n_episode > 0 + ready_env_ids = np.arange(min(self.env_num, n_episode)) + self.data = self.data[:min(self.env_num, n_episode)] start_time = time.time() step_count = 0 @@ -223,9 +225,6 @@ def collect( episode_lens = [] episode_start_indices = [] - ready_env_ids = np.arange(min(self.env_num, n_episode)) - self.data = self.data[:min(self.env_num, n_episode)] - while True: assert len(self.data) == len(ready_env_ids) # restore the state: if the last state is None, it won't store @@ -294,12 +293,16 @@ def collect( self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) - surplus_env_n = len(ready_env_ids) - (n_episode - episode_count) - if n_episode and surplus_env_n > 0: - mask = np.ones_like(ready_env_ids, np.bool) - mask[env_ind_local[:surplus_env_n]] = False - ready_env_ids = ready_env_ids[mask] - self.data = self.data[mask] + + # remove surplus env id from ready_env_ids to avoid bias in + # selecting environments. + if n_episode: + surplus_env_n = len(ready_env_ids) - (n_episode - episode_count) + if surplus_env_n > 0: + mask = np.ones_like(ready_env_ids, np.bool) + mask[env_ind_local[:surplus_env_n]] = False + ready_env_ids = ready_env_ids[mask] + self.data = self.data[mask] self.data.obs = self.data.obs_next if (n_step and step_count >= n_step) or \ diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 124b33c37..c8bc9b9cb 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -349,7 +349,6 @@ def _episodic_return( v_s = np.roll(v_s_, 1) return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s - @njit def _nstep_return( rew: np.ndarray, @@ -361,9 +360,9 @@ def _nstep_return( mean: float, std: float, ) -> np.ndarray: - gamma_buffer = np.ones(n_step+1) - for i in range(1, n_step+1): - gamma_buffer[i] = gamma_buffer[i-1]*gamma + gamma_buffer = np.ones(n_step + 1) + for i in range(1, n_step + 1): + gamma_buffer[i] = gamma_buffer[i - 1]*gamma target_shape = target_q.shape bsz = target_shape[0] # change target_q to 2d array @@ -374,7 +373,6 @@ def _nstep_return( now = indices[n] gammas[end_flag[now] > 0] = n returns[end_flag[now] > 0] = 0.0 - returns = (rew[now].reshape(-1, 1) - mean) / std + gamma * returns - gammas = gammas.reshape(-1, 1) - target_q = target_q * gamma_buffer[gammas] + returns + returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns + target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns return target_q.reshape(target_shape) From 6c6d503a1f67cda831c10208b3ed1082abe99bb2 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Mon, 8 Feb 2021 15:17:03 +0800 Subject: [PATCH 054/104] change buffer setup in all files, switch buffer to vec buffer --- examples/atari/atari_c51.py | 8 +++++--- examples/atari/atari_dqn.py | 16 ++++++++++------ examples/atari/atari_qrdqn.py | 8 +++++--- examples/atari/runnable/pong_a2c.py | 5 +++-- examples/atari/runnable/pong_ppo.py | 5 +++-- examples/box2d/acrobot_dualdqn.py | 4 ++-- examples/box2d/bipedal_hardcore_sac.py | 4 ++-- examples/box2d/lunarlander_dqn.py | 4 ++-- examples/box2d/mcc_sac.py | 4 ++-- examples/mujoco/mujoco_sac.py | 4 ++-- examples/mujoco/runnable/ant_v2_ddpg.py | 4 ++-- examples/mujoco/runnable/ant_v2_td3.py | 4 ++-- .../mujoco/runnable/halfcheetahBullet_v0_sac.py | 4 ++-- examples/mujoco/runnable/point_maze_td3.py | 4 ++-- test/base/test_collector.py | 2 +- test/continuous/test_ddpg.py | 4 ++-- test/continuous/test_ppo.py | 4 ++-- test/continuous/test_sac_with_il.py | 4 ++-- test/continuous/test_td3.py | 4 ++-- test/discrete/test_a2c_with_il.py | 4 ++-- test/discrete/test_c51.py | 9 +++++---- test/discrete/test_dqn.py | 11 ++++++----- test/discrete/test_drqn.py | 7 ++++--- test/discrete/test_pg.py | 4 ++-- test/discrete/test_ppo.py | 4 ++-- test/discrete/test_qrdqn.py | 9 +++++---- test/discrete/test_sac.py | 4 ++-- test/modelbase/test_psrl.py | 4 ++-- test/multiagent/tic_tac_toe.py | 4 ++-- test/throughput/test_collector_profile.py | 14 +++++++++----- tianshou/data/buffer.py | 8 ++++---- tianshou/data/collector.py | 6 +++++- 32 files changed, 103 insertions(+), 81 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index c74da8f13..60f62cd08 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -8,7 +8,7 @@ from tianshou.policy import C51Policy from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from atari_network import C51 from atari_wrapper import wrap_deepmind @@ -90,8 +90,10 @@ def test_c51(args=get_args()): print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM - buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer) test_collector = Collector(policy, test_envs) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 4f4c4f2df..a2a907147 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -8,7 +8,7 @@ from tianshou.policy import DQNPolicy from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from atari_network import DQN from atari_wrapper import wrap_deepmind @@ -86,8 +86,10 @@ def test_dqn(args=get_args()): print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM - buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer) test_collector = Collector(policy, test_envs) @@ -127,9 +129,11 @@ def watch(): test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") - buffer = ReplayBuffer( - args.buffer_size, ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + buffer = VectorReplayBuffer(args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack) collector = Collector(policy, test_envs, buffer) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 08c34733c..a292e2496 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -8,7 +8,7 @@ from tianshou.policy import QRDQNPolicy from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from atari_network import QRDQN from atari_wrapper import wrap_deepmind @@ -88,8 +88,10 @@ def test_qrdqn(args=get_args()): print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM - buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer) test_collector = Collector(policy, test_envs) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 4da98ce67..4b0fd87af 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -9,7 +9,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic from atari import create_atari_environment, preprocess_fn @@ -75,7 +75,8 @@ def test_a2c(args=get_args()): ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size), + policy, train_envs, + VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)), preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 3141a4bf2..55dc0760b 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -9,7 +9,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic from atari import create_atari_environment, preprocess_fn @@ -79,7 +79,8 @@ def test_ppo(args=get_args()): action_range=None) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size), + policy, train_envs, + VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)), preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index a0b8f7788..d114dd58e 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -10,7 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -72,7 +72,7 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 87539e46e..a9d38b309 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -10,7 +10,7 @@ from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -125,7 +125,7 @@ def test_sac_bipedal(args=get_args()): # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 781ae63ec..d0c7b83ed 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -9,7 +9,7 @@ from tianshou.policy import DQNPolicy from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv @@ -73,7 +73,7 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index a700e2f24..18b1af443 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -8,7 +8,7 @@ from tianshou.policy import SACPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.utils.net.common import Net @@ -94,7 +94,7 @@ def test_sac(args=get_args()): exploration_noise=OUNoise(0.0, args.noise_std)) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 0b09c81a5..107cfbf34 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -10,7 +10,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -108,7 +108,7 @@ def test_sac(args=get_args()): # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'sac') diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index e1bcd4964..686a9f677 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -10,7 +10,7 @@ from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -74,7 +74,7 @@ def test_ddpg(args=get_args()): reward_normalization=True, ignore_done=True) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log writer = SummaryWriter(args.logdir + '/' + 'ddpg') diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index 6e37c61f4..36863acde 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -10,7 +10,7 @@ from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -82,7 +82,7 @@ def test_td3(args=get_args()): reward_normalization=True, ignore_done=True) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index ced529913..666c8974f 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -11,7 +11,7 @@ from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -82,7 +82,7 @@ def test_sac(args=get_args()): reward_normalization=True, ignore_done=True) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index 781ec9734..78f0db22e 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -10,7 +10,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.exploration import GaussianNoise from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic from mujoco.register import reg @@ -87,7 +87,7 @@ def test_td3(args=get_args()): reward_normalization=True, ignore_done=True) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 1fa7c1d19..d4823d3d7 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -128,7 +128,7 @@ def test_collector(): dum = DummyVectorEnv(env_fns) num = len(env_fns) c3 = Collector(policy, dum, - VectorReplayBuffer(total_size=90000, buffer_num=num)) + VectorReplayBuffer(total_size=40000, buffer_num=num)) for i in range(num, 400): c3.reset() result = c3.collect(n_episode=i) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 911cd8bf8..20c326be6 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -11,7 +11,7 @@ from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -86,7 +86,7 @@ def test_ddpg(args=get_args()): estimation_step=args.n_step) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector( policy, test_envs, action_noise=GaussianNoise(sigma=args.test_noise)) # log diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index a6d7469d3..b7556683a 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -104,7 +104,7 @@ def dist(*logits): gae_lambda=args.gae_lambda) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 20eb63e88..27701f0ca 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -9,7 +9,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import SACPolicy, ImitationPolicy from tianshou.utils.net.continuous import Actor, ActorProb, Critic @@ -92,7 +92,7 @@ def test_sac_with_il(args=get_args()): estimation_step=args.n_step) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 8eb0c8fe2..c8f5537e8 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,7 +11,7 @@ from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -97,7 +97,7 @@ def test_td3(args=get_args()): estimation_step=args.n_step) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 7c5eff4c0..ce6d09bee 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -8,7 +8,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.trainer import onpolicy_trainer, offpolicy_trainer @@ -79,7 +79,7 @@ def test_a2c_with_il(args=get_args()): max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'a2c') diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index e726471f0..a4d496792 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -10,7 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -75,10 +75,11 @@ def test_c51(args=get_args()): ).to(args.device) # buffer if args.prioritized_replay: - buf = PrioritizedReplayBuffer( - args.buffer_size, alpha=args.alpha, beta=args.beta) + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num = len(train_envs), + alpha=args.alpha, beta=args.beta) else: - buf = ReplayBuffer(args.buffer_size) + buf = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf) test_collector = Collector(policy, test_envs) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 94b488886..b83b63ff1 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -77,10 +77,11 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # buffer if args.prioritized_replay: - buf = PrioritizedReplayBuffer( - args.buffer_size, alpha=args.alpha, beta=args.beta) + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num = len(train_envs), + alpha=args.alpha, beta=args.beta) else: - buf = ReplayBuffer(args.buffer_size) + buf = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf) test_collector = Collector(policy, test_envs) @@ -130,7 +131,7 @@ def test_fn(epoch, env_step): print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') # save buffer in pickle format, for imitation learning unittest - buf = ReplayBuffer(args.buffer_size) + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) collector = Collector(policy, test_envs, buf) collector.collect(n_step=args.buffer_size) pickle.dump(buf, open(args.save_buffer_name, "wb")) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index c695aa916..758e6ae10 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -10,7 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.common import Recurrent -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -66,8 +66,9 @@ def test_drqn(args=get_args()): target_update_freq=args.target_update_freq) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer( - args.buffer_size, stack_num=args.stack_num, ignore_obs_next=True)) + policy, train_envs, VectorReplayBuffer( + args.buffer_size, buffer_num = len(train_envs), + stack_num=args.stack_num, ignore_obs_next=True)) # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs) # policy.set_eps(1) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 75447b768..f00803aea 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -10,7 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -65,7 +65,7 @@ def test_pg(args=get_args()): reward_normalization=args.rew_norm) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'pg') diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 50539642c..dcaea39cc 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -10,7 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic @@ -91,7 +91,7 @@ def test_ppo(args=get_args()): value_clip=args.value_clip) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 138f18f41..b63553f82 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -10,7 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -73,10 +73,11 @@ def test_qrdqn(args=get_args()): ).to(args.device) # buffer if args.prioritized_replay: - buf = PrioritizedReplayBuffer( - args.buffer_size, alpha=args.alpha, beta=args.beta) + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num = len(train_envs), + alpha=args.alpha, beta=args.beta) else: - buf = ReplayBuffer(args.buffer_size) + buf = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf) test_collector = Collector(policy, test_envs) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 4c0c51625..7357947bc 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -9,7 +9,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DiscreteSACPolicy from tianshou.utils.net.discrete import Actor, Critic @@ -87,7 +87,7 @@ def test_discrete_sac(args=get_args()): ignore_done=args.ignore_done) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 9795dd483..2e84bdd1d 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -7,7 +7,7 @@ from tianshou.policy import PSRLPolicy from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv @@ -61,7 +61,7 @@ def test_psrl(args=get_args()): args.add_done_loop) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # log writer = SummaryWriter(args.logdir + '/' + args.task) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index fe85213e4..c5207b88c 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -9,7 +9,7 @@ from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import BasePolicy, DQNPolicy, RandomPolicy, \ MultiAgentPolicyManager @@ -124,7 +124,7 @@ def env_func(): # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py index 4036472f7..7e4e8ac75 100644 --- a/test/throughput/test_collector_profile.py +++ b/test/throughput/test_collector_profile.py @@ -4,7 +4,7 @@ from gym.spaces.discrete import Discrete from gym.utils import seeding -from tianshou.data import Batch, Collector, ReplayBuffer +from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy @@ -63,10 +63,13 @@ def data(): [lambda: SimpleEnv() for _ in range(8)]) env_subproc_init.seed(np.random.randint(1000, size=100).tolist()) buffer = ReplayBuffer(50000) + vec_buffer = VectorReplayBuffer(50000, 100) policy = SimplePolicy() collector = Collector(policy, env, ReplayBuffer(50000)) - collector_vec = Collector(policy, env_vec, ReplayBuffer(50000)) - collector_subproc = Collector(policy, env_subproc, ReplayBuffer(50000)) + collector_vec = Collector(policy, env_vec, + VectorReplayBuffer(50000, env_vec.env_num)) + collector_subproc = Collector(policy, env_subproc, + VectorReplayBuffer(50000, env_subproc.env_num)) return { "env": env, "env_vec": env_vec, @@ -74,6 +77,7 @@ def data(): "env_subproc_init": env_subproc_init, "policy": policy, "buffer": buffer, + "vec_buffer": vec_buffer, "collector": collector, "collector_vec": collector_vec, "collector_subproc": collector_subproc, @@ -102,7 +106,7 @@ def test_collect_ep(data): def test_init_vec_env(data): for _ in range(5000): - Collector(data["policy"], data["env_vec"], data["buffer"]) + Collector(data["policy"], data["env_vec"], data["vec_buffer"]) def test_reset_vec_env(data): @@ -122,7 +126,7 @@ def test_collect_vec_env_ep(data): def test_init_subproc_env(data): for _ in range(5000): - Collector(data["policy"], data["env_subproc_init"], data["buffer"]) + Collector(data["policy"], data["env_subproc_init"], data["vec_buffer"]) def test_reset_subproc_env(data): diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 5d53c39b7..b256f0168 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -598,8 +598,8 @@ class VectorReplayBuffer(ReplayBufferManager): def __init__( self, total_size: int, buffer_num: int, **kwargs: Any ) -> None: - assert buffer_num > 0 and total_size % buffer_num == 0 - size = total_size // buffer_num + assert buffer_num > 0 + size = int(np.ceil(total_size / buffer_num)) buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)] super().__init__(buffer_list) @@ -608,8 +608,8 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): def __init__( self, total_size: int, buffer_num: int, **kwargs: Any ) -> None: - assert buffer_num > 0 and total_size % buffer_num == 0 - size = total_size // buffer_num + assert buffer_num > 0 + size = int(np.ceil(total_size / buffer_num)) buffer_list = [PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)] super().__init__(buffer_list) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1699e1dea..d1666c3f5 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -46,14 +46,18 @@ class Collector(object): policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') + replay_buffer = ReplayBuffer(size=10000) + # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well + vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num = 3) + # buffer_num should be equal(suggested) or larger than envs number envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) - collector = Collector(policy, envs, buffer=replay_buffer) + collector = Collector(policy, envs, buffer=vec_buffer) # collect 3 episodes collector.collect(n_episode=3) From e683c083cc53911e7bb21bf339b4ede524c91d57 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Mon, 8 Feb 2021 15:57:15 +0800 Subject: [PATCH 055/104] change all api to coordinate with new collector and vec buffer --- README.md | 2 +- docs/tutorials/tictactoe.rst | 1 + examples/atari/atari_bcq.py | 2 +- examples/atari/atari_c51.py | 4 ++-- examples/atari/atari_dqn.py | 6 +++--- examples/atari/atari_qrdqn.py | 4 ++-- examples/atari/runnable/pong_a2c.py | 2 +- examples/atari/runnable/pong_ppo.py | 2 +- examples/box2d/acrobot_dualdqn.py | 6 ++++-- examples/box2d/bipedal_hardcore_sac.py | 6 ++++-- examples/box2d/lunarlander_dqn.py | 6 ++++-- examples/box2d/mcc_sac.py | 6 ++++-- examples/mujoco/mujoco_sac.py | 6 ++++-- examples/mujoco/runnable/ant_v2_ddpg.py | 6 ++++-- examples/mujoco/runnable/ant_v2_td3.py | 6 ++++-- .../runnable/halfcheetahBullet_v0_sac.py | 6 ++++-- examples/mujoco/runnable/point_maze_td3.py | 6 ++++-- test/continuous/test_ddpg.py | 8 +++---- test/continuous/test_ppo.py | 4 +++- test/continuous/test_sac_with_il.py | 4 +++- test/continuous/test_td3.py | 4 +++- test/discrete/test_a2c_with_il.py | 4 +++- test/discrete/test_c51.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 3 ++- test/discrete/test_pg.py | 4 +++- test/discrete/test_ppo.py | 4 +++- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_sac.py | 4 +++- test/modelbase/test_psrl.py | 6 ++++-- test/multiagent/tic_tac_toe.py | 4 +++- tianshou/data/collector.py | 21 +++++++------------ 32 files changed, 92 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 7226d7c52..1c845e545 100644 --- a/README.md +++ b/README.md @@ -158,7 +158,7 @@ Currently, the overall code of Tianshou platform is less than 2500 lines. Most o ```python result = collector.collect(n_step=n) ``` - +# TODO remove this If you have 3 environments in total and want to collect 1 episode in the first environment, 3 for the third environment: ```python diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index bda6ff17a..11e5f71e4 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -128,6 +128,7 @@ Tianshou already provides some builtin classes for multi-agent learning. You can >>> >>> # use collectors to collect a episode of trajectories >>> # the reward is a vector, so we need a scalar metric to monitor the training + # TODO remove reward_metric >>> collector = Collector(policy, env, reward_metric=lambda x: x[0]) >>> >>> # you will see a long trajectory showing the board status at each timestep diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index e9edb8310..0db587700 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -130,7 +130,7 @@ def watch(): test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 60f62cd08..59b616c99 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -95,7 +95,7 @@ def test_c51(args=get_args()): save_only_last_obs=True, stack_num=args.frames_stack) # collector - train_collector = Collector(policy, train_envs, buffer) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'c51') @@ -132,7 +132,7 @@ def watch(): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index a2a907147..2f611808b 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -91,7 +91,7 @@ def test_dqn(args=get_args()): save_only_last_obs=True, stack_num=args.frames_stack) # collector - train_collector = Collector(policy, train_envs, buffer) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'dqn') @@ -142,8 +142,8 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, + render=args.render) pprint.pprint(result) if args.watch: diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index a292e2496..7628a1bdb 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -93,7 +93,7 @@ def test_qrdqn(args=get_args()): save_only_last_obs=True, stack_num=args.frames_stack) # collector - train_collector = Collector(policy, train_envs, buffer) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') @@ -130,7 +130,7 @@ def watch(): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 4b0fd87af..bc1001dd3 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -77,7 +77,7 @@ def test_a2c(args=get_args()): train_collector = Collector( policy, train_envs, VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)), - preprocess_fn=preprocess_fn) + preprocess_fn=preprocess_fn, exploration_noise=True) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c')) diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 55dc0760b..d081df38b 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -81,7 +81,7 @@ def test_ppo(args=get_args()): train_collector = Collector( policy, train_envs, VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)), - preprocess_fn=preprocess_fn) + preprocess_fn=preprocess_fn, exploration_noise=True) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo')) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index d114dd58e..00f459893 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -72,7 +72,9 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) @@ -114,7 +116,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index a9d38b309..1a7c45ef8 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -125,7 +125,9 @@ def test_sac_bipedal(args=get_args()): # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -151,7 +153,7 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index d0c7b83ed..24a194ebd 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -73,7 +73,9 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) @@ -110,7 +112,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 18b1af443..795766e7d 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -94,7 +94,9 @@ def test_sac(args=get_args()): exploration_noise=OUNoise(0.0, args.noise_std)) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -119,7 +121,7 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 107cfbf34..bd0ec1cc0 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -108,7 +108,9 @@ def test_sac(args=get_args()): # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'sac') @@ -120,7 +122,7 @@ def watch(): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index 686a9f677..6b61d6c9f 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -74,7 +74,9 @@ def test_ddpg(args=get_args()): reward_normalization=True, ignore_done=True) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log writer = SummaryWriter(args.logdir + '/' + 'ddpg') @@ -94,7 +96,7 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index 36863acde..ef446a3c4 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -82,7 +82,9 @@ def test_td3(args=get_args()): reward_normalization=True, ignore_done=True) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -103,7 +105,7 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index 666c8974f..aab38f939 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -82,7 +82,9 @@ def test_sac(args=get_args()): reward_normalization=True, ignore_done=True) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -105,7 +107,7 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index 78f0db22e..cfd4aa6d0 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -87,7 +87,9 @@ def test_td3(args=get_args()): reward_normalization=True, ignore_done=True) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -111,7 +113,7 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 20c326be6..742cece4a 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -25,7 +25,6 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--exploration-noise', type=float, default=0.1) - parser.add_argument('--test-noise', type=float, default=0.1) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--collect-per-step', type=int, default=4) @@ -86,9 +85,10 @@ def test_ddpg(args=get_args()): estimation_step=args.n_step) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) - test_collector = Collector( - policy, test_envs, action_noise=GaussianNoise(sigma=args.test_noise)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) + test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ddpg') writer = SummaryWriter(log_path) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index b7556683a..8e5710dcb 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -104,7 +104,9 @@ def dist(*logits): gae_lambda=args.gae_lambda) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 27701f0ca..e4b2111d2 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -92,7 +92,9 @@ def test_sac_with_il(args=get_args()): estimation_step=args.n_step) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index c8f5537e8..15675f852 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -97,7 +97,9 @@ def test_td3(args=get_args()): estimation_step=args.n_step) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index ce6d09bee..bdc747a49 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -79,7 +79,9 @@ def test_a2c_with_il(args=get_args()): max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'a2c') diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index a4d496792..38dcbebd5 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -81,7 +81,7 @@ def test_c51(args=get_args()): else: buf = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf) + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index b83b63ff1..4b410865e 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -83,7 +83,7 @@ def test_dqn(args=get_args()): else: buf = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf) + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 758e6ae10..80a8dc47c 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -68,7 +68,8 @@ def test_drqn(args=get_args()): train_collector = Collector( policy, train_envs, VectorReplayBuffer( args.buffer_size, buffer_num = len(train_envs), - stack_num=args.stack_num, ignore_obs_next=True)) + stack_num=args.stack_num, ignore_obs_next=True), + exploration_noise=True) # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs) # policy.set_eps(1) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index f00803aea..25606bebd 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -65,7 +65,9 @@ def test_pg(args=get_args()): reward_normalization=args.rew_norm) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'pg') diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index dcaea39cc..2a17d8194 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -91,7 +91,9 @@ def test_ppo(args=get_args()): value_clip=args.value_clip) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index b63553f82..992d22b50 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -79,7 +79,7 @@ def test_qrdqn(args=get_args()): else: buf = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf) + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 7357947bc..3808c1cf8 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -87,7 +87,9 @@ def test_discrete_sac(args=get_args()): ignore_done=args.ignore_done) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 2e84bdd1d..889f9888b 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -61,7 +61,9 @@ def test_psrl(args=get_args()): args.add_done_loop) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log writer = SummaryWriter(args.logdir + '/' + args.task) @@ -86,7 +88,7 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') elif env.spec.reward_threshold: diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index c5207b88c..bb2b7ec1b 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -124,7 +124,9 @@ def env_func(): # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index d1666c3f5..1f2b42733 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -13,7 +13,6 @@ class Collector(object): # TODO change doc - # TODO change all test to ensure api """Collector enables the policy to interact with different types of envs. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` @@ -25,14 +24,9 @@ class Collector(object): :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults to None. - :param BaseNoise action_noise: add a noise to continuous action. Normally - a policy already has a noise param for exploration in training phase, - so this is recommended to use in test collector for some purpose. - :param function reward_metric: to be used in multi-agent RL. The reward to - report is of shape [agent_num], but we need to return a single scalar - to monitor training. This function specifies what is the desired - metric, e.g., the reward of agent 1 or the average reward over all - agents. By default, the behavior is to select the reward of agent 1. + :param exploration_noise: a flag which determines when the collector is used for + training. If so, function exploration_noise() in policy will be called + automatically to add exploration noise. Default to True. The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in @@ -54,7 +48,7 @@ class Collector(object): # the collector supports vectorized environments as well vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num = 3) - # buffer_num should be equal(suggested) or larger than envs number + # buffer_num should be equal(suggested) to or larger than envs number envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) collector = Collector(policy, envs, buffer=vec_buffer) @@ -83,16 +77,15 @@ def __init__( env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, - training: bool = True, + exploration_noise: bool = False, ) -> None: - # TODO update training in all test/examples, remove action noise super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) assert env.is_async is False, "Please use AsyncCollector if ..." self.env = env self.env_num = len(env) - self.training = training + self.exploration_noise = exploration_noise self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn @@ -252,7 +245,7 @@ def collect( if state is not None: policy.hidden_state = state # save state into buffer act = to_numpy(result.act) - if self.training and not random: + if self.exploration_noise and not random: act = self.policy.exploration_noise(act, self.data) self.data.update(policy=policy, act=act) From 8e12504a7c071a302d1b0b17b48f585fddb0b812 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 8 Feb 2021 16:44:20 +0800 Subject: [PATCH 056/104] small update --- tianshou/data/collector.py | 66 ++++++++++++++++---------------- tianshou/policy/modelfree/a2c.py | 1 - 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1f2b42733..ab5f8d822 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -11,8 +11,8 @@ VectorReplayBuffer, CachedReplayBuffer, to_numpy +# TODO change doc class Collector(object): - # TODO change doc """Collector enables the policy to interact with different types of envs. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` @@ -20,13 +20,13 @@ class Collector(object): :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` - class. If set to ``None`` (testing phase), it will not store the data. + class. If set to None (testing phase), it will not store the data. :param function preprocess_fn: a function called before the data has been - added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults + added to the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None. - :param exploration_noise: a flag which determines when the collector is used for - training. If so, function exploration_noise() in policy will be called - automatically to add exploration noise. Default to True. + :param exploration_noise: a flag which determines when the collector is + used for training. If so, function exploration_noise() in policy will + be called automatically to add exploration noise. Default to True. The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in @@ -40,7 +40,7 @@ class Collector(object): policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') - + replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment @@ -48,7 +48,7 @@ class Collector(object): # the collector supports vectorized environments as well vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num = 3) - # buffer_num should be equal(suggested) to or larger than envs number + # buffer_num should be equal (suggested) to or larger than #envs envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) collector = Collector(policy, envs, buffer=vec_buffer) @@ -58,17 +58,13 @@ class Collector(object): # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the - # sleep time between rendering consecutive frames) + # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) - Collected data always consist of full episodes. So if only ``n_step`` - argument is give, the collector may return the data more than the - ``n_step`` limitation. Same as ``n_episode`` for the multiple environment - case. - .. note:: - Please make sure the given environment has a time limitation. + Please make sure the given environment has a time limitation if using + n_episode collect option. """ def __init__( @@ -115,7 +111,7 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: f"collect {self.env_num} envs,\n\tplease use {vector_type}" f"(total_size={buffer.maxsize}, buffer_num={self.env_num}," " ...) instead.") - self.buffer =buffer + self.buffer = buffer # TODO move to trainer # @staticmethod @@ -229,8 +225,8 @@ def collect( # get the next action if random: - result = Batch(act=[self._action_space[i].sample() - for i in ready_env_ids]) + self.data.update(act=[self._action_space[i].sample() + for i in ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -238,20 +234,20 @@ def collect( result = self.policy(self.data, last_state) else: result = self.policy(self.data, last_state) - # update state / act / policy into self.data - policy = result.get("policy", Batch()) - assert isinstance(policy, Batch) - state = result.get("state", None) - if state is not None: - policy.hidden_state = state # save state into buffer - act = to_numpy(result.act) - if self.exploration_noise and not random: - act = self.policy.exploration_noise(act, self.data) - self.data.update(policy=policy, act=act) + # update state / act / policy into self.data + policy = result.get("policy", Batch()) + assert isinstance(policy, Batch) + state = result.get("state", None) + if state is not None: + policy.hidden_state = state # save state into buffer + act = to_numpy(result.act) + if self.exploration_noise: + act = self.policy.exploration_noise(act, self.data) + self.data.update(policy=policy, act=act) # step in env obs_next, rew, done, info = self.env.step( - act, id=ready_env_ids) + self.data.act, id=ready_env_ids) self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: @@ -290,16 +286,18 @@ def collect( self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) - - # remove surplus env id from ready_env_ids to avoid bias in + + # Remove surplus env id from ready_env_ids to avoid bias in # selecting environments. if n_episode: - surplus_env_n = len(ready_env_ids) - (n_episode - episode_count) - if surplus_env_n > 0: + episode_to_collect = n_episode - episode_count + surplus_env_num = len(ready_env_ids) - episode_to_collect + if surplus_env_num > 0: mask = np.ones_like(ready_env_ids, np.bool) - mask[env_ind_local[:surplus_env_n]] = False + mask[env_ind_local[:surplus_env_num]] = False ready_env_ids = ready_env_ids[mask] self.data = self.data[mask] + self.data.obs = self.data.obs_next if (n_step and step_count >= n_step) or \ diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index b50eb7fb6..1c8edb300 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -80,7 +80,6 @@ def process_fn( batch, buffer, indice, v_, gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm) - def forward( self, From 05140aa952d72c8884e4b6497f41d2b3108adde0 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Mon, 8 Feb 2021 17:44:01 +0800 Subject: [PATCH 057/104] fix some bug --- test/continuous/test_ddpg.py | 2 +- test/throughput/test_buffer_profile.py | 3 +-- tianshou/trainer/offpolicy.py | 7 ++++--- tianshou/trainer/onpolicy.py | 7 ++++--- tianshou/trainer/utils.py | 10 ++++++++-- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 742cece4a..5d965ff98 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=4) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index 3134004f1..be82f0b7f 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer, +from tianshou.data import (PrioritizedReplayBuffer, ReplayBuffer, SegmentTree) @@ -29,7 +29,6 @@ def test_init(): for _ in np.arange(1e5): _ = ReplayBuffer(1e5) _ = PrioritizedReplayBuffer(size=int(1e5), alpha=0.5, beta=0.5) - _ = ListReplayBuffer() def test_add(data): diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index d1d487fe1..847ad06f5 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -99,9 +99,10 @@ def offpolicy_trainer( "n/st": str(int(result["n/st"])), } if writer and env_step % log_interval == 0: - for k in result.keys(): - writer.add_scalar( - "train/" + k, result[k], global_step=env_step) + writer.add_scalar( + "train/rew", result['rews'].mean(), global_step=env_step) + writer.add_scalar( + "train/len", result['lens'].mean(), global_step=env_step) if test_in_train and stop_fn and stop_fn(result["rews"].mean()): test_result = test_episode( policy, test_collector, test_fn, diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 584c6a5bc..fe932506b 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -99,9 +99,10 @@ def onpolicy_trainer( "n/st": str(int(result["n/st"])), } if writer and env_step % log_interval == 0: - for k in result.keys(): - writer.add_scalar( - "train/" + k, result[k], global_step=env_step) + writer.add_scalar( + "train/rew", result['rews'].mean(), global_step=env_step) + writer.add_scalar( + "train/len", result['lens'].mean(), global_step=env_step) if test_in_train and stop_fn and stop_fn(result["rews"].mean()): test_result = test_episode( policy, test_collector, test_fn, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index cb9a2ef2a..37a33f321 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -24,8 +24,14 @@ def test_episode( test_fn(epoch, global_step) result = collector.collect(n_episode=n_episode) if writer is not None and global_step is not None: - for k in result.keys(): - writer.add_scalar("test/" + k, result[k], global_step=global_step) + writer.add_scalar( + "test/rew", result['rews'].mean(), global_step=global_step) + writer.add_scalar( + "test/rew_std", result['rews'].std(), global_step=global_step) + writer.add_scalar( + "test/len", result['lens'].mean(), global_step=global_step) + writer.add_scalar( + "test/len_std", result['lens'].std(), global_step=global_step) return result From f96f81b5bf8432e3b73f4d56d36d198d28079287 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Mon, 8 Feb 2021 18:40:15 +0800 Subject: [PATCH 058/104] coordinate test files and fix on bug on summary writer --- docs/tutorials/tictactoe.rst | 2 +- examples/atari/atari_c51.py | 2 +- examples/atari/atari_dqn.py | 2 +- examples/atari/atari_qrdqn.py | 2 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/bipedal_hardcore_sac.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- examples/box2d/mcc_sac.py | 2 +- examples/mujoco/mujoco_sac.py | 2 +- examples/mujoco/runnable/ant_v2_ddpg.py | 2 +- examples/mujoco/runnable/ant_v2_td3.py | 2 +- .../runnable/halfcheetahBullet_v0_sac.py | 2 +- examples/mujoco/runnable/point_maze_td3.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_sac_with_il.py | 2 +- test/continuous/test_td3.py | 2 +- test/discrete/test_c51.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_sac.py | 2 +- test/multiagent/tic_tac_toe.py | 2 +- tianshou/data/buffer.py | 2 + tianshou/policy/base.py | 3 +- tianshou/trainer/offpolicy.py | 41 +++++++++---------- tianshou/trainer/onpolicy.py | 4 +- 26 files changed, 47 insertions(+), 47 deletions(-) diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 11e5f71e4..605528b01 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -202,7 +202,7 @@ The explanation of each Tianshou class/function will be deferred to their first parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--collect-per-step', type=int, default=8) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 59b616c99..cb8f0d296 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -33,7 +33,7 @@ def get_args(): parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 2f611808b..2d81747d4 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -30,7 +30,7 @@ def get_args(): parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 7628a1bdb..2a23926de 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 00f459893..272087f21 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -34,7 +34,7 @@ def get_args(): nargs='*', default=[128, 128]) parser.add_argument('--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 1a7c45ef8..ea2d4f239 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -32,7 +32,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 24a194ebd..1471a86a6 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -35,7 +35,7 @@ def get_args(): nargs='*', default=[128, 128]) parser.add_argument('--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 795766e7d..4c54659a6 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -34,7 +34,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=5) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index bd0ec1cc0..894e59f81 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -35,7 +35,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=4) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index 6b61d6c9f..dd95c2907 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -30,7 +30,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=4) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index ef446a3c4..be59534f9 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -33,7 +33,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index aab38f939..58c8e4e20 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -32,7 +32,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=4) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--log-interval', type=int, default=100) diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index cfd4aa6d0..9e7f03faa 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -35,7 +35,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 8e5710dcb..da75f3368 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=1) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index e4b2111d2..51f35341f 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -33,7 +33,7 @@ def get_args(): nargs='*', default=[128, 128]) parser.add_argument('--imitation-hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 15675f852..998b80e5e 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -34,7 +34,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 38dcbebd5..50c940492 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--collect-per-step', type=int, default=8) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 4b410865e..c0268953d 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 80a8dc47c..46e05fe48 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -30,7 +30,7 @@ def get_args(): parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=3) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 992d22b50..f721fac3a 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 3808c1cf8..370d93dcb 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -32,7 +32,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=5) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.0) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index bb2b7ec1b..94cc523ea 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -33,7 +33,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.1) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b256f0168..318233a6e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -50,6 +50,8 @@ def __init__( ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, + # TODO talk about + **kwargs, ) -> None: self.options: Dict[str, Any] = { "stack_num": stack_num, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index c8bc9b9cb..f0e4137e9 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -297,7 +297,8 @@ def compute_nstep_return( terminal = indices[-1] with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) - target_q = to_numpy(target_q_torch) * BasePolicy.value_mask(batch) + target_q = to_numpy(target_q_torch) * \ + BasePolicy.value_mask(batch).reshape(-1, 1) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True target_q = _nstep_return(rew, end_flag, target_q, indices, diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 847ad06f5..a6deb8131 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -94,30 +94,29 @@ def offpolicy_trainer( data = { "env_step": str(env_step), "rew": f"{result['rews'].mean():.2f}", - "len": str(int(result["lens"].mean())), + "len": str(result["lens"].mean()), "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), } - if writer and env_step % log_interval == 0: - writer.add_scalar( - "train/rew", result['rews'].mean(), global_step=env_step) - writer.add_scalar( - "train/len", result['lens'].mean(), global_step=env_step) - if test_in_train and stop_fn and stop_fn(result["rews"].mean()): - test_result = test_episode( - policy, test_collector, test_fn, - epoch, episode_per_test, writer, env_step) - if stop_fn(test_result["rews"].mean()): - if save_fn: - save_fn(policy) - for k in result.keys(): - data[k] = f"{result[k]:.2f}" - t.set_postfix(**data) - return gather_info( - start_time, train_collector, test_collector, - test_result["rews"].mean(), test_result["rew_std"].std()) - else: - policy.train() + if result["n/ep"] > 0: + if writer and env_step % log_interval == 0: + writer.add_scalar( + "train/rew", result['rews'].mean(), global_step=env_step) + writer.add_scalar( + "train/len", result['lens'].mean(), global_step=env_step) + if test_in_train and stop_fn and stop_fn(result["rews"].mean()): + test_result = test_episode( + policy, test_collector, test_fn, + epoch, episode_per_test, writer, env_step) + if stop_fn(test_result["rews"].mean()): + if save_fn: + save_fn(policy) + t.set_postfix(**data) + return gather_info( + start_time, train_collector, test_collector, + test_result["rews"].mean(), test_result["rews"].std()) + else: + policy.train() for i in range(update_per_step * min( result["n/st"] // collect_per_step, t.total - t.n)): gradient_step += 1 diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index fe932506b..27b15b262 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -110,12 +110,10 @@ def onpolicy_trainer( if stop_fn(test_result["rews"].mean()): if save_fn: save_fn(policy) - for k in result.keys(): - data[k] = f"{result[k]:.2f}" t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rews"].mean(), test_result["rew_std"]) + test_result["rews"].mean(), test_result["rews"].std()) else: policy.train() losses = policy.update( From 1f105d5fba11abfdc99489876ad493a8caf5e39c Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 8 Feb 2021 21:24:07 +0800 Subject: [PATCH 059/104] AsyncCollector, still need more test --- test/base/test_collector.py | 45 ++------ tianshou/data/__init__.py | 3 +- tianshou/data/buffer.py | 2 + tianshou/data/collector.py | 198 ++++++++++++++++++++++++++++++++++-- 4 files changed, 202 insertions(+), 46 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index d4823d3d7..ea55a0c64 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -4,8 +4,8 @@ from tianshou.policy import BasePolicy from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.data import Collector, Batch, ReplayBuffer, VectorReplayBuffer, \ - CachedReplayBuffer +from tianshou.data import Batch, Collector, AsyncCollector +from tianshou.data import ReplayBuffer, VectorReplayBuffer, CachedReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -145,40 +145,13 @@ def test_collector_with_async(): venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MyPolicy() - c1 = Collector( - policy, venv, ReplayBuffer(size=1000, ignore_obs_next=False), + c1 = AsyncCollector( + policy, venv, + VectorReplayBuffer(total_size=100, buffer_num=4), logger.preprocess_fn) - c1.collect(n_episode=10) - # check if the data in the buffer is chronological - # i.e. data in the buffer are full episodes, and each episode is - # returned by the same environment - env_id = c1.buffer.info['env_id'] - size = len(c1.buffer) - obs = c1.buffer.obs[:size] - done = c1.buffer.done[:size] - obs_ground_truth = [] - i = 0 - while i < size: - # i is the start of an episode - if done[i]: - # this episode has one transition - assert env_lens[env_id[i]] == 1 - i += 1 - continue - j = i - while True: - j += 1 - # in one episode, the environment id is the same - assert env_id[j] == env_id[i] - if done[j]: - break - j = j + 1 # j is the start of the next episode - assert j - i == env_lens[env_id[i]] - obs_ground_truth += list(range(j - i)) - i = j - obs_ground_truth = np.expand_dims( - np.array(obs_ground_truth), axis=-1) - assert np.allclose(obs, obs_ground_truth) + result = c1.collect(n_episode=10) + print(result) + print(c1.buffer) def test_collector_with_dict_state(): @@ -413,4 +386,4 @@ def test_collector_with_atari_setting(): test_collector_with_dict_state() test_collector_with_ma() test_collector_with_atari_setting() - # test_collector_with_async() + test_collector_with_async() diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index f32603d87..72db8401d 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -4,7 +4,7 @@ from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer, \ ReplayBufferManager, PrioritizedReplayBufferManager, \ VectorReplayBuffer, PrioritizedVectorReplayBuffer, CachedReplayBuffer -from tianshou.data.collector import Collector +from tianshou.data.collector import Collector, AsyncCollector __all__ = [ "Batch", @@ -20,4 +20,5 @@ "PrioritizedVectorReplayBuffer", "CachedReplayBuffer", "Collector", + "AsyncCollector", ] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b256f0168..fa3eac88a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -15,6 +15,8 @@ def _alloc_by_keys_diff( if key in meta.keys(): if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): _alloc_by_keys_diff(meta[key], batch[key], size, stack) + elif isinstance(meta[key], Batch) and meta[key].is_empty(): + meta[key] = _create_value(batch[key], size, stack) else: meta[key] = _create_value(batch[key], size, stack) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index ab5f8d822..bdfe11314 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -3,12 +3,13 @@ import torch import warnings import numpy as np -from typing import Dict, List, Union, Optional, Callable +from typing import Any, Dict, List, Union, Optional, Callable from tianshou.policy import BasePolicy from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager, \ VectorReplayBuffer, CachedReplayBuffer, to_numpy +from tianshou.data.buffer import _alloc_by_keys_diff # TODO change doc @@ -78,7 +79,6 @@ def __init__( super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) - assert env.is_async is False, "Please use AsyncCollector if ..." self.env = env self.env_num = len(env) self.exploration_noise = exploration_noise @@ -144,14 +144,12 @@ def reset_buffer(self) -> None: self.buffer.reset() def reset_env(self) -> None: - """Reset all of the environment(s)' states and the cache buffers.""" + """Reset all of the environment(s).""" obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) self.data.obs = obs - if isinstance(self.buffer, CachedReplayBuffer): - for buf in self.buffer.cached_buffers: - buf.reset() + self._ready_env_ids = np.arange(self.env_num) def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" @@ -198,13 +196,12 @@ def collect( """ # collect at least n_step or n_episode # TODO: modify docs, tell the constraints + assert self.env.is_async is False, "Please use AsyncCollector if ..." if n_step is not None: assert n_episode is None, ( "Only one of n_step or n_episode is allowed in Collector." f"collect, got n_step={n_step}, n_episode={n_episode}.") assert n_step > 0 - assert n_step % self.env_num == 0, \ - "n_step should be a multiple of #envs" ready_env_ids = np.arange(self.env_num) else: assert n_episode > 0 @@ -226,7 +223,7 @@ def collect( # get the next action if random: self.data.update(act=[self._action_space[i].sample() - for i in ready_env_ids]) + for i in ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -325,3 +322,186 @@ def collect( "n/ep": episode_count, "n/st": step_count, "rews": rews, "lens": lens, "idxs": idxs, } + + +class AsyncCollector(Collector): + """docstring for AsyncCollector""" + + def __init__( + self, + policy: BasePolicy, + env: Union[gym.Env, BaseVectorEnv], + buffer: Optional[ReplayBuffer] = None, + preprocess_fn: Optional[Callable[..., Batch]] = None, + exploration_noise: bool = False, + ) -> None: + super().__init__(policy, env, buffer, preprocess_fn, exploration_noise) + + def collect( + self, + n_step: Optional[int] = None, + n_episode: Optional[int] = None, + random: bool = False, + render: Optional[float] = None, + no_grad: bool = True, + ) -> Dict[str, float]: + """Collect a specified number of step or episode. + + :param int n_step: how many steps you want to collect. + :param n_episode: how many episodes you want to collect. + :param bool random: whether to use random policy for collecting data. + Default to False. + :param float render: the sleep time between rendering consecutive + frames. Default to None (no rendering). + :param bool no_grad: whether to retain gradient in policy.forward. + Default to True (no gradient retaining). + + .. note:: + + One and only one collection number specification is permitted, + either ``n_step`` or ``n_episode``. + + :return: A dict including the following keys + + * ``n/ep`` the collected number of episodes. + * ``n/st`` the collected number of steps. + * ``rews`` the list of episode reward over collected episodes. + * ``lens`` the list of episode length over collected episodes. + * ``idxs`` the list of episode start index over collected episodes. + """ + # collect at least n_step or n_episode + # TODO: modify docs, tell the constraints + if n_step is not None: + assert n_episode is None, ( + "Only one of n_step or n_episode is allowed in Collector." + f"collect, got n_step={n_step}, n_episode={n_episode}.") + assert n_step > 0 + else: + assert n_episode > 0 + warnings.warn("Using n_episode under async setting may collect " + "extra frames into buffer.") + + finished_env_ids = [] + ready_env_ids = self._ready_env_ids + + start_time = time.time() + + step_count = 0 + episode_count = 0 + episode_rews = [] + episode_lens = [] + episode_start_indices = [] + + while True: + whole_data = self.data + self.data = self.data[ready_env_ids] + assert len(whole_data) == self.env_num # major difference + # restore the state: if the last state is None, it won't store + last_state = self.data.policy.pop("hidden_state", None) + + # get the next action + if random: + self.data.update(act=[self._action_space[i].sample() + for i in ready_env_ids]) + else: + if no_grad: + with torch.no_grad(): # faster than retain_grad version + # self.data.obs will be used by agent to get result + result = self.policy(self.data, last_state) + else: + result = self.policy(self.data, last_state) + # update state / act / policy into self.data + policy = result.get("policy", Batch()) + assert isinstance(policy, Batch) + state = result.get("state", None) + if state is not None: + policy.hidden_state = state # save state into buffer + act = to_numpy(result.act) + if self.exploration_noise: + act = self.policy.exploration_noise(act, self.data) + self.data.update(policy=policy, act=act) + + # save act/policy before env.step + try: + whole_data.act[ready_env_ids] = self.data.act + whole_data.policy[ready_env_ids] = self.data.policy + except ValueError: + _alloc_by_keys_diff(whole_data, self.data, self.env_num, False) + whole_data[ready_env_ids] = self.data # lots of overhead + + # step in env + obs_next, rew, done, info = self.env.step( + self.data.act, id=ready_env_ids) + + # change self.data here because ready_env_ids has changed + ready_env_ids = np.array([i["env_id"] for i in info]) + self.data = whole_data[ready_env_ids] + + self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) + if self.preprocess_fn: + self.data.update(self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + )) + + if render: + self.env.render() + if render > 0 and not np.isclose(render, 0): + time.sleep(render) + + # add data into the buffer + ptr, ep_rew, ep_len, ep_idx = self.buffer.add( + self.data, buffer_ids=ready_env_ids) + + # collect statistics + step_count += len(ready_env_ids) + + if np.any(done): + env_ind_local = np.where(done)[0] + env_ind_global = ready_env_ids[env_ind_local] + episode_count += len(env_ind_local) + episode_lens.append(ep_len[env_ind_local]) + episode_rews.append(ep_rew[env_ind_local]) + episode_start_indices.append(ep_idx[env_ind_local]) + # now we copy obs_next to obs, but since there might be + # finished episodes, we have to reset finished envs first. + obs_reset = self.env.reset(env_ind_global) + if self.preprocess_fn: + obs_reset = self.preprocess_fn( + obs=obs_reset).get("obs", obs_reset) + self.data.obs_next[env_ind_local] = obs_reset + for i in env_ind_local: + self._reset_state(i) + + try: + whole_data.obs[ready_env_ids] = self.data.obs_next + whole_data.rew[ready_env_ids] = self.data.rew + whole_data.done[ready_env_ids] = self.data.done + whole_data.info[ready_env_ids] = self.data.info + except ValueError: + _alloc_by_keys_diff(whole_data, self.data, self.env_num, False) + whole_data[ready_env_ids] = self.data # lots of overhead + self.data = whole_data + + if (n_step and step_count >= n_step) or \ + (n_episode and episode_count >= n_episode): + break + + # generate statistics + self.collect_step += step_count + self.collect_episode += episode_count + self.collect_time += max(time.time() - start_time, 1e-9) + + if episode_count > 0: + rews, lens, idxs = list(map(np.concatenate, [ + episode_rews, episode_lens, episode_start_indices])) + else: + rews, lens, idxs = \ + np.array([]), np.array([], np.int), np.array([], np.int) + + return { + "n/ep": episode_count, "n/st": step_count, + "rews": rews, "lens": lens, "idxs": idxs, + } From 0a0dd26de1c5ddcba7d80b3df899e3a1a2cb0674 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 00:08:48 +0800 Subject: [PATCH 060/104] AsyncCollector test --- test/base/test_collector.py | 39 +++++++++++++++++++++++++++++++------ tianshou/data/collector.py | 3 +++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index ea55a0c64..aee437c80 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,3 +1,4 @@ +import tqdm import pytest import numpy as np from torch.utils.tensorboard import SummaryWriter @@ -129,7 +130,7 @@ def test_collector(): num = len(env_fns) c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num)) - for i in range(num, 400): + for i in tqdm.trange(num, 400, desc="test step collector n_episode"): c3.reset() result = c3.collect(n_episode=i) assert result['n/ep'] == i @@ -140,18 +141,44 @@ def test_collector_with_async(): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True) + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.01, random_sleep=True) for i in env_lens] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MyPolicy() + bufsize = 300 c1 = AsyncCollector( policy, venv, - VectorReplayBuffer(total_size=100, buffer_num=4), + VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), logger.preprocess_fn) - result = c1.collect(n_episode=10) - print(result) - print(c1.buffer) + ptr = [0, 0, 0, 0] + for n_episode in tqdm.trange(1, 100, desc="test async n_episode"): + result = c1.collect(n_episode=n_episode) + assert result["n/ep"] >= n_episode + # check buffer data, obs and obs_next, env_id + for i, count in enumerate( + np.bincount(result["lens"], minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape( + count, env_len) == seq + 1) + # test async n_step, for now the buffer should be full of data + for n_step in tqdm.trange(1, 150, desc="test async n_step"): + result = c1.collect(n_step=n_step) + assert result["n/st"] >= n_step + for i in range(4): + env_len = i + 2 + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id == i) + assert np.all(buf.obs.reshape(-1, env_len) == seq) + assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) def test_collector_with_dict_state(): diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index bdfe11314..7167a9bf3 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -482,6 +482,7 @@ def collect( whole_data.info[ready_env_ids] = self.data.info except ValueError: _alloc_by_keys_diff(whole_data, self.data, self.env_num, False) + self.data.obs = self.data.obs_next whole_data[ready_env_ids] = self.data # lots of overhead self.data = whole_data @@ -489,6 +490,8 @@ def collect( (n_episode and episode_count >= n_episode): break + self._ready_env_ids = ready_env_ids + # generate statistics self.collect_step += step_count self.collect_episode += episode_count From 3dbb65f095b828d816236d909bcdc6707e33e08d Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 10:00:59 +0800 Subject: [PATCH 061/104] flake8 maxlen 119 --- setup.cfg | 11 ++ tianshou/data/__init__.py | 12 +- tianshou/data/buffer.py | 222 +++++++++++++------------------------ tianshou/data/collector.py | 125 +++++++-------------- 4 files changed, 139 insertions(+), 231 deletions(-) diff --git a/setup.cfg b/setup.cfg index 0a4742891..7a9eb21cf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,14 @@ +[flake8] +exclude = + .git + log + __pycache__ + docs + build + dist + *.egg-info +max-line-length = 119 + [mypy] files = tianshou/**/*.py allow_redefinition = True diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 72db8401d..137cf171f 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,9 +1,15 @@ from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree -from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer, \ - ReplayBufferManager, PrioritizedReplayBufferManager, \ - VectorReplayBuffer, PrioritizedVectorReplayBuffer, CachedReplayBuffer +from tianshou.data.buffer import ( + ReplayBuffer, + PrioritizedReplayBuffer, + ReplayBufferManager, + PrioritizedReplayBufferManager, + VectorReplayBuffer, + PrioritizedVectorReplayBuffer, + CachedReplayBuffer, +) from tianshou.data.collector import Collector, AsyncCollector __all__ = [ diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 8c03db85f..0a876cadf 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -8,9 +8,7 @@ from tianshou.data.utils.converter import to_hdf5, from_hdf5 -def _alloc_by_keys_diff( - meta: Batch, batch: Batch, size: int, stack: bool = True -) -> None: +def _alloc_by_keys_diff(meta: Batch, batch: Batch, size: int, stack: bool = True) -> None: for key in batch.keys(): if key in meta.keys(): if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): @@ -22,28 +20,24 @@ def _alloc_by_keys_diff( class ReplayBuffer: - """:class:`~tianshou.data.ReplayBuffer` stores data generated from \ - interaction between the policy and environment. + """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. - ReplayBuffer can be considered as a specialized form (or management) of - Batch. It stores all the data in a batch with circular-queue style. + ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch + with circular-queue style. - For the example usage of ReplayBuffer, please check out Section Buffer in - :doc:`/tutorials/concepts`. + For the example usage of ReplayBuffer, please check out Section Buffer in :doc:`/tutorials/concepts`. :param int size: the maximum size of replay buffer. - :param int stack_num: the frame-stack sampling argument, should be greater - than or equal to 1, defaults to 1 (no stacking). + :param int stack_num: the frame-stack sampling argument, should be greater than or equal to 1. + Default to 1 (no stacking). :param bool ignore_obs_next: whether to store obs_next, defaults to False. - :param bool save_only_last_obs: only save the last obs/obs_next when it has - a shape of (timestep, ...) because of temporal stacking, defaults to - False. - :param bool sample_avail: the parameter indicating sampling only available - index when using frame-stack sampling method, defaults to False. + :param bool save_only_last_obs: only save the last obs/obs_next when it has a shape of (timestep, ...) because of + temporal stacking. Default to False. + :param bool sample_avail: the parameter indicating sampling only available index when using frame-stack sampling + method. Default to False. """ - _reserved_keys = ("obs", "act", "rew", "done", - "obs_next", "info", "policy") + _reserved_keys = ("obs", "act", "rew", "done", "obs_next", "info", "policy") def __init__( self, @@ -99,8 +93,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None: def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" - assert key not in self._reserved_keys, ( - "key '{}' is reserved and cannot be assigned".format(key)) + assert key not in self._reserved_keys, "key '{}' is reserved and cannot be assigned".format(key) super().__setattr__(key, value) def save_hdf5(self, path: str) -> None: @@ -109,9 +102,7 @@ def save_hdf5(self, path: str) -> None: to_hdf5(self.__dict__, f) @classmethod - def load_hdf5( - cls, path: str, device: Optional[str] = None - ) -> "ReplayBuffer": + def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer": """Load replay buffer from HDF5 file.""" with h5py.File(path, "r") as f: buf = cls.__new__(cls) @@ -125,16 +116,14 @@ def reset(self) -> None: def set_batch(self, batch: Batch) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" - assert len(batch) == self.maxsize and \ - set(batch.keys()).issubset(self._reserved_keys), \ + assert len(batch) == self.maxsize and set(batch.keys()).issubset(self._reserved_keys), \ "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch def unfinished_index(self) -> np.ndarray: """Return the index of unfinished episode.""" last = (self._index - 1) % self._size if self._size else 0 - return np.array( - [last] if not self.done[last] and self._size else [], np.int) + return np.array([last] if not self.done[last] and self._size else [], np.int) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: """Return the index of previous transition. @@ -176,13 +165,10 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._meta[to_indices] = buffer._meta[from_indices] return to_indices - def _add_index( - self, rew: Union[float, np.ndarray], done: bool - ) -> Tuple[int, Union[float, np.ndarray], int, int]: + def _add_index(self, rew: Union[float, np.ndarray], done: bool) -> Tuple[int, Union[float, np.ndarray], int, int]: """Maintain the buffer's state after adding one data batch. - Return (index_to_be_modified, episode_reward, episode_length, - episode_start_index). + Return (index_to_be_modified, episode_reward, episode_length, episode_start_index). """ ptr = self._index self._size = min(self._size + 1, self.maxsize) @@ -199,21 +185,17 @@ def _add_index( return ptr, self._ep_rew * 0.0, 0, self._ep_idx def add( - self, - batch: Batch, - buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. - :param Batch batch: the input data batch. Its keys must belong to the 7 - reserved keys, and "obs", "act", "rew", "done" is required. - :param buffer_ids: to make consistent with other buffer's add function; - if it is not None, we assume the input batch's first dimension is - always 1. + :param Batch batch: the input data batch. Its keys must belong to the 7 reserved keys, and "obs", "act", + "rew", "done" is required. + :param buffer_ids: to make consistent with other buffer's add function; if it is not None, we assume the input + batch's first dimension is always 1. - Return (current_index, episode_reward, episode_length, - episode_start_index). If the episode is not finished, the return value - of episode_length and episode_reward is 0. + Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, + the return value of episode_length and episode_reward is 0. """ # preprocess batch b = Batch() @@ -229,15 +211,13 @@ def add( if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: - batch.obs_next = \ - batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] + batch.obs_next = batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] # get ptr if stacked_batch: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done - ptr, ep_rew, ep_len, ep_idx = list(map( - lambda x: np.array([x]), self._add_index(rew, done))) + ptr, ep_rew, ep_len, ep_idx = list(map(lambda x: np.array([x]), self._add_index(rew, done))) try: self._meta[ptr] = batch except ValueError: @@ -254,24 +234,20 @@ def add( def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. - Return all available indices in the buffer if batch_size is 0; return - an empty numpy array if batch_size < 0 or no available index can be - sampled. + Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 + or no available index can be sampled. """ if self.stack_num == 1 or not self._sample_avail: # most often case if batch_size > 0: return np.random.choice(self._size, batch_size) elif batch_size == 0: # construct current available indices - return np.concatenate([ - np.arange(self._index, self._size), - np.arange(self._index)]) + return np.concatenate([np.arange(self._index, self._size), np.arange(self._index)]) else: return np.array([], np.int) else: if batch_size < 0: return np.array([], np.int) - all_indices = prev_indices = np.concatenate([ - np.arange(self._index, self._size), np.arange(self._index)]) + all_indices = prev_indices = np.concatenate([np.arange(self._index, self._size), np.arange(self._index)]) for _ in range(self.stack_num - 2): prev_indices = self.prev(prev_indices) all_indices = all_indices[prev_indices != self.prev(prev_indices)] @@ -295,12 +271,11 @@ def get( index: Union[int, np.integer, np.ndarray], key: str, default_value: Optional[Any] = None, - stack_num: Optional[int] = None, + stack_num: Optional[int] = None ) -> Union[Batch, np.ndarray]: """Return the stacked result. - E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the - index. + E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index. """ if key not in self._meta and default_value is not None: return default_value @@ -324,20 +299,17 @@ def get( raise e # val != Batch() return Batch() - def __getitem__( - self, index: Union[slice, int, np.integer, np.ndarray] - ) -> Batch: + def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch: """Return a data batch: self[index]. - If stack_num is larger than 1, return the stacked obs and obs_next with - shape (batch, len, ...). + If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, ...). """ if isinstance(index, slice): # change slice to np array if index == slice(None): # buffer[:] will get all available data index = self.sample_index(0) else: index = self._indices[:len(self)][index] - # raise KeyError first instead of AttributeError, to support np.array + # raise KeyError first instead of AttributeError, to support np.array([ReplayBuffer()]) obs = self.get(index, "obs") if self._save_obs_next: obs_next = self.get(index, "obs_next", Batch()) @@ -362,13 +334,10 @@ class PrioritizedReplayBuffer(ReplayBuffer): .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed - explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed explanation. """ - def __init__( - self, size: int, alpha: float, beta: float, **kwargs: Any - ) -> None: + def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None: super().__init__(size, **kwargs) assert alpha > 0.0 and beta >= 0.0 self._alpha, self._beta = alpha, beta @@ -386,9 +355,7 @@ def update(self, buffer: ReplayBuffer) -> np.ndarray: self.init_weight(indices) def add( - self, - batch: Batch, - buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids) self.init_weight(ptr) @@ -401,25 +368,18 @@ def sample_index(self, batch_size: int) -> np.ndarray: else: return super().sample_index(batch_size) - def get_weight( - self, index: Union[slice, int, np.integer, np.ndarray] - ) -> np.ndarray: + def get_weight(self, index: Union[slice, int, np.integer, np.ndarray]) -> np.ndarray: """Get the importance sampling weight. - The "weight" in the returned Batch is the weight on loss function - to de-bias the sampling process (some transition tuples are sampled - more often so their losses are weighted less). + The "weight" in the returned Batch is the weight on loss function to de-bias the sampling process (some + transition tuples are sampled more often so their losses are weighted less). """ # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) return (self.weight[index] / self._min_prio) ** (-self._beta) - def update_weight( - self, - index: np.ndarray, - new_weight: Union[np.ndarray, torch.Tensor], - ) -> None: + def update_weight(self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor]) -> None: """Update priority weight by index in this buffer. :param np.ndarray index: index you want to update weight. @@ -430,27 +390,23 @@ def update_weight( self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) - def __getitem__( - self, index: Union[slice, int, np.integer, np.ndarray] - ) -> Batch: + def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch: batch = super().__getitem__(index) batch.weight = self.get_weight(index) return batch class ReplayBufferManager(ReplayBuffer): - """ReplayBufferManager contains a list of ReplayBuffer with exactly the \ - same configuration. + """ReplayBufferManager contains a list of ReplayBuffer with exactly the same configuration. - These replay buffers have contiguous memory layout, and the storage space - each buffer has is a shallow copy of the topmost memory. + These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the + topmost memory. :param int buffer_list: a list of ReplayBuffer needed to be handled. .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed - explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed explanation. """ def __init__(self, buffer_list: List[ReplayBuffer]) -> None: @@ -510,18 +466,15 @@ def update(self, buffer: ReplayBuffer) -> np.ndarray: raise NotImplementedError def add( - self, - batch: Batch, - buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. - Each of the data's length (first dimension) must equal to the length of - buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. + Each of the data's length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is + [0, 1, ..., buffer_num - 1]. - Return (current_index, episode_reward, episode_length, - episode_start_index). If the episode is not finished, the return value - of episode_length and episode_reward is 0. + Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, + the return value of episode_length and episode_reward is 0. """ # preprocess batch b = Batch() @@ -587,9 +540,7 @@ def sample_index(self, batch_size: int) -> np.ndarray: ]) -class PrioritizedReplayBufferManager( - PrioritizedReplayBuffer, ReplayBufferManager -): +class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): def __init__(self, buffer_list: List[PrioritizedReplayBuffer]) -> None: ReplayBufferManager.__init__(self, buffer_list) kwargs = buffer_list[0].options @@ -599,9 +550,7 @@ def __init__(self, buffer_list: List[PrioritizedReplayBuffer]) -> None: class VectorReplayBuffer(ReplayBufferManager): - def __init__( - self, total_size: int, buffer_num: int, **kwargs: Any - ) -> None: + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)] @@ -609,73 +558,54 @@ def __init__( class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): - def __init__( - self, total_size: int, buffer_num: int, **kwargs: Any - ) -> None: + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) - buffer_list = [PrioritizedReplayBuffer(size, **kwargs) - for _ in range(buffer_num)] + buffer_list = [PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)] super().__init__(buffer_list) class CachedReplayBuffer(ReplayBufferManager): - """CachedReplayBuffer contains a given main buffer and n cached buffers, \ - cached_buffer_num * ReplayBuffer(size=max_episode_length). + """CachedReplayBuffer contains a given main buffer and n cached buffers, cached_buffer_num * \ + ReplayBuffer(size=max_episode_length). - The memory layout is: ``| main_buffer | cached_buffers[0] | - cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1]``. + The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ... | + cached_buffers[cached_buffer_num - 1]``. - The data is first stored in cached buffers. When the episode is - terminated, the data will move to the main buffer and the corresponding - cached buffer will be reset. + The data is first stored in cached buffers. When the episode is terminated, the data will move to the main buffer + and the corresponding cached buffer will be reset. - :param ReplayBuffer main_buffer: the main buffer whose ``.update()`` - function behaves normally. - :param int cached_buffer_num: number of ReplayBuffer needs to be created - for cached buffer. - :param int max_episode_length: the maximum length of one episode, used in - each cached buffer's maxsize. + :param ReplayBuffer main_buffer: the main buffer whose ``.update()`` function behaves normally. + :param int cached_buffer_num: number of ReplayBuffer needs to be created for cached buffer. + :param int max_episode_length: the maximum length of one episode, used in each cached buffer's maxsize. .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` or - :class:`~tianshou.data.ReplayBufferManager` for more detailed - explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` or :class:`~tianshou.data.ReplayBufferManager` for more + detailed explanation. """ - def __init__( - self, - main_buffer: ReplayBuffer, - cached_buffer_num: int, - max_episode_length: int, - ) -> None: + def __init__(self, main_buffer: ReplayBuffer, cached_buffer_num: int, max_episode_length: int) -> None: assert cached_buffer_num > 0 and max_episode_length > 0 assert type(main_buffer) == ReplayBuffer kwargs = main_buffer.options - buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) - for _ in range(cached_buffer_num)] + buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) for _ in range(cached_buffer_num)] super().__init__(buffer_list=buffers) self.main_buffer = self.buffers[0] self.cached_buffers = self.buffers[1:] self.cached_buffer_num = cached_buffer_num def add( - self, - batch: Batch, - buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. - Each of the data's length (first dimension) must equal to the length of - buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num - - 1]. + Each of the data's length (first dimension) must equal to the length of buffer_ids. By default the buffer_ids + is [0, 1, ..., cached_buffer_num - 1]. - Return (current_index, episode_reward, episode_length, - episode_start_index) with each of the shape (len(buffer_ids), ...), - where (current_index[i], episode_reward[i], episode_length[i], - episode_start_index[i]) refers to the cached_buffer_ids[i]th cached - buffer's corresponding episode result. + Return (current_index, episode_reward, episode_length, episode_start_index) with each of the shape + (len(buffer_ids), ...), where (current_index[i], episode_reward[i], episode_length[i], episode_start_index[i]) + refers to the cached_buffer_ids[i]th cached buffer's corresponding episode result. """ if buffer_ids is None: buffer_ids = np.arange(1, 1 + self.cached_buffer_num) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 7167a9bf3..d223a0839 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -3,7 +3,7 @@ import torch import warnings import numpy as np -from typing import Any, Dict, List, Union, Optional, Callable +from typing import Dict, List, Union, Optional, Callable from tianshou.policy import BasePolicy from tianshou.env import BaseVectorEnv, DummyVectorEnv @@ -16,25 +16,19 @@ class Collector(object): """Collector enables the policy to interact with different types of envs. - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` - class. - :param env: a ``gym.Env`` environment or an instance of the - :class:`~tianshou.env.BaseVectorEnv` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` - class. If set to None (testing phase), it will not store the data. - :param function preprocess_fn: a function called before the data has been - added to the buffer, see issue #42 and :ref:`preprocess_fn`. Default - to None. - :param exploration_noise: a flag which determines when the collector is - used for training. If so, function exploration_noise() in policy will - be called automatically to add exploration noise. Default to True. - - The ``preprocess_fn`` is a function called before the data has been added - to the buffer with batch format, which receives up to 7 keys as listed in - :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the - collector resets the environment. It returns either a dict or a - :class:`~tianshou.data.Batch` with the modified keys and values. Examples - are in "test/base/test_collector.py". + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to None (testing phase), it + will not store the data. + :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and + :ref:`preprocess_fn`. Default to None. + :param exploration_noise: a flag which determines when the collector is used for training. If so, function + exploration_noise() in policy will be called automatically to add exploration noise. Default to True. + + The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which + receives up to 7 keys as listed in :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the + collector resets the environment. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified + keys and values. Examples are in "test/base/test_collector.py". Here is the example: :: @@ -50,22 +44,19 @@ class Collector(object): # the collector supports vectorized environments as well vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num = 3) # buffer_num should be equal (suggested) to or larger than #envs - envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') - for _ in range(3)]) + envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) collector = Collector(policy, envs, buffer=vec_buffer) # collect 3 episodes collector.collect(n_episode=3) # collect at least 2 steps collector.collect(n_step=2) - # collect episodes with visual rendering (the render argument is the - # sleep time between rendering consecutive frames) + # collect episodes with visual rendering ("render" is the sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) .. note:: - Please make sure the given environment has a time limitation if using - n_episode collect option. + Please make sure the given environment has a time limitation if using n_episode collect option. """ def __init__( @@ -91,8 +82,7 @@ def __init__( def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: if buffer is None: - buffer = VectorReplayBuffer( - self.env_num * 1, self.env_num) + buffer = VectorReplayBuffer(self.env_num * 1, self.env_num) elif isinstance(buffer, ReplayBufferManager): assert buffer.buffer_num >= self.env_num if isinstance(buffer, CachedReplayBuffer): @@ -107,10 +97,8 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: buffer_type = "PrioritizedReplayBuffer" vector_type = "PrioritizedVectorReplayBuffer" raise TypeError( - f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to " - f"collect {self.env_num} envs,\n\tplease use {vector_type}" - f"(total_size={buffer.maxsize}, buffer_num={self.env_num}," - " ...) instead.") + f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect {self.env_num} envs,\n\t" + f"please use {vector_type}(total_size={buffer.maxsize}, buffer_num={self.env_num}, ...) instead.") self.buffer = buffer # TODO move to trainer @@ -129,8 +117,7 @@ def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy - self.data = Batch(obs={}, act={}, rew={}, done={}, - obs_next={}, info={}, policy={}) + self.data = Batch(obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}) self.reset_env() self.reset_buffer() self.reset_stat() @@ -199,7 +186,7 @@ def collect( assert self.env.is_async is False, "Please use AsyncCollector if ..." if n_step is not None: assert n_episode is None, ( - "Only one of n_step or n_episode is allowed in Collector." + f"Only one of n_step or n_episode is allowed in Collector." f"collect, got n_step={n_step}, n_episode={n_episode}.") assert n_step > 0 ready_env_ids = np.arange(self.env_num) @@ -222,8 +209,7 @@ def collect( # get the next action if random: - self.data.update(act=[self._action_space[i].sample() - for i in ready_env_ids]) + self.data.update(act=[self._action_space[i].sample() for i in ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -243,8 +229,7 @@ def collect( self.data.update(policy=policy, act=act) # step in env - obs_next, rew, done, info = self.env.step( - self.data.act, id=ready_env_ids) + obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids) self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: @@ -261,8 +246,7 @@ def collect( time.sleep(render) # add data into the buffer - ptr, ep_rew, ep_len, ep_idx = self.buffer.add( - self.data, buffer_ids=ready_env_ids) + ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids) # collect statistics step_count += len(ready_env_ids) @@ -278,8 +262,7 @@ def collect( # finished episodes, we have to reset finished envs first. obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: - obs_reset = self.preprocess_fn( - obs=obs_reset).get("obs", obs_reset) + obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) @@ -287,8 +270,7 @@ def collect( # Remove surplus env id from ready_env_ids to avoid bias in # selecting environments. if n_episode: - episode_to_collect = n_episode - episode_count - surplus_env_num = len(ready_env_ids) - episode_to_collect + surplus_env_num = len(ready_env_ids) - (n_episode - episode_count) if surplus_env_num > 0: mask = np.ones_like(ready_env_ids, np.bool) mask[env_ind_local[:surplus_env_num]] = False @@ -297,8 +279,7 @@ def collect( self.data.obs = self.data.obs_next - if (n_step and step_count >= n_step) or \ - (n_episode and episode_count >= n_episode): + if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode): break # generate statistics @@ -307,21 +288,15 @@ def collect( self.collect_time += max(time.time() - start_time, 1e-9) if n_episode: - self.data = Batch(obs={}, act={}, rew={}, done={}, - obs_next={}, info={}, policy={}) + self.data = Batch(obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}) self.reset_env() if episode_count > 0: - rews, lens, idxs = list(map(np.concatenate, [ - episode_rews, episode_lens, episode_start_indices])) + rews, lens, idxs = list(map(np.concatenate, [episode_rews, episode_lens, episode_start_indices])) else: - rews, lens, idxs = \ - np.array([]), np.array([], np.int), np.array([], np.int) + rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int) - return { - "n/ep": episode_count, "n/st": step_count, - "rews": rews, "lens": lens, "idxs": idxs, - } + return {"n/ep": episode_count, "n/st": step_count, "rews": rews, "lens": lens, "idxs": idxs} class AsyncCollector(Collector): @@ -349,12 +324,9 @@ def collect( :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. - :param bool random: whether to use random policy for collecting data. - Default to False. - :param float render: the sleep time between rendering consecutive - frames. Default to None (no rendering). - :param bool no_grad: whether to retain gradient in policy.forward. - Default to True (no gradient retaining). + :param bool random: whether to use random policy for collecting data. Default to False. + :param float render: the sleep time between rendering consecutive frames. Default to None (no rendering). + :param bool no_grad: whether to retain gradient in policy.forward. Default to True (no gradient retaining). .. note:: @@ -378,10 +350,8 @@ def collect( assert n_step > 0 else: assert n_episode > 0 - warnings.warn("Using n_episode under async setting may collect " - "extra frames into buffer.") + warnings.warn("Using n_episode under async setting may collect extra frames into buffer.") - finished_env_ids = [] ready_env_ids = self._ready_env_ids start_time = time.time() @@ -401,8 +371,7 @@ def collect( # get the next action if random: - self.data.update(act=[self._action_space[i].sample() - for i in ready_env_ids]) + self.data.update(act=[self._action_space[i].sample() for i in ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -452,8 +421,7 @@ def collect( time.sleep(render) # add data into the buffer - ptr, ep_rew, ep_len, ep_idx = self.buffer.add( - self.data, buffer_ids=ready_env_ids) + ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids) # collect statistics step_count += len(ready_env_ids) @@ -469,8 +437,7 @@ def collect( # finished episodes, we have to reset finished envs first. obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: - obs_reset = self.preprocess_fn( - obs=obs_reset).get("obs", obs_reset) + obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) @@ -486,8 +453,7 @@ def collect( whole_data[ready_env_ids] = self.data # lots of overhead self.data = whole_data - if (n_step and step_count >= n_step) or \ - (n_episode and episode_count >= n_episode): + if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode): break self._ready_env_ids = ready_env_ids @@ -498,13 +464,8 @@ def collect( self.collect_time += max(time.time() - start_time, 1e-9) if episode_count > 0: - rews, lens, idxs = list(map(np.concatenate, [ - episode_rews, episode_lens, episode_start_indices])) + rews, lens, idxs = list(map(np.concatenate, [episode_rews, episode_lens, episode_start_indices])) else: - rews, lens, idxs = \ - np.array([]), np.array([], np.int), np.array([], np.int) + rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int) - return { - "n/ep": episode_count, "n/st": step_count, - "rews": rews, "lens": lens, "idxs": idxs, - } + return {"n/ep": episode_count, "n/st": step_count, "rews": rews, "lens": lens, "idxs": idxs} From 615584484d7dd87c938ce43a8665b26612e940fa Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 10:09:59 +0800 Subject: [PATCH 062/104] black format --- tianshou/data/buffer.py | 24 ++++++++++------------- tianshou/data/collector.py | 39 +++++++++++++++++++------------------- 2 files changed, 29 insertions(+), 34 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 0a876cadf..6f7ce2afb 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -30,7 +30,7 @@ class ReplayBuffer: :param int size: the maximum size of replay buffer. :param int stack_num: the frame-stack sampling argument, should be greater than or equal to 1. Default to 1 (no stacking). - :param bool ignore_obs_next: whether to store obs_next, defaults to False. + :param bool ignore_obs_next: whether to store obs_next. Default to False. :param bool save_only_last_obs: only save the last obs/obs_next when it has a shape of (timestep, ...) because of temporal stacking. Default to False. :param bool sample_avail: the parameter indicating sampling only available index when using frame-stack sampling @@ -271,7 +271,7 @@ def get( index: Union[int, np.integer, np.ndarray], key: str, default_value: Optional[Any] = None, - stack_num: Optional[int] = None + stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: """Return the stacked result. @@ -439,9 +439,7 @@ def set_batch(self, batch: Batch) -> None: self._set_batch_for_children() def unfinished_index(self) -> np.ndarray: - return np.concatenate([ - buf.unfinished_index() + offset - for offset, buf in zip(self._offset, self.buffers)]) + return np.concatenate([buf.unfinished_index() + offset for offset, buf in zip(self._offset, self.buffers)]) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: index = np.asarray(index) % self.maxsize @@ -517,9 +515,9 @@ def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: return np.array([], np.int) if self._sample_avail and self.stack_num > 1: - all_indices = np.concatenate([ - buf.sample_index(0) + offset - for offset, buf in zip(self._offset, self.buffers)]) + all_indices = np.concatenate( + [buf.sample_index(0) + offset for offset, buf in zip(self._offset, self.buffers)] + ) if batch_size == 0: return all_indices else: @@ -528,16 +526,14 @@ def sample_index(self, batch_size: int) -> np.ndarray: sample_num = np.zeros(self.buffer_num, np.int) else: buffer_lens = np.array([len(buf) for buf in self.buffers]) - buffer_idx = np.random.choice(self.buffer_num, batch_size, - p=buffer_lens / buffer_lens.sum()) + buffer_idx = np.random.choice(self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum()) sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) # avoid batch_size > 0 and sample_num == 0 -> get child's all data sample_num[sample_num == 0] = -1 - return np.concatenate([ - buf.sample_index(bsz) + offset - for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) - ]) + return np.concatenate( + [buf.sample_index(bsz) + offset for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)] + ) class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index d223a0839..5c7518546 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -6,10 +6,9 @@ from typing import Dict, List, Union, Optional, Callable from tianshou.policy import BasePolicy -from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager, \ - VectorReplayBuffer, CachedReplayBuffer, to_numpy from tianshou.data.buffer import _alloc_by_keys_diff +from tianshou.env import BaseVectorEnv, DummyVectorEnv +from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager, VectorReplayBuffer, CachedReplayBuffer, to_numpy # TODO change doc @@ -98,7 +97,8 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: vector_type = "PrioritizedVectorReplayBuffer" raise TypeError( f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect {self.env_num} envs,\n\t" - f"please use {vector_type}(total_size={buffer.maxsize}, buffer_num={self.env_num}, ...) instead.") + f"please use {vector_type}(total_size={buffer.maxsize}, buffer_num={self.env_num}, ...) instead." + ) self.buffer = buffer # TODO move to trainer @@ -187,7 +187,8 @@ def collect( if n_step is not None: assert n_episode is None, ( f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got n_step={n_step}, n_episode={n_episode}.") + f"collect, got n_step={n_step}, n_episode={n_episode}." + ) assert n_step > 0 ready_env_ids = np.arange(self.env_num) else: @@ -233,12 +234,11 @@ def collect( self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: - self.data.update(self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - )) + self.data.update( + self.preprocess_fn( + obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info, + ) + ) if render: self.env.render() @@ -346,7 +346,8 @@ def collect( if n_step is not None: assert n_episode is None, ( "Only one of n_step or n_episode is allowed in Collector." - f"collect, got n_step={n_step}, n_episode={n_episode}.") + f"collect, got n_step={n_step}, n_episode={n_episode}." + ) assert n_step > 0 else: assert n_episode > 0 @@ -399,8 +400,7 @@ def collect( whole_data[ready_env_ids] = self.data # lots of overhead # step in env - obs_next, rew, done, info = self.env.step( - self.data.act, id=ready_env_ids) + obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids) # change self.data here because ready_env_ids has changed ready_env_ids = np.array([i["env_id"] for i in info]) @@ -408,12 +408,11 @@ def collect( self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: - self.data.update(self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - )) + self.data.update( + self.preprocess_fn( + obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info, + ) + ) if render: self.env.render() From c6431cfd3889dd82b1b2477d0748c00c0d4a646a Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 10:36:49 +0800 Subject: [PATCH 063/104] fix test and docs for collector --- test/base/test_collector.py | 6 +-- tianshou/data/collector.py | 78 ++++++++++++++++++------------------- 2 files changed, 40 insertions(+), 44 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index aee437c80..9148fe967 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -93,8 +93,6 @@ def test_collector(): policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4), logger.preprocess_fn) - with pytest.raises(AssertionError): - c1.collect(n_step=6) c1.collect(n_step=8) obs = np.zeros(100) obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1] @@ -141,7 +139,7 @@ def test_collector_with_async(): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.01, random_sleep=True) + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) @@ -199,8 +197,6 @@ def test_collector_with_dict_state(): policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), Logger.single_preprocess_fn) - with pytest.raises(AssertionError): - c1.collect(n_step=10) c1.collect(n_step=12) result = c1.collect(n_episode=8) assert result['n/ep'] == 8 diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 5c7518546..32786524a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -11,29 +11,29 @@ from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager, VectorReplayBuffer, CachedReplayBuffer, to_numpy -# TODO change doc class Collector(object): - """Collector enables the policy to interact with different types of envs. + """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to None (testing phase), it - will not store the data. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to None, it will not store + the data. Default to None. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None. - :param exploration_noise: a flag which determines when the collector is used for training. If so, function - exploration_noise() in policy will be called automatically to add exploration noise. Default to True. + :param bool exploration_noise: determine whether the action needs to be modified with corresponding policy's + exploration noise. If so, "policy.exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. - The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which - receives up to 7 keys as listed in :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the - collector resets the environment. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified - keys and values. Examples are in "test/base/test_collector.py". + The "preprocess_fn" is a function called before the data has been added to the buffer with batch format. + It will receive with only "obs" when the collector resets the environment, and will receive four keys "obs_next", + "rew", "done", "info" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the + modified keys and values. Examples are in "test/base/test_collector.py". - Here is the example: + Here are some example usages: :: policy = PGPolicy(...) # or other policies if you wish - env = gym.make('CartPole-v0') + env = gym.make("CartPole-v0") replay_buffer = ReplayBuffer(size=10000) @@ -41,9 +41,9 @@ class Collector(object): collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well - vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num = 3) - # buffer_num should be equal (suggested) to or larger than #envs - envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) + vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3) + # buffer_num should be equal to (suggested) or larger than #envs + envs = DummyVectorEnv([lambda: gym.make("CartPole-v0") for _ in range(3)]) collector = Collector(policy, envs, buffer=vec_buffer) # collect 3 episodes @@ -80,13 +80,14 @@ def __init__( self.reset() def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: + """Check if the buffer matches the constraint.""" if buffer is None: buffer = VectorReplayBuffer(self.env_num * 1, self.env_num) elif isinstance(buffer, ReplayBufferManager): assert buffer.buffer_num >= self.env_num if isinstance(buffer, CachedReplayBuffer): assert buffer.cached_buffer_num >= self.env_num - else: # ReplayBuffer or PrioritizedReplayBuffer + else: # ReplayBuffer or PrioritizedReplayBuffer cannot be used collecting with multi environments. assert buffer.maxsize > 0 if self.env_num > 1: if type(buffer) == ReplayBuffer: @@ -160,18 +161,14 @@ def collect( """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. - :param bool random: whether to use random policy for collecting data. - Default to False. - :param float render: the sleep time between rendering consecutive - frames. Default to None (no rendering). - :param bool no_grad: whether to retain gradient in policy.forward. - Default to True (no gradient retaining). + :param int n_episode: how many episodes you want to collect. + :param bool random: whether to use random policy for collecting data. Default to False. + :param float render: the sleep time between rendering consecutive frames. Default to None (no rendering). + :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). .. note:: - One and only one collection number specification is permitted, - either ``n_step`` or ``n_episode``. + One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys @@ -179,10 +176,13 @@ def collect( * ``n/st`` the collected number of steps. * ``rews`` the list of episode reward over collected episodes. * ``lens`` the list of episode length over collected episodes. - * ``idxs`` the list of episode start index over collected episodes. + * ``idxs`` the list of episode start index in buffer over collected episodes. + + .. note:: + + To ensure unbiased sampling result with n_episode option, this function will first collect ``n_episode - + env_num`` episodes, then for the last ``env_num`` episodes, they will be collected evenly from each env. """ - # collect at least n_step or n_episode - # TODO: modify docs, tell the constraints assert self.env.is_async is False, "Please use AsyncCollector if ..." if n_step is not None: assert n_episode is None, ( @@ -190,6 +190,9 @@ def collect( f"collect, got n_step={n_step}, n_episode={n_episode}." ) assert n_step > 0 + if not n_step % self.env_num == 0: + warnings.warn(f"n_step={n_step} is not a multiple of #env ({self.env_num}), " + "which may cause extra frame collected into the buffer.") ready_env_ids = np.arange(self.env_num) else: assert n_episode > 0 @@ -236,7 +239,7 @@ def collect( if self.preprocess_fn: self.data.update( self.preprocess_fn( - obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info, + obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info ) ) @@ -267,8 +270,7 @@ def collect( for i in env_ind_local: self._reset_state(i) - # Remove surplus env id from ready_env_ids to avoid bias in - # selecting environments. + # Remove surplus env id from ready_env_ids to avoid bias in selecting environments. if n_episode: surplus_env_num = len(ready_env_ids) - (n_episode - episode_count) if surplus_env_num > 0: @@ -323,15 +325,14 @@ def collect( """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. + :param int n_episode: how many episodes you want to collect. :param bool random: whether to use random policy for collecting data. Default to False. :param float render: the sleep time between rendering consecutive frames. Default to None (no rendering). - :param bool no_grad: whether to retain gradient in policy.forward. Default to True (no gradient retaining). + :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). .. note:: - One and only one collection number specification is permitted, - either ``n_step`` or ``n_episode``. + One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys @@ -339,10 +340,9 @@ def collect( * ``n/st`` the collected number of steps. * ``rews`` the list of episode reward over collected episodes. * ``lens`` the list of episode length over collected episodes. - * ``idxs`` the list of episode start index over collected episodes. + * ``idxs`` the list of episode start index in buffer over collected episodes. """ # collect at least n_step or n_episode - # TODO: modify docs, tell the constraints if n_step is not None: assert n_episode is None, ( "Only one of n_step or n_episode is allowed in Collector." @@ -351,7 +351,7 @@ def collect( assert n_step > 0 else: assert n_episode > 0 - warnings.warn("Using n_episode under async setting may collect extra frames into buffer.") + warnings.warn("Using async setting may collect extra frames into buffer.") ready_env_ids = self._ready_env_ids @@ -410,7 +410,7 @@ def collect( if self.preprocess_fn: self.data.update( self.preprocess_fn( - obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info, + obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info ) ) From 581e5c80567804aeaac66201ffd0d2dd110aaa8d Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Tue, 9 Feb 2021 11:16:27 +0800 Subject: [PATCH 064/104] some other bugs fix --- docs/tutorials/tictactoe.rst | 2 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- test/discrete/test_c51.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/multiagent/tic_tac_toe.py | 2 +- tianshou/policy/base.py | 2 +- tianshou/trainer/offpolicy.py | 4 ++-- tianshou/trainer/onpolicy.py | 2 +- 11 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 605528b01..3801d8863 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -316,7 +316,7 @@ With the above preparation, we are close to the first learned agent. The followi # ======== collector setup ========= train_collector = Collector(policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # ======== tensorboard logging setup ========= if not hasattr(args, 'writer'): diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 272087f21..bff170bc2 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -77,7 +77,7 @@ def test_dqn(args=get_args()): exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 1471a86a6..d6039c094 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -78,7 +78,7 @@ def test_dqn(args=get_args()): exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 50c940492..0f6251eb4 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -84,7 +84,7 @@ def test_c51(args=get_args()): train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index c0268953d..7a9a078b2 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -86,7 +86,7 @@ def test_dqn(args=get_args()): train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 46e05fe48..d897ca7e4 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -73,7 +73,7 @@ def test_drqn(args=get_args()): # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'drqn') writer = SummaryWriter(log_path) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index f721fac3a..6299e9efd 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -82,7 +82,7 @@ def test_qrdqn(args=get_args()): train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 94cc523ea..9f35f0cf4 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -129,7 +129,7 @@ def env_func(): exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log if not hasattr(args, 'writer'): log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index f0e4137e9..5a78fa37c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -192,7 +192,7 @@ def update( @staticmethod def value_mask(batch): # TODO doc - return ~batch.done + return ~batch.done.astype(np.bool) @staticmethod def compute_episodic_return( diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index a6deb8131..e3db968a2 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -135,13 +135,13 @@ def offpolicy_trainer( result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step) if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result["rew_std"].std() + best_reward, best_reward = result["rews"].mean(), result["rews"].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " + f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 27b15b262..4d5a496f7 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -137,7 +137,7 @@ def onpolicy_trainer( result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step) if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result["rew_std"] + best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() best_epoch = epoch if save_fn: save_fn(policy) From 26dc2a1553b2e2bc46568dfd0846acc22e095585 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 12:14:43 +0800 Subject: [PATCH 065/104] 88 --- setup.cfg | 2 +- tianshou/data/buffer.py | 196 ++++++++++++++++++++++-------------- tianshou/data/collector.py | 197 ++++++++++++++++++++++++------------- 3 files changed, 254 insertions(+), 141 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7a9eb21cf..d485e6d06 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ exclude = build dist *.egg-info -max-line-length = 119 +max-line-length = 87 [mypy] files = tianshou/**/*.py diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 6f7ce2afb..9b6c8583d 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -8,7 +8,9 @@ from tianshou.data.utils.converter import to_hdf5, from_hdf5 -def _alloc_by_keys_diff(meta: Batch, batch: Batch, size: int, stack: bool = True) -> None: +def _alloc_by_keys_diff( + meta: Batch, batch: Batch, size: int, stack: bool = True +) -> None: for key in batch.keys(): if key in meta.keys(): if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): @@ -20,21 +22,23 @@ def _alloc_by_keys_diff(meta: Batch, batch: Batch, size: int, stack: bool = True class ReplayBuffer: - """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. + """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction \ + between the policy and environment. - ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch - with circular-queue style. + ReplayBuffer can be considered as a specialized form (or management) of Batch. It + stores all the data in a batch with circular-queue style. - For the example usage of ReplayBuffer, please check out Section Buffer in :doc:`/tutorials/concepts`. + For the example usage of ReplayBuffer, please check out Section Buffer in + :doc:`/tutorials/concepts`. :param int size: the maximum size of replay buffer. - :param int stack_num: the frame-stack sampling argument, should be greater than or equal to 1. - Default to 1 (no stacking). + :param int stack_num: the frame-stack sampling argument, should be greater than or + equal to 1. Default to 1 (no stacking). :param bool ignore_obs_next: whether to store obs_next. Default to False. - :param bool save_only_last_obs: only save the last obs/obs_next when it has a shape of (timestep, ...) because of - temporal stacking. Default to False. - :param bool sample_avail: the parameter indicating sampling only available index when using frame-stack sampling - method. Default to False. + :param bool save_only_last_obs: only save the last obs/obs_next when it has a shape + of (timestep, ...) because of temporal stacking. Default to False. + :param bool sample_avail: the parameter indicating sampling only available index + when using frame-stack sampling method. Default to False. """ _reserved_keys = ("obs", "act", "rew", "done", "obs_next", "info", "policy") @@ -46,8 +50,6 @@ def __init__( ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, - # TODO talk about - **kwargs, ) -> None: self.options: Dict[str, Any] = { "stack_num": stack_num, @@ -93,7 +95,9 @@ def __setstate__(self, state: Dict[str, Any]) -> None: def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" - assert key not in self._reserved_keys, "key '{}' is reserved and cannot be assigned".format(key) + assert ( + key not in self._reserved_keys + ), "key '{}' is reserved and cannot be assigned".format(key) super().__setattr__(key, value) def save_hdf5(self, path: str) -> None: @@ -116,8 +120,9 @@ def reset(self) -> None: def set_batch(self, batch: Batch) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" - assert len(batch) == self.maxsize and set(batch.keys()).issubset(self._reserved_keys), \ - "Input batch doesn't meet ReplayBuffer's data form requirement." + assert len(batch) == self.maxsize and set(batch.keys()).issubset( + self._reserved_keys + ), "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch def unfinished_index(self) -> np.ndarray: @@ -165,10 +170,13 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._meta[to_indices] = buffer._meta[from_indices] return to_indices - def _add_index(self, rew: Union[float, np.ndarray], done: bool) -> Tuple[int, Union[float, np.ndarray], int, int]: + def _add_index( + self, rew: Union[float, np.ndarray], done: bool + ) -> Tuple[int, Union[float, np.ndarray], int, int]: """Maintain the buffer's state after adding one data batch. - Return (index_to_be_modified, episode_reward, episode_length, episode_start_index). + Return (index_to_be_modified, episode_reward, episode_length, + episode_start_index). """ ptr = self._index self._size = min(self._size + 1, self.maxsize) @@ -189,13 +197,14 @@ def add( ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. - :param Batch batch: the input data batch. Its keys must belong to the 7 reserved keys, and "obs", "act", - "rew", "done" is required. - :param buffer_ids: to make consistent with other buffer's add function; if it is not None, we assume the input - batch's first dimension is always 1. + :param Batch batch: the input data batch. Its keys must belong to the 7 + reserved keys, and "obs", "act", "rew", "done" is required. + :param buffer_ids: to make consistent with other buffer's add function; if it + is not None, we assume the input batch's first dimension is always 1. - Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, - the return value of episode_length and episode_reward is 0. + Return (current_index, episode_reward, episode_length, episode_start_index). If + the episode is not finished, the return value of episode_length and + episode_reward is 0. """ # preprocess batch b = Batch() @@ -211,13 +220,17 @@ def add( if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: - batch.obs_next = batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] + batch.obs_next = ( + batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] + ) # get ptr if stacked_batch: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done - ptr, ep_rew, ep_len, ep_idx = list(map(lambda x: np.array([x]), self._add_index(rew, done))) + ptr, ep_rew, ep_len, ep_idx = list( + map(lambda x: np.array([x]), self._add_index(rew, done)) + ) try: self._meta[ptr] = batch except ValueError: @@ -234,20 +247,24 @@ def add( def sample_index(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. - Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 - or no available index can be sampled. + Return all available indices in the buffer if batch_size is 0; return an empty + numpy array if batch_size < 0 or no available index can be sampled. """ if self.stack_num == 1 or not self._sample_avail: # most often case if batch_size > 0: return np.random.choice(self._size, batch_size) elif batch_size == 0: # construct current available indices - return np.concatenate([np.arange(self._index, self._size), np.arange(self._index)]) + return np.concatenate( + [np.arange(self._index, self._size), np.arange(self._index)] + ) else: return np.array([], np.int) else: if batch_size < 0: return np.array([], np.int) - all_indices = prev_indices = np.concatenate([np.arange(self._index, self._size), np.arange(self._index)]) + all_indices = prev_indices = np.concatenate( + [np.arange(self._index, self._size), np.arange(self._index)] + ) for _ in range(self.stack_num - 2): prev_indices = self.prev(prev_indices) all_indices = all_indices[prev_indices != self.prev(prev_indices)] @@ -302,14 +319,16 @@ def get( def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch: """Return a data batch: self[index]. - If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, ...). + If stack_num is larger than 1, return the stacked obs and obs_next with shape + (batch, len, ...). """ if isinstance(index, slice): # change slice to np array if index == slice(None): # buffer[:] will get all available data index = self.sample_index(0) else: index = self._indices[:len(self)][index] - # raise KeyError first instead of AttributeError, to support np.array([ReplayBuffer()]) + # raise KeyError first instead of AttributeError, + # to support np.array([ReplayBuffer()]) obs = self.get(index, "obs") if self._save_obs_next: obs_next = self.get(index, "obs_next", Batch()) @@ -334,7 +353,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed + explanation. """ def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None: @@ -368,18 +388,23 @@ def sample_index(self, batch_size: int) -> np.ndarray: else: return super().sample_index(batch_size) - def get_weight(self, index: Union[slice, int, np.integer, np.ndarray]) -> np.ndarray: + def get_weight( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> np.ndarray: """Get the importance sampling weight. - The "weight" in the returned Batch is the weight on loss function to de-bias the sampling process (some - transition tuples are sampled more often so their losses are weighted less). + The "weight" in the returned Batch is the weight on loss function to de-bias + the sampling process (some transition tuples are sampled more often so their + losses are weighted less). """ # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) return (self.weight[index] / self._min_prio) ** (-self._beta) - def update_weight(self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor]) -> None: + def update_weight( + self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor] + ) -> None: """Update priority weight by index in this buffer. :param np.ndarray index: index you want to update weight. @@ -397,16 +422,18 @@ def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch class ReplayBufferManager(ReplayBuffer): - """ReplayBufferManager contains a list of ReplayBuffer with exactly the same configuration. + """ReplayBufferManager contains a list of ReplayBuffer with exactly the same \ + configuration. - These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the - topmost memory. + These replay buffers have contiguous memory layout, and the storage space each + buffer has is a shallow copy of the topmost memory. :param int buffer_list: a list of ReplayBuffer needed to be handled. .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed + explanation. """ def __init__(self, buffer_list: List[ReplayBuffer]) -> None: @@ -439,7 +466,10 @@ def set_batch(self, batch: Batch) -> None: self._set_batch_for_children() def unfinished_index(self) -> np.ndarray: - return np.concatenate([buf.unfinished_index() + offset for offset, buf in zip(self._offset, self.buffers)]) + return np.concatenate([ + buf.unfinished_index() + offset + for offset, buf in zip(self._offset, self.buffers) + ]) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: index = np.asarray(index) % self.maxsize @@ -468,11 +498,12 @@ def add( ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. - Each of the data's length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is - [0, 1, ..., buffer_num - 1]. + Each of the data's length (first dimension) must equal to the length of + buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. - Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, - the return value of episode_length and episode_reward is 0. + Return (current_index, episode_reward, episode_length, episode_start_index). If + the episode is not finished, the return value of episode_length and + episode_reward is 0. """ # preprocess batch b = Batch() @@ -492,7 +523,8 @@ def add( ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( - batch.rew[batch_idx], batch.done[batch_idx]) + batch.rew[batch_idx], batch.done[batch_idx] + ) ptrs.append(ptr + self._offset[buffer_id]) ep_lens.append(ep_len) ep_rews.append(ep_rew) @@ -515,9 +547,10 @@ def sample_index(self, batch_size: int) -> np.ndarray: if batch_size < 0: return np.array([], np.int) if self._sample_avail and self.stack_num > 1: - all_indices = np.concatenate( - [buf.sample_index(0) + offset for offset, buf in zip(self._offset, self.buffers)] - ) + all_indices = np.concatenate([ + buf.sample_index(0) + offset + for offset, buf in zip(self._offset, self.buffers) + ]) if batch_size == 0: return all_indices else: @@ -526,14 +559,17 @@ def sample_index(self, batch_size: int) -> np.ndarray: sample_num = np.zeros(self.buffer_num, np.int) else: buffer_lens = np.array([len(buf) for buf in self.buffers]) - buffer_idx = np.random.choice(self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum()) + buffer_idx = np.random.choice( + self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum() + ) sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) # avoid batch_size > 0 and sample_num == 0 -> get child's all data sample_num[sample_num == 0] = -1 - return np.concatenate( - [buf.sample_index(bsz) + offset for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)] - ) + return np.concatenate([ + buf.sample_index(bsz) + offset + for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) + ]) class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): @@ -557,35 +593,48 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) - buffer_list = [PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)] + buffer_list = [ + PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num) + ] super().__init__(buffer_list) class CachedReplayBuffer(ReplayBufferManager): - """CachedReplayBuffer contains a given main buffer and n cached buffers, cached_buffer_num * \ - ReplayBuffer(size=max_episode_length). + """CachedReplayBuffer contains a given main buffer and n cached buffers, \ + cached_buffer_num * ReplayBuffer(size=max_episode_length). - The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ... | - cached_buffers[cached_buffer_num - 1]``. + The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ... + | cached_buffers[cached_buffer_num - 1]``. - The data is first stored in cached buffers. When the episode is terminated, the data will move to the main buffer - and the corresponding cached buffer will be reset. + The data is first stored in cached buffers. When an episode is terminated, the data + will move to the main buffer and the corresponding cached buffer will be reset. - :param ReplayBuffer main_buffer: the main buffer whose ``.update()`` function behaves normally. - :param int cached_buffer_num: number of ReplayBuffer needs to be created for cached buffer. - :param int max_episode_length: the maximum length of one episode, used in each cached buffer's maxsize. + :param ReplayBuffer main_buffer: the main buffer whose ``.update()`` function + behaves normally. + :param int cached_buffer_num: number of ReplayBuffer needs to be created for cached + buffer. + :param int max_episode_length: the maximum length of one episode, used in each + cached buffer's maxsize. .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` or :class:`~tianshou.data.ReplayBufferManager` for more - detailed explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` or + :class:`~tianshou.data.ReplayBufferManager` for more detailed explanation. """ - def __init__(self, main_buffer: ReplayBuffer, cached_buffer_num: int, max_episode_length: int) -> None: + def __init__( + self, + main_buffer: ReplayBuffer, + cached_buffer_num: int, + max_episode_length: int, + ) -> None: assert cached_buffer_num > 0 and max_episode_length > 0 assert type(main_buffer) == ReplayBuffer kwargs = main_buffer.options - buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) for _ in range(cached_buffer_num)] + buffers = [main_buffer] + [ + ReplayBuffer(max_episode_length, **kwargs) + for _ in range(cached_buffer_num) + ] super().__init__(buffer_list=buffers) self.main_buffer = self.buffers[0] self.cached_buffers = self.buffers[1:] @@ -596,12 +645,13 @@ def add( ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. - Each of the data's length (first dimension) must equal to the length of buffer_ids. By default the buffer_ids - is [0, 1, ..., cached_buffer_num - 1]. + Each of the data's length (first dimension) must equal to the length of + buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num - 1]. - Return (current_index, episode_reward, episode_length, episode_start_index) with each of the shape - (len(buffer_ids), ...), where (current_index[i], episode_reward[i], episode_length[i], episode_start_index[i]) - refers to the cached_buffer_ids[i]th cached buffer's corresponding episode result. + Return (current_index, episode_reward, episode_length, episode_start_index) + with each of the shape (len(buffer_ids), ...), where (current_index[i], + episode_reward[i], episode_length[i], episode_start_index[i]) refers to the + cached_buffer_ids[i]th cached buffer's corresponding episode result. """ if buffer_ids is None: buffer_ids = np.arange(1, 1 + self.cached_buffer_num) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 32786524a..0ab7288ef 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -8,26 +8,37 @@ from tianshou.policy import BasePolicy from tianshou.data.buffer import _alloc_by_keys_diff from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager, VectorReplayBuffer, CachedReplayBuffer, to_numpy +from tianshou.data import ( + Batch, + ReplayBuffer, + ReplayBufferManager, + VectorReplayBuffer, + CachedReplayBuffer, + to_numpy, +) class Collector(object): - """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. + """Collector enables the policy to interact with different types of envs with \ + exact number of steps or episodes. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to None, it will not store - the data. Default to None. - :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and - :ref:`preprocess_fn`. Default to None. - :param bool exploration_noise: determine whether the action needs to be modified with corresponding policy's - exploration noise. If so, "policy.exploration_noise(act, batch)" will be called automatically to add the + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, it will not store the data. Default to None. + :param function preprocess_fn: a function called before the data has been added to + the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None. + :param bool exploration_noise: determine whether the action needs to be modified + with corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. - The "preprocess_fn" is a function called before the data has been added to the buffer with batch format. - It will receive with only "obs" when the collector resets the environment, and will receive four keys "obs_next", - "rew", "done", "info" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the - modified keys and values. Examples are in "test/base/test_collector.py". + The "preprocess_fn" is a function called before the data has been added to the + buffer with batch format. It will receive with only "obs" when the collector resets + the environment, and will receive four keys "obs_next", "rew", "done", "info" in a + normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with + the modified keys and values. Examples are in "test/base/test_collector.py". Here are some example usages: :: @@ -50,12 +61,14 @@ class Collector(object): collector.collect(n_episode=3) # collect at least 2 steps collector.collect(n_step=2) - # collect episodes with visual rendering ("render" is the sleep time between rendering consecutive frames) + # collect episodes with visual rendering ("render" is the sleep time between + # rendering consecutive frames) collector.collect(n_episode=1, render=0.03) .. note:: - Please make sure the given environment has a time limitation if using n_episode collect option. + Please make sure the given environment has a time limitation if using n_episode + collect option. """ def __init__( @@ -87,7 +100,7 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: assert buffer.buffer_num >= self.env_num if isinstance(buffer, CachedReplayBuffer): assert buffer.cached_buffer_num >= self.env_num - else: # ReplayBuffer or PrioritizedReplayBuffer cannot be used collecting with multi environments. + else: # ReplayBuffer or PrioritizedReplayBuffer assert buffer.maxsize > 0 if self.env_num > 1: if type(buffer) == ReplayBuffer: @@ -97,8 +110,9 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: buffer_type = "PrioritizedReplayBuffer" vector_type = "PrioritizedVectorReplayBuffer" raise TypeError( - f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect {self.env_num} envs,\n\t" - f"please use {vector_type}(total_size={buffer.maxsize}, buffer_num={self.env_num}, ...) instead." + f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect " + f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" + f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead." ) self.buffer = buffer @@ -118,7 +132,9 @@ def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy - self.data = Batch(obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}) + self.data = Batch( + obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} + ) self.reset_env() self.reset_buffer() self.reset_stat() @@ -137,7 +153,6 @@ def reset_env(self) -> None: if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) self.data.obs = obs - self._ready_env_ids = np.arange(self.env_num) def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" @@ -162,26 +177,31 @@ def collect( :param int n_step: how many steps you want to collect. :param int n_episode: how many episodes you want to collect. - :param bool random: whether to use random policy for collecting data. Default to False. - :param float render: the sleep time between rendering consecutive frames. Default to None (no rendering). - :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). + :param bool random: whether to use random policy for collecting data. Default + to False. + :param float render: the sleep time between rendering consecutive frames. + Default to None (no rendering). + :param bool no_grad: whether to retain gradient in policy.forward(). Default to + True (no gradient retaining). .. note:: - One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. :return: A dict including the following keys - * ``n/ep`` the collected number of episodes. - * ``n/st`` the collected number of steps. - * ``rews`` the list of episode reward over collected episodes. - * ``lens`` the list of episode length over collected episodes. - * ``idxs`` the list of episode start index in buffer over collected episodes. + * ``n/ep`` collected number of episodes. + * ``n/st`` collected number of steps. + * ``rews`` list of episode reward over collected episodes. + * ``lens`` list of episode length over collected episodes. + * ``idxs`` list of episode start index in buffer over collected episodes. .. note:: - To ensure unbiased sampling result with n_episode option, this function will first collect ``n_episode - - env_num`` episodes, then for the last ``env_num`` episodes, they will be collected evenly from each env. + To ensure unbiased sampling result with n_episode option, this function + will first collect ``n_episode - env_num`` episodes, then for the last + ``env_num`` episodes, they will be collected evenly from each env. """ assert self.env.is_async is False, "Please use AsyncCollector if ..." if n_step is not None: @@ -191,8 +211,10 @@ def collect( ) assert n_step > 0 if not n_step % self.env_num == 0: - warnings.warn(f"n_step={n_step} is not a multiple of #env ({self.env_num}), " - "which may cause extra frame collected into the buffer.") + warnings.warn( + f"n_step={n_step} is not a multiple of #env ({self.env_num}), " + "which may cause extra frame collected into the buffer." + ) ready_env_ids = np.arange(self.env_num) else: assert n_episode > 0 @@ -213,7 +235,8 @@ def collect( # get the next action if random: - self.data.update(act=[self._action_space[i].sample() for i in ready_env_ids]) + self.data.update( + act=[self._action_space[i].sample() for i in ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -237,11 +260,12 @@ def collect( self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: - self.data.update( - self.preprocess_fn( - obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info - ) - ) + self.data.update(self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + )) if render: self.env.render() @@ -249,7 +273,8 @@ def collect( time.sleep(render) # add data into the buffer - ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids) + ptr, ep_rew, ep_len, ep_idx = self.buffer.add( + self.data, buffer_ids=ready_env_ids) # collect statistics step_count += len(ready_env_ids) @@ -270,7 +295,8 @@ def collect( for i in env_ind_local: self._reset_state(i) - # Remove surplus env id from ready_env_ids to avoid bias in selecting environments. + # remove surplus env id from ready_env_ids + # to avoid bias in selecting environments if n_episode: surplus_env_num = len(ready_env_ids) - (n_episode - episode_count) if surplus_env_num > 0: @@ -281,7 +307,8 @@ def collect( self.data.obs = self.data.obs_next - if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode): + if (n_step and step_count >= n_step) or \ + (n_episode and episode_count >= n_episode): break # generate statistics @@ -290,30 +317,47 @@ def collect( self.collect_time += max(time.time() - start_time, 1e-9) if n_episode: - self.data = Batch(obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}) + self.data = Batch(obs={}, act={}, rew={}, done={}, + obs_next={}, info={}, policy={}) self.reset_env() if episode_count > 0: - rews, lens, idxs = list(map(np.concatenate, [episode_rews, episode_lens, episode_start_indices])) + rews, lens, idxs = list(map( + np.concatenate, [episode_rews, episode_lens, episode_start_indices])) else: rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int) - return {"n/ep": episode_count, "n/st": step_count, "rews": rews, "lens": lens, "idxs": idxs} + return { + "n/ep": episode_count, + "n/st": step_count, + "rews": rews, + "lens": lens, + "idxs": idxs, + } class AsyncCollector(Collector): - """docstring for AsyncCollector""" + """Async Collector handles async vector environment. + + The arguments are exactly the same as :class:`~tianshou.data.Collector`, please + refer to :class:`~tianshou.data.Collector` for more detailed explanation. + """ def __init__( self, policy: BasePolicy, - env: Union[gym.Env, BaseVectorEnv], + env: BaseVectorEnv, buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, exploration_noise: bool = False, ) -> None: + assert env.is_async super().__init__(policy, env, buffer, preprocess_fn, exploration_noise) + def reset_env(self) -> None: + super().reset_env() + self._ready_env_ids = np.arange(self.env_num) + def collect( self, n_step: Optional[int] = None, @@ -322,25 +366,33 @@ def collect( render: Optional[float] = None, no_grad: bool = True, ) -> Dict[str, float]: - """Collect a specified number of step or episode. + """Collect a specified number of step or episode with async env setting. + + This function doesn't collect exactly n_step or n_episode number of frames. + Instead, in order to support async setting, it may collect more than given + n_step or n_episode frames and save into buffer. :param int n_step: how many steps you want to collect. :param int n_episode: how many episodes you want to collect. - :param bool random: whether to use random policy for collecting data. Default to False. - :param float render: the sleep time between rendering consecutive frames. Default to None (no rendering). - :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). + :param bool random: whether to use random policy for collecting data. Default + to False. + :param float render: the sleep time between rendering consecutive frames. + Default to None (no rendering). + :param bool no_grad: whether to retain gradient in policy.forward(). Default to + True (no gradient retaining). .. note:: - One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. :return: A dict including the following keys - * ``n/ep`` the collected number of episodes. - * ``n/st`` the collected number of steps. - * ``rews`` the list of episode reward over collected episodes. - * ``lens`` the list of episode length over collected episodes. - * ``idxs`` the list of episode start index in buffer over collected episodes. + * ``n/ep`` collected number of episodes. + * ``n/st`` collected number of steps. + * ``rews`` list of episode reward over collected episodes. + * ``lens`` list of episode length over collected episodes. + * ``idxs`` list of episode start index in buffer over collected episodes. """ # collect at least n_step or n_episode if n_step is not None: @@ -372,7 +424,8 @@ def collect( # get the next action if random: - self.data.update(act=[self._action_space[i].sample() for i in ready_env_ids]) + self.data.update( + act=[self._action_space[i].sample() for i in ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -408,11 +461,12 @@ def collect( self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: - self.data.update( - self.preprocess_fn( - obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info - ) - ) + self.data.update(self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + )) if render: self.env.render() @@ -420,7 +474,8 @@ def collect( time.sleep(render) # add data into the buffer - ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids) + ptr, ep_rew, ep_len, ep_idx = self.buffer.add( + self.data, buffer_ids=ready_env_ids) # collect statistics step_count += len(ready_env_ids) @@ -452,7 +507,8 @@ def collect( whole_data[ready_env_ids] = self.data # lots of overhead self.data = whole_data - if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode): + if (n_step and step_count >= n_step) or \ + (n_episode and episode_count >= n_episode): break self._ready_env_ids = ready_env_ids @@ -463,8 +519,15 @@ def collect( self.collect_time += max(time.time() - start_time, 1e-9) if episode_count > 0: - rews, lens, idxs = list(map(np.concatenate, [episode_rews, episode_lens, episode_start_indices])) + rews, lens, idxs = list(map( + np.concatenate, [episode_rews, episode_lens, episode_start_indices])) else: rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int) - return {"n/ep": episode_count, "n/st": step_count, "rews": rews, "lens": lens, "idxs": idxs} + return { + "n/ep": episode_count, + "n/st": step_count, + "rews": rews, + "lens": lens, + "idxs": idxs, + } From d3dd8ea71936c30a6f8f7b53122bd690717825cf Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 12:22:30 +0800 Subject: [PATCH 066/104] small update --- tianshou/data/collector.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 0ab7288ef..78d198b7a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -95,7 +95,7 @@ def __init__( def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: """Check if the buffer matches the constraint.""" if buffer is None: - buffer = VectorReplayBuffer(self.env_num * 1, self.env_num) + buffer = VectorReplayBuffer(self.env_num, self.env_num) elif isinstance(buffer, ReplayBufferManager): assert buffer.buffer_num >= self.env_num if isinstance(buffer, CachedReplayBuffer): @@ -132,9 +132,8 @@ def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy - self.data = Batch( - obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} - ) + self.data = Batch(obs={}, act={}, rew={}, done={}, + obs_next={}, info={}, policy={}) self.reset_env() self.reset_buffer() self.reset_stat() @@ -148,7 +147,7 @@ def reset_buffer(self) -> None: self.buffer.reset() def reset_env(self) -> None: - """Reset all of the environment(s).""" + """Reset all of the environments.""" obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) @@ -175,6 +174,10 @@ def collect( ) -> Dict[str, float]: """Collect a specified number of step or episode. + To ensure unbiased sampling result with n_episode option, this function will + first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` + episodes, they will be collected evenly from each env. + :param int n_step: how many steps you want to collect. :param int n_episode: how many episodes you want to collect. :param bool random: whether to use random policy for collecting data. Default @@ -196,14 +199,8 @@ def collect( * ``rews`` list of episode reward over collected episodes. * ``lens`` list of episode length over collected episodes. * ``idxs`` list of episode start index in buffer over collected episodes. - - .. note:: - - To ensure unbiased sampling result with n_episode option, this function - will first collect ``n_episode - env_num`` episodes, then for the last - ``env_num`` episodes, they will be collected evenly from each env. """ - assert self.env.is_async is False, "Please use AsyncCollector if ..." + assert not self.env.is_async, "Please use AsyncCollector if using async venv." if n_step is not None: assert n_episode is None, ( f"Only one of n_step or n_episode is allowed in Collector." From c37bf2495bb4dc447a5ef61653d07bb0e4bc3070 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Tue, 9 Feb 2021 16:15:13 +0800 Subject: [PATCH 067/104] fix cuda bug --- examples/atari/runnable/pong_a2c.py | 4 ++-- examples/atari/runnable/pong_ppo.py | 4 ++-- test/discrete/test_a2c_with_il.py | 4 ++-- test/discrete/test_ppo.py | 4 ++-- tianshou/utils/net/continuous.py | 11 +++++++---- tianshou/utils/net/discrete.py | 10 ++++++++-- 6 files changed, 23 insertions(+), 14 deletions(-) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 100ae24a6..ffed1694d 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -65,8 +65,8 @@ def test_a2c(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(set( actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 55219da68..35ed0e749 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -65,8 +65,8 @@ def test_ppo(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(set( actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 90f6681fd..c222bf9a3 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -68,8 +68,8 @@ def test_a2c_with_il(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(set( actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 231ad5032..e2e671c99 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -68,8 +68,8 @@ def test_ppo(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 333c2ab8c..cf1647bf6 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -49,7 +49,8 @@ def __init__( self.output_dim = np.prod(action_shape) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, hidden_sizes) + self.last = MLP(input_dim, self.output_dim, + hidden_sizes, device = self.device) self._max = max_action def forward( @@ -98,7 +99,7 @@ def __init__( self.output_dim = 1 input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, 1, hidden_sizes) + self.last = MLP(input_dim, 1, hidden_sizes, device = self.device) def forward( self, @@ -164,10 +165,12 @@ def __init__( self.output_dim = np.prod(action_shape) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.mu = MLP(input_dim, self.output_dim, hidden_sizes) + self.mu = MLP(input_dim, self.output_dim, + hidden_sizes, device = self.device) self._c_sigma = conditioned_sigma if conditioned_sigma: - self.sigma = MLP(input_dim, self.output_dim, hidden_sizes) + self.sigma = MLP(input_dim, self.output_dim, + hidden_sizes, device = self.device) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) self._max = max_action diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 05c02361c..fc7c9b002 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -40,13 +40,16 @@ def __init__( hidden_sizes: Sequence[int] = (), softmax_output: bool = True, preprocess_net_output_dim: Optional[int] = None, + device: Union[str, int, torch.device] = "cpu", ) -> None: super().__init__() + self.device = device self.preprocess = preprocess_net self.output_dim = np.prod(action_shape) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, hidden_sizes) + self.last = MLP(input_dim, self.output_dim, + hidden_sizes, device=self.device) self.softmax_output = softmax_output def forward( @@ -91,13 +94,16 @@ def __init__( hidden_sizes: Sequence[int] = (), last_size: int = 1, preprocess_net_output_dim: Optional[int] = None, + device: Union[str, int, torch.device] = "cpu", ) -> None: super().__init__() + self.device = device self.preprocess = preprocess_net self.output_dim = last_size input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, last_size, hidden_sizes) + self.last = MLP(input_dim, last_size, + hidden_sizes, device=self.device) def forward( self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any From 9656a1efaf2b26094a3075f0310869bcb3e1e0bd Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Tue, 9 Feb 2021 16:19:38 +0800 Subject: [PATCH 068/104] another fix --- test/discrete/test_sac.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 67da3fb57..3d3df6f2c 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -62,15 +62,18 @@ def test_discrete_sac(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, softmax_output=False).to(args.device) + actor = Actor(net, args.action_shape, + softmax_output=False, device=args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device) + critic1 = Critic(net_c1, last_size=args.action_shape, + device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device) + critic2 = Critic(net_c2, last_size=args.action_shape, + device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) # better not to use auto alpha in CartPole From 15bd6e7b53fa941f59652675fed11433e3929b0a Mon Sep 17 00:00:00 2001 From: n+e Date: Tue, 9 Feb 2021 16:42:10 +0800 Subject: [PATCH 069/104] Apply suggestions from code review --- tianshou/utils/net/continuous.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index cf1647bf6..a8f667532 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -50,7 +50,7 @@ def __init__( input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) self.last = MLP(input_dim, self.output_dim, - hidden_sizes, device = self.device) + hidden_sizes, device=self.device) self._max = max_action def forward( @@ -99,7 +99,7 @@ def __init__( self.output_dim = 1 input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, 1, hidden_sizes, device = self.device) + self.last = MLP(input_dim, 1, hidden_sizes, device=self.device) def forward( self, @@ -166,11 +166,11 @@ def __init__( input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) self.mu = MLP(input_dim, self.output_dim, - hidden_sizes, device = self.device) + hidden_sizes, device=self.device) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = MLP(input_dim, self.output_dim, - hidden_sizes, device = self.device) + hidden_sizes, device=self.device) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) self._max = max_action From 38496cf4664e94f6798d3d8d5e220e47d964e826 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Tue, 9 Feb 2021 16:51:35 +0800 Subject: [PATCH 070/104] last fix --- test/discrete/test_a2c_with_il.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index c222bf9a3..08759f92e 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -113,7 +113,7 @@ def stop_fn(mean_rewards): env.spec.reward_threshold = 190 # lower the goal net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - net = Actor(net, args.action_shape).to(args.device) + net = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='discrete') il_test_collector = Collector( From a40700df0583464c4cbe44ae04dce79efb2b0a0b Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Tue, 9 Feb 2021 18:20:30 +0800 Subject: [PATCH 071/104] pep8 fix --- docs/tutorials/tictactoe.rst | 2 +- examples/atari/atari_dqn.py | 2 +- examples/atari/runnable/pong_a2c.py | 2 +- examples/atari/runnable/pong_ppo.py | 2 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/bipedal_hardcore_sac.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- examples/box2d/mcc_sac.py | 2 +- examples/mujoco/runnable/ant_v2_ddpg.py | 2 +- examples/mujoco/runnable/ant_v2_td3.py | 2 +- examples/mujoco/runnable/halfcheetahBullet_v0_sac.py | 2 +- examples/mujoco/runnable/point_maze_td3.py | 2 +- test/base/test_returns.py | 1 + test/continuous/test_ddpg.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_sac_with_il.py | 4 ++-- test/continuous/test_td3.py | 2 +- test/discrete/test_a2c_with_il.py | 4 ++-- test/discrete/test_c51.py | 6 +++--- test/discrete/test_dqn.py | 6 +++--- test/discrete/test_drqn.py | 9 +++++---- test/discrete/test_il_bcq.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo.py | 2 +- test/discrete/test_qrdqn.py | 6 +++--- test/discrete/test_sac.py | 2 +- test/modelbase/test_psrl.py | 2 +- test/multiagent/tic_tac_toe.py | 2 +- tianshou/policy/base.py | 5 ++++- tianshou/policy/modelfree/ddpg.py | 4 ++-- tianshou/policy/modelfree/dqn.py | 2 +- tianshou/policy/modelfree/sac.py | 2 +- tianshou/trainer/offpolicy.py | 1 - tianshou/trainer/utils.py | 1 - 34 files changed, 48 insertions(+), 45 deletions(-) diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 3801d8863..6e840abbd 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -295,7 +295,7 @@ With the above preparation, we are close to the first learned agent. The followi policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if args.watch: watch(args) exit(0) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 2d81747d4..ff589dcbb 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -143,7 +143,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, - render=args.render) + render=args.render) pprint.pprint(result) if args.watch: diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index e071e6320..f77257a53 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -99,7 +99,7 @@ def stop_fn(mean_rewards): env = create_atari_environment(args.task) collector = Collector(policy, env, preprocess_fn=preprocess_fn) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 0f774ae84..e932c36fc 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -103,7 +103,7 @@ def stop_fn(mean_rewards): env = create_atari_environment(args.task) collector = Collector(policy, env, preprocess_fn=preprocess_fn) result = collector.collect(n_step=2000, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index bff170bc2..783d51171 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -118,7 +118,7 @@ def test_fn(epoch, env_step): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index ea2d4f239..1448352ae 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -155,7 +155,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index d6039c094..587b69de2 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -114,7 +114,7 @@ def test_fn(epoch, env_step): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 4c54659a6..e350f48f2 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -123,7 +123,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index dd95c2907..5213305b1 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -98,7 +98,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index be59534f9..d8cce63ad 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -107,7 +107,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index 58c8e4e20..fd74b45f1 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -109,7 +109,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index 9e7f03faa..a8d3feaf6 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -115,7 +115,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/base/test_returns.py b/test/base/test_returns.py index d3d7af52a..fdc0ee727 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -17,6 +17,7 @@ def compute_episodic_return_base(batch, gamma): batch.returns = returns return batch + # TODO need to change def test_episodic_returns(size=2560): fn = BasePolicy.compute_episodic_return diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 5d965ff98..20f7b78d6 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -112,7 +112,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index da75f3368..63d747b82 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -132,7 +132,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 51f35341f..466d5f9d0 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -120,7 +120,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') # here we define an imitation collector with a trivial policy policy.eval() @@ -151,7 +151,7 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 998b80e5e..fb82c6912 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -125,7 +125,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index b7a04c7d1..a762d0083 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -107,7 +107,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') policy.eval() # here we define an imitation collector with a trivial policy @@ -136,7 +136,7 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 0f6251eb4..2bc43763b 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -76,10 +76,10 @@ def test_c51(args=get_args()): # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num = len(train_envs), + args.buffer_size, buffer_num=len(train_envs), alpha=args.alpha, beta=args.beta) else: - buf = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs)) + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs) @@ -125,7 +125,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') def test_pc51(args=get_args()): diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 7a9a078b2..626128a1f 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -78,10 +78,10 @@ def test_dqn(args=get_args()): # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num = len(train_envs), + args.buffer_size, buffer_num=len(train_envs), alpha=args.alpha, beta=args.beta) else: - buf = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs)) + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs) @@ -128,7 +128,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index d897ca7e4..aed7debaa 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -66,9 +66,10 @@ def test_drqn(args=get_args()): target_update_freq=args.target_update_freq) # collector train_collector = Collector( - policy, train_envs, VectorReplayBuffer( - args.buffer_size, buffer_num = len(train_envs), - stack_num=args.stack_num, ignore_obs_next=True), + policy, train_envs, + VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + stack_num=args.stack_num, ignore_obs_next=True), exploration_noise=True) # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs) @@ -105,7 +106,7 @@ def test_fn(epoch, env_step): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 697fab23b..2c719f43f 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -104,7 +104,7 @@ def stop_fn(mean_rewards): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == "__main__": diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 25606bebd..74cb7c6a8 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -93,7 +93,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 85eb9e5cd..f69b775d6 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -119,7 +119,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 6299e9efd..32c1f2fc7 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -74,10 +74,10 @@ def test_qrdqn(args=get_args()): # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num = len(train_envs), + args.buffer_size, buffer_num=len(train_envs), alpha=args.alpha, beta=args.beta) else: - buf = VectorReplayBuffer(args.buffer_size, buffer_num = len(train_envs)) + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs) @@ -123,7 +123,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') def test_pqrdqn(args=get_args()): diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 43df24005..e745e6de7 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -119,7 +119,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') if __name__ == '__main__': diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 889f9888b..144d4b444 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -90,7 +90,7 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') elif env.spec.reward_threshold: assert result["best_reward"] >= env.spec.reward_threshold diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 9f35f0cf4..2039b60aa 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -180,4 +180,4 @@ def watch( policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 5a78fa37c..6698bfbfd 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -298,7 +298,7 @@ def compute_nstep_return( with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) target_q = to_numpy(target_q_torch) * \ - BasePolicy.value_mask(batch).reshape(-1, 1) + BasePolicy.value_mask(batch).reshape(-1, 1) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True target_q = _nstep_return(rew, end_flag, target_q, indices, @@ -320,6 +320,7 @@ def _compile(self) -> None: _episodic_return(f32, f64, b, 0.1, 0.1) _nstep_return(f64, b, f32, i64, 0.1, 1, 0.0, 1.0) + @njit def _gae_return( v_s: np.ndarray, @@ -338,6 +339,7 @@ def _gae_return( returns[i] = gae return returns + @njit def _episodic_return( v_s_: np.ndarray, @@ -350,6 +352,7 @@ def _episodic_return( v_s = np.roll(v_s_, 1) return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s + @njit def _nstep_return( rew: np.ndarray, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index bb81433b6..1bca2014a 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -5,7 +5,7 @@ from tianshou.policy import BasePolicy from tianshou.exploration import BaseNoise, GaussianNoise -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer class DDPGPolicy(BasePolicy): @@ -172,4 +172,4 @@ def exploration_noise( if self._noise: act = act + self._noise(act.shape) act = act.clip(self._range[0], self._range[1]) - return act \ No newline at end of file + return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 54b570971..d4cca7cb2 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -186,4 +186,4 @@ def exploration_noise( if hasattr(batch["obs"], "mask"): q_[~batch["obs"].mask[i]] = -np.inf act[i] = q_.argmax() - return act \ No newline at end of file + return act diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index a8fb02f51..091d1243c 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -6,7 +6,7 @@ from tianshou.policy import DDPGPolicy from tianshou.exploration import BaseNoise -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer class SACPolicy(DDPGPolicy): diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index e3db968a2..3c54819a9 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -147,4 +147,3 @@ def offpolicy_trainer( break return gather_info(start_time, train_collector, test_collector, best_reward, best_reward_std) - diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 37a33f321..da39b5532 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,5 +1,4 @@ import time -import numpy as np from torch.utils.tensorboard import SummaryWriter from typing import Dict, List, Union, Callable, Optional From 5c4641541dc3687448d1182c2b2e10304fbaf8d6 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 19:27:49 +0800 Subject: [PATCH 072/104] fix test_ppo --- test/continuous/test_ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index da75f3368..786b809ba 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -24,12 +24,12 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--collect-per-step', type=int, default=16) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=1) + parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) From b91e44fb61d447d023daa363a0843d60d4672e2a Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 20:14:39 +0800 Subject: [PATCH 073/104] fix 4096 --- tianshou/policy/base.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 6698bfbfd..b93888cd9 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -291,18 +291,16 @@ def compute_nstep_return( for _ in range(n_step - 1): indices.append(buffer.next(indices[-1])) indices = np.stack(indices) - # terminal indicates buffer indexes nstep after 'indice', # and are truncated at the end of each episode terminal = indices[-1] with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) - target_q = to_numpy(target_q_torch) * \ - BasePolicy.value_mask(batch).reshape(-1, 1) + target_q = to_numpy(target_q_torch) * BasePolicy.value_mask(batch) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True - target_q = _nstep_return(rew, end_flag, target_q, indices, - gamma, n_step, mean, std) + target_q = _nstep_return(rew, end_flag, target_q.reshape(-1, 1), + indices, gamma, n_step, mean, std) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update @@ -366,7 +364,7 @@ def _nstep_return( ) -> np.ndarray: gamma_buffer = np.ones(n_step + 1) for i in range(1, n_step + 1): - gamma_buffer[i] = gamma_buffer[i - 1]*gamma + gamma_buffer[i] = gamma_buffer[i - 1] * gamma target_shape = target_q.shape bsz = target_shape[0] # change target_q to 2d array From d75f5b56421da7fb8130358986c1d976e72c8383 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 20:17:41 +0800 Subject: [PATCH 074/104] fix a2c and pg --- test/discrete/test_a2c_with_il.py | 2 +- test/discrete/test_pg.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index a762d0083..c7d45f792 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--collect-per-step', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 74cb7c6a8..eb9aba675 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -22,7 +22,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--collect-per-step', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, From f16c5ff8355cca4ea41cb54ba5912dcbef602850 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 20:28:21 +0800 Subject: [PATCH 075/104] fix print --- examples/atari/runnable/pong_a2c.py | 3 ++- examples/atari/runnable/pong_ppo.py | 3 ++- examples/box2d/acrobot_dualdqn.py | 3 ++- examples/box2d/bipedal_hardcore_sac.py | 3 ++- examples/box2d/lunarlander_dqn.py | 6 +++--- examples/box2d/mcc_sac.py | 6 +++--- examples/mujoco/mujoco_sac.py | 3 +-- examples/mujoco/runnable/ant_v2_ddpg.py | 6 +++--- examples/mujoco/runnable/ant_v2_td3.py | 6 +++--- examples/mujoco/runnable/halfcheetahBullet_v0_sac.py | 3 ++- examples/mujoco/runnable/point_maze_td3.py | 3 ++- test/continuous/test_ddpg.py | 3 ++- test/continuous/test_ppo.py | 3 ++- test/continuous/test_sac_with_il.py | 6 ++++-- test/continuous/test_td3.py | 3 ++- test/discrete/test_a2c_with_il.py | 6 ++++-- test/discrete/test_c51.py | 3 ++- test/discrete/test_dqn.py | 3 ++- test/discrete/test_drqn.py | 3 ++- test/discrete/test_il_bcq.py | 3 ++- test/discrete/test_pg.py | 3 ++- test/discrete/test_ppo.py | 3 ++- test/discrete/test_qrdqn.py | 3 ++- test/discrete/test_sac.py | 3 ++- test/modelbase/test_psrl.py | 6 +++--- test/multiagent/tic_tac_toe.py | 3 ++- 26 files changed, 60 insertions(+), 39 deletions(-) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index f77257a53..0b81cecd6 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -99,7 +99,8 @@ def stop_fn(mean_rewards): env = create_atari_environment(args.task) collector = Collector(policy, env, preprocess_fn=preprocess_fn) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index e932c36fc..8ed04c21e 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -103,7 +103,8 @@ def stop_fn(mean_rewards): env = create_atari_environment(args.task) collector = Collector(policy, env, preprocess_fn=preprocess_fn) result = collector.collect(n_step=2000, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 783d51171..dadd7ddb1 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -118,7 +118,8 @@ def test_fn(epoch, env_step): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 1448352ae..0bf802d3b 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -155,7 +155,8 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 587b69de2..c01410bb7 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -112,9 +112,9 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index e350f48f2..14e5095e7 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -121,9 +121,9 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 894e59f81..ba9cd79a3 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -122,8 +122,7 @@ def watch(): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) def save_fn(policy): diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index 5213305b1..b9a6e0118 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -96,9 +96,9 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index d8cce63ad..2f8370217 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -105,9 +105,9 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index fd74b45f1..6f34ce0ad 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -109,7 +109,8 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index a8d3feaf6..76271f40f 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -115,7 +115,8 @@ def stop_fn(mean_rewards): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 20f7b78d6..68d3fc433 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -112,7 +112,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 1b17573e3..45a59f425 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -132,7 +132,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 466d5f9d0..6e075bb5c 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -120,7 +120,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") # here we define an imitation collector with a trivial policy policy.eval() @@ -151,7 +152,8 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index fb82c6912..c90a92aa4 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -125,7 +125,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index c7d45f792..ea7e6b6ad 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -107,7 +107,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") policy.eval() # here we define an imitation collector with a trivial policy @@ -136,7 +137,8 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 2bc43763b..31c0c2a1a 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -125,7 +125,8 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") def test_pc51(args=get_args()): diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 626128a1f..273213af9 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -128,7 +128,8 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index aed7debaa..62569fe2e 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -106,7 +106,8 @@ def test_fn(epoch, env_step): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 2c719f43f..9691119e3 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -104,7 +104,8 @@ def stop_fn(mean_rewards): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == "__main__": diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index eb9aba675..a413111b7 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -93,7 +93,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index f69b775d6..9862ea7f7 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -119,7 +119,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 32c1f2fc7..3a4f0c998 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -123,7 +123,8 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") def test_pqrdqn(args=get_args()): diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index e745e6de7..16ab54cb2 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -119,7 +119,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 144d4b444..5a813874f 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -88,9 +88,9 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") elif env.spec.reward_threshold: assert result["best_reward"] >= env.spec.reward_threshold diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 2039b60aa..cf58aadcb 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -180,4 +180,5 @@ def watch( policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") From 924fde6c71d1f8dcf3460f5847dccd2b80bc8c0e Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 20:46:31 +0800 Subject: [PATCH 076/104] fix some mypy --- tianshou/data/collector.py | 17 ++++++++++++----- tianshou/policy/base.py | 2 +- tianshou/trainer/offline.py | 4 ++-- tianshou/trainer/offpolicy.py | 4 ++-- tianshou/trainer/onpolicy.py | 4 ++-- tianshou/trainer/utils.py | 19 ++++++++----------- 6 files changed, 27 insertions(+), 23 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 78d198b7a..63d02f9b2 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -3,7 +3,7 @@ import torch import warnings import numpy as np -from typing import Dict, List, Union, Optional, Callable +from typing import Any, Dict, List, Union, Optional, Callable from tianshou.policy import BasePolicy from tianshou.data.buffer import _alloc_by_keys_diff @@ -171,7 +171,7 @@ def collect( random: bool = False, render: Optional[float] = None, no_grad: bool = True, - ) -> Dict[str, float]: + ) -> Dict[str, Any]: """Collect a specified number of step or episode. To ensure unbiased sampling result with n_episode option, this function will @@ -213,10 +213,14 @@ def collect( "which may cause extra frame collected into the buffer." ) ready_env_ids = np.arange(self.env_num) - else: + elif n_episode is not None: assert n_episode > 0 ready_env_ids = np.arange(min(self.env_num, n_episode)) self.data = self.data[:min(self.env_num, n_episode)] + else: + raise TypeError("Please specify at least one (either n_step or n_episode) " + "in AsyncCollector.collect().") + start_time = time.time() step_count = 0 @@ -362,7 +366,7 @@ def collect( random: bool = False, render: Optional[float] = None, no_grad: bool = True, - ) -> Dict[str, float]: + ) -> Dict[str, Any]: """Collect a specified number of step or episode with async env setting. This function doesn't collect exactly n_step or n_episode number of frames. @@ -398,8 +402,11 @@ def collect( f"collect, got n_step={n_step}, n_episode={n_episode}." ) assert n_step > 0 - else: + elif n_episode is not None: assert n_episode > 0 + else: + raise TypeError("Please specify at least one (either n_step or n_episode) " + "in AsyncCollector.collect().") warnings.warn("Using async setting may collect extra frames into buffer.") ready_env_ids = self._ready_env_ids diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index b93888cd9..bdfa7b464 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -190,7 +190,7 @@ def update( return result @staticmethod - def value_mask(batch): + def value_mask(batch: Batch) -> np.ndarray: # TODO doc return ~batch.done.astype(np.bool) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 2547a2507..56c9ef9e3 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -2,7 +2,7 @@ import tqdm from collections import defaultdict from torch.utils.tensorboard import SummaryWriter -from typing import Dict, List, Union, Callable, Optional +from typing import Dict, Union, Callable, Optional from tianshou.policy import BasePolicy from tianshou.utils import tqdm_config, MovAvg @@ -16,7 +16,7 @@ def offline_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - episode_per_test: Union[int, List[int]], + episode_per_test: int, batch_size: int, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 3c54819a9..4e32929aa 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -2,7 +2,7 @@ import tqdm from collections import defaultdict from torch.utils.tensorboard import SummaryWriter -from typing import Dict, List, Union, Callable, Optional +from typing import Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy @@ -17,7 +17,7 @@ def offpolicy_trainer( max_epoch: int, step_per_epoch: int, collect_per_step: int, - episode_per_test: Union[int, List[int]], + episode_per_test: int, batch_size: int, update_per_step: int = 1, train_fn: Optional[Callable[[int, int], None]] = None, diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 4d5a496f7..11b5da71c 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -2,7 +2,7 @@ import tqdm from collections import defaultdict from torch.utils.tensorboard import SummaryWriter -from typing import Dict, List, Union, Callable, Optional +from typing import Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy @@ -18,7 +18,7 @@ def onpolicy_trainer( step_per_epoch: int, collect_per_step: int, repeat_per_collect: int, - episode_per_test: Union[int, List[int]], + episode_per_test: int, batch_size: int, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index da39b5532..927c61141 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,6 +1,6 @@ import time from torch.utils.tensorboard import SummaryWriter -from typing import Dict, List, Union, Callable, Optional +from typing import Any, Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy @@ -11,10 +11,10 @@ def test_episode( collector: Collector, test_fn: Optional[Callable[[int, Optional[int]], None]], epoch: int, - n_episode: Union[int, List[int]], + n_episode: int, writer: Optional[SummaryWriter] = None, global_step: Optional[int] = None, -) -> Dict[str, float]: +) -> Dict[str, Any]: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() @@ -23,14 +23,11 @@ def test_episode( test_fn(epoch, global_step) result = collector.collect(n_episode=n_episode) if writer is not None and global_step is not None: - writer.add_scalar( - "test/rew", result['rews'].mean(), global_step=global_step) - writer.add_scalar( - "test/rew_std", result['rews'].std(), global_step=global_step) - writer.add_scalar( - "test/len", result['lens'].mean(), global_step=global_step) - writer.add_scalar( - "test/len_std", result['lens'].std(), global_step=global_step) + rews, lens = result["rews"], result["lens"] + writer.add_scalar("test/rew", rews.mean(), global_step=global_step) + writer.add_scalar("test/rew_std", rews.std(), global_step=global_step) + writer.add_scalar("test/len", lens.mean(), global_step=global_step) + writer.add_scalar("test/len_std", lens.std(), global_step=global_step) return result From 7f5b51c0f63c5eb8bcca89188913559d343d7889 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 21:00:37 +0800 Subject: [PATCH 077/104] fix multidim target_q nstep error --- tianshou/policy/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index bdfa7b464..0726ca2be 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -279,6 +279,7 @@ def compute_nstep_return( torch.Tensor with the same shape as target_q_fn's return tensor. """ rew = buffer.rew + bsz = len(indice) # TODO this rew_norm will cause unstablity in training if rew_norm: bfr = rew[:min(len(buffer), 1000)] # avoid large buffer @@ -296,10 +297,11 @@ def compute_nstep_return( terminal = indices[-1] with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) - target_q = to_numpy(target_q_torch) * BasePolicy.value_mask(batch) + target_q = to_numpy(target_q_torch.reshape(bsz, -1)) + target_q = target_q * BasePolicy.value_mask(batch).reshape(-1, 1) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True - target_q = _nstep_return(rew, end_flag, target_q.reshape(-1, 1), + target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step, mean, std) batch.returns = to_torch_as(target_q, target_q_torch) @@ -316,7 +318,7 @@ def _compile(self) -> None: _gae_return(f32, f32, f64, b, 0.1, 0.1) _episodic_return(f64, f64, b, 0.1, 0.1) _episodic_return(f32, f64, b, 0.1, 0.1) - _nstep_return(f64, b, f32, i64, 0.1, 1, 0.0, 1.0) + _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1, 0.0, 1.0) @njit @@ -351,7 +353,7 @@ def _episodic_return( return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s -@njit +# @njit def _nstep_return( rew: np.ndarray, end_flag: np.ndarray, From 4ab067f069a1e2cba5dfe2df9e7bdc9c2096c377 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 21:28:06 +0800 Subject: [PATCH 078/104] fix throughput --- test/base/test_collector.py | 16 +- test/throughput/env.py | 1 + test/throughput/test_buffer_profile.py | 141 +++++-------- test/throughput/test_collector_profile.py | 243 +++++++++------------- 4 files changed, 163 insertions(+), 238 deletions(-) create mode 120000 test/throughput/env.py diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 9148fe967..e8c886c51 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -123,16 +123,6 @@ def test_collector(): obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3] assert np.all(c2.buffer.obs[:, 0] == obs) c2.collect(n_episode=4, random=True) - env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] - dum = DummyVectorEnv(env_fns) - num = len(env_fns) - c3 = Collector(policy, dum, - VectorReplayBuffer(total_size=40000, buffer_num=num)) - for i in tqdm.trange(num, 400, desc="test step collector n_episode"): - c3.reset() - result = c3.collect(n_episode=i) - assert result['n/ep'] == i - assert result['n/st'] == len(c3.buffer) def test_collector_with_async(): @@ -144,13 +134,13 @@ def test_collector_with_async(): venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MyPolicy() - bufsize = 300 + bufsize = 60 c1 = AsyncCollector( policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), logger.preprocess_fn) ptr = [0, 0, 0, 0] - for n_episode in tqdm.trange(1, 100, desc="test async n_episode"): + for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): result = c1.collect(n_episode=n_episode) assert result["n/ep"] >= n_episode # check buffer data, obs and obs_next, env_id @@ -167,7 +157,7 @@ def test_collector_with_async(): assert np.all(buf.obs_next[indices].reshape( count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data - for n_step in tqdm.trange(1, 150, desc="test async n_step"): + for n_step in tqdm.trange(1, 15, desc="test async n_step"): result = c1.collect(n_step=n_step) assert result["n/st"] >= n_step for i in range(4): diff --git a/test/throughput/env.py b/test/throughput/env.py new file mode 120000 index 000000000..9a57534db --- /dev/null +++ b/test/throughput/env.py @@ -0,0 +1 @@ +../base/env.py \ No newline at end of file diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index be82f0b7f..40ce68889 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -1,88 +1,61 @@ -import pytest +import sys +import gym +import time +import tqdm import numpy as np - -from tianshou.data import (PrioritizedReplayBuffer, - ReplayBuffer, SegmentTree) - - -@pytest.fixture(scope="module") -def data(): - np.random.seed(0) - obs = {'observable': np.random.rand(100, 100), - 'hidden': np.random.randint(1000, size=200)} - info = {'policy': "dqn", 'base': np.arange(10)} - add_data = {'obs': obs, 'rew': 1., 'act': np.random.rand(30), - 'done': False, 'obs_next': obs, 'info': info} - buffer = ReplayBuffer(int(1e3), stack_num=100) - buffer2 = ReplayBuffer(int(1e4), stack_num=100) - indexes = np.random.choice(int(1e3), size=3, replace=False) - return { - 'add_data': add_data, - 'buffer': buffer, - 'buffer2': buffer2, - 'slice': slice(-3000, -1000, 2), - 'indexes': indexes, - } - - -def test_init(): - for _ in np.arange(1e5): - _ = ReplayBuffer(1e5) - _ = PrioritizedReplayBuffer(size=int(1e5), alpha=0.5, beta=0.5) - - -def test_add(data): - buffer = data['buffer'] - for _ in np.arange(1e5): - buffer.add(**data['add_data']) - - -def test_update(data): - buffer = data['buffer'] - buffer2 = data['buffer2'] - for _ in np.arange(1e2): - buffer2.update(buffer) - - -def test_getitem_slice(data): - Slice = data['slice'] - buffer = data['buffer'] - for _ in np.arange(1e3): - _ = buffer[Slice] - - -def test_getitem_indexes(data): - indexes = data['indexes'] - buffer = data['buffer'] - for _ in np.arange(1e2): - _ = buffer[indexes] - - -def test_get(data): - indexes = data['indexes'] - buffer = data['buffer'] - for _ in np.arange(3e2): - buffer.get(indexes, 'obs') - buffer.get(indexes, 'rew') - buffer.get(indexes, 'done') - buffer.get(indexes, 'info') - - -def test_sample(data): - buffer = data['buffer'] - for _ in np.arange(1e1): - buffer.sample(int(1e2)) - - -def test_segtree(data): - size = 100000 - tree = SegmentTree(size) - tree[np.arange(size)] = np.random.rand(size) - - for i in np.arange(1e5): - scalar = np.random.rand(64) * tree.reduce() - tree.get_prefix_sum_idx(scalar) +from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer + + +def test_replaybuffer(task="Pendulum-v0"): + total_count = 5 + for _ in tqdm.trange(total_count, desc="ReplayBuffer"): + env = gym.make(task) + buf = ReplayBuffer(10000) + obs = env.reset() + for i in range(100000): + act = env.action_space.sample() + obs_next, rew, done, info = env.step(act) + batch = Batch( + obs=np.array([obs]), + act=np.array([act]), + rew=np.array([rew]), + done=np.array([done]), + obs_next=np.array([obs_next]), + info=np.array([info]), + ) + buf.add(batch, buffer_ids=[0]) + obs = obs_next + if done: + obs = env.reset() + + +def test_vectorbuffer(task="Pendulum-v0"): + total_count = 5 + for _ in tqdm.trange(total_count, desc="VectorReplayBuffer"): + env = gym.make(task) + buf = VectorReplayBuffer(total_size=10000, buffer_num=1) + obs = env.reset() + for i in range(100000): + act = env.action_space.sample() + obs_next, rew, done, info = env.step(act) + batch = Batch( + obs=np.array([obs]), + act=np.array([act]), + rew=np.array([rew]), + done=np.array([done]), + obs_next=np.array([obs_next]), + info=np.array([info]), + ) + buf.add(batch) + obs = obs_next + if done: + obs = env.reset() if __name__ == '__main__': - pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"]) + t0 = time.time() + test_replaybuffer(sys.argv[-1]) + print("test replaybuffer: ", time.time() - t0) + t0 = time.time() + test_vectorbuffer(sys.argv[-1]) + print("test vectorbuffer: ", time.time() - t0) diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py index 7e4e8ac75..6242e694b 100644 --- a/test/throughput/test_collector_profile.py +++ b/test/throughput/test_collector_profile.py @@ -1,148 +1,109 @@ -import gym +import tqdm import numpy as np -import pytest -from gym.spaces.discrete import Discrete -from gym.utils import seeding -from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer -from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy +from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.data import Batch, Collector, AsyncCollector, VectorReplayBuffer - -class SimpleEnv(gym.Env): - """A simplest example of self-defined env, used to minimize - data collect time and profile collector.""" - - def __init__(self): - self.action_space = Discrete(200) - self._fake_data = np.ones((10, 10, 1)) - self.seed(0) - self.reset() - - def reset(self): - self._index = 0 - self.done = np.random.randint(3, high=200) - return {'observable': np.zeros((10, 10, 1)), 'hidden': self._index} - - def step(self, action): - if self._index == self.done: - raise ValueError('step after done !!!') - self._index += 1 - return {'observable': self._fake_data, 'hidden': self._index}, -1, \ - self._index == self.done, {} - - def seed(self, seed=None): - self.np_random, seed = seeding.np_random(seed) - return [seed] - - -class SimplePolicy(BasePolicy): - """A simplest example of self-defined policy, used - to minimize data collect time.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def learn(self, batch, **kwargs): - return super().learn(batch, **kwargs) - - def forward(self, batch, state=None, **kwargs): - return Batch(act=np.array([30] * len(batch)), state=None, logits=None) - - -@pytest.fixture(scope="module") -def data(): - np.random.seed(0) - env = SimpleEnv() - env.seed(0) - env_vec = DummyVectorEnv([lambda: SimpleEnv() for _ in range(100)]) - env_vec.seed(np.random.randint(1000, size=100).tolist()) - env_subproc = SubprocVectorEnv([lambda: SimpleEnv() for _ in range(8)]) - env_subproc.seed(np.random.randint(1000, size=100).tolist()) - env_subproc_init = SubprocVectorEnv( - [lambda: SimpleEnv() for _ in range(8)]) - env_subproc_init.seed(np.random.randint(1000, size=100).tolist()) - buffer = ReplayBuffer(50000) - vec_buffer = VectorReplayBuffer(50000, 100) - policy = SimplePolicy() - collector = Collector(policy, env, ReplayBuffer(50000)) - collector_vec = Collector(policy, env_vec, - VectorReplayBuffer(50000, env_vec.env_num)) - collector_subproc = Collector(policy, env_subproc, - VectorReplayBuffer(50000, env_subproc.env_num)) - return { - "env": env, - "env_vec": env_vec, - "env_subproc": env_subproc, - "env_subproc_init": env_subproc_init, - "policy": policy, - "buffer": buffer, - "vec_buffer": vec_buffer, - "collector": collector, - "collector_vec": collector_vec, - "collector_subproc": collector_subproc, - } - - -def test_init(data): - for _ in range(5000): - Collector(data["policy"], data["env"], data["buffer"]) - - -def test_reset(data): - for _ in range(5000): - data["collector"].reset() - - -def test_collect_st(data): - for _ in range(50): - data["collector"].collect(n_step=1000) - - -def test_collect_ep(data): - for _ in range(50): - data["collector"].collect(n_episode=10) - - -def test_init_vec_env(data): - for _ in range(5000): - Collector(data["policy"], data["env_vec"], data["vec_buffer"]) - - -def test_reset_vec_env(data): - for _ in range(5000): - data["collector_vec"].reset() - - -def test_collect_vec_env_st(data): - for _ in range(50): - data["collector_vec"].collect(n_step=1000) - - -def test_collect_vec_env_ep(data): - for _ in range(50): - data["collector_vec"].collect(n_episode=10) - - -def test_init_subproc_env(data): - for _ in range(5000): - Collector(data["policy"], data["env_subproc_init"], data["vec_buffer"]) - - -def test_reset_subproc_env(data): - for _ in range(5000): - data["collector_subproc"].reset() - - -def test_collect_subproc_env_st(data): - for _ in range(50): - data["collector_subproc"].collect(n_step=1000) - - -def test_collect_subproc_env_ep(data): - for _ in range(50): - data["collector_subproc"].collect(n_episode=10) +if __name__ == '__main__': + from env import MyTestEnv +else: # pytest + from test.base.env import MyTestEnv + + +class MyPolicy(BasePolicy): + def __init__(self, dict_state=False, need_state=True): + """ + :param bool dict_state: if the observation of the environment is a dict + :param bool need_state: if the policy needs the hidden state (for RNN) + """ + super().__init__() + self.dict_state = dict_state + self.need_state = need_state + + def forward(self, batch, state=None): + if self.need_state: + if state is None: + state = np.zeros((len(batch.obs), 2)) + else: + state += 1 + if self.dict_state: + return Batch(act=np.ones(len(batch.obs['index'])), state=state) + return Batch(act=np.ones(len(batch.obs)), state=state) + + def learn(self): + pass + + +def test_collector_nstep(): + policy = MyPolicy() + env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] + dum = DummyVectorEnv(env_fns) + num = len(env_fns) + c3 = Collector(policy, dum, + VectorReplayBuffer(total_size=40000, buffer_num=num)) + for i in tqdm.trange(1, 400, desc="test step collector n_step"): + c3.reset() + result = c3.collect(n_step=i * len(env_fns)) + assert result['n/st'] >= i + + +def test_collector_nepisode(): + policy = MyPolicy() + env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] + dum = DummyVectorEnv(env_fns) + num = len(env_fns) + c3 = Collector(policy, dum, + VectorReplayBuffer(total_size=40000, buffer_num=num)) + for i in tqdm.trange(1, 400, desc="test step collector n_episode"): + c3.reset() + result = c3.collect(n_episode=i) + assert result['n/ep'] == i + assert result['n/st'] == len(c3.buffer) + + +def test_asynccollector(): + env_lens = [2, 3, 4, 5] + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) + for i in env_lens] + + venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) + policy = MyPolicy() + bufsize = 300 + c1 = AsyncCollector( + policy, venv, + VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4)) + ptr = [0, 0, 0, 0] + for n_episode in tqdm.trange(1, 100, desc="test async n_episode"): + result = c1.collect(n_episode=n_episode) + assert result["n/ep"] >= n_episode + # check buffer data, obs and obs_next, env_id + for i, count in enumerate( + np.bincount(result["lens"], minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape( + count, env_len) == seq + 1) + # test async n_step, for now the buffer should be full of data + for n_step in tqdm.trange(1, 150, desc="test async n_step"): + result = c1.collect(n_step=n_step) + assert result["n/st"] >= n_step + for i in range(4): + env_len = i + 2 + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id == i) + assert np.all(buf.obs.reshape(-1, env_len) == seq) + assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) if __name__ == '__main__': - pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"]) + test_collector_nstep() + test_collector_nepisode() + test_asynccollector() From 4f6a9834ffe10504afe4de26a93eccc69f55d98c Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 21:29:22 +0800 Subject: [PATCH 079/104] open nstep njit --- tianshou/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 0726ca2be..0721f62d0 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -353,7 +353,7 @@ def _episodic_return( return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s -# @njit +@njit def _nstep_return( rew: np.ndarray, end_flag: np.ndarray, From cbb0fda35596de7565c961ad6c8054b0d09f3d97 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Tue, 9 Feb 2021 21:36:26 +0800 Subject: [PATCH 080/104] fix assert False --- tianshou/trainer/offpolicy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 4e32929aa..d7c6842a3 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -135,7 +135,7 @@ def offpolicy_trainer( result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step) if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward = result["rews"].mean(), result["rews"].std() + best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() best_epoch = epoch if save_fn: save_fn(policy) From 6c0b52c0f1f58b4e33c93a9f6ac41047bd0088d4 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Wed, 10 Feb 2021 15:02:27 +0800 Subject: [PATCH 081/104] fix a strange bug in buffer init --- tianshou/data/buffer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 9b6c8583d..3770db931 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -50,6 +50,7 @@ def __init__( ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, + **kwargs: Any, ) -> None: self.options: Dict[str, Any] = { "stack_num": stack_num, @@ -358,7 +359,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): """ def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None: - super().__init__(size, **kwargs) + # super().__init__(size, **kwargs) will raise keyword error, don't know why, yet. + ReplayBuffer.__init__(self, size, **kwargs) assert alpha > 0.0 and beta >= 0.0 self._alpha, self._beta = alpha, beta self._max_prio = self._min_prio = 1.0 From e8a71a71fae47abd9dfe581eaea3f390eaf95b2a Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 10 Feb 2021 17:28:53 +0800 Subject: [PATCH 082/104] some bugs in nstep --- test/base/test_returns.py | 13 +++++++------ tianshou/data/buffer.py | 5 +++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index fdc0ee727..e262eca1a 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -102,14 +102,15 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): def test_nstep_returns(size=10000): buf = ReplayBuffer(10) for i in range(12): - buf.add(obs=0, act=0, rew=i + 1, done=i % 4 == 3) + buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3)) batch, indice = buf.sample(0) assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [10, 11, 2, 3, 4, 5, 6, 7, 8, 9] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn, gamma=.1, n_step=1).pop('returns')) + batch, buf, indice, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1)) assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indice) assert np.allclose(returns, r_), (r_, returns) @@ -119,9 +120,9 @@ def test_nstep_returns(size=10000): assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns')) - assert np.allclose(returns, [ - 3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) + batch, buf, indice, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indice) assert np.allclose(returns, r_) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( @@ -143,7 +144,7 @@ def test_nstep_returns(size=10000): if __name__ == '__main__': buf = ReplayBuffer(size) for i in range(int(size * 1.5)): - buf.add(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0) + buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0)) batch, indice = buf.sample(256) def vanilla(): diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 3770db931..f3636afe6 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -50,7 +50,7 @@ def __init__( ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, - **kwargs: Any, + **kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError ) -> None: self.options: Dict[str, Any] = { "stack_num": stack_num, @@ -359,7 +359,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): """ def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None: - # super().__init__(size, **kwargs) will raise keyword error, don't know why, yet. + # will raise KeyError in PrioritizedVectorReplayBuffer + # super().__init__(size, **kwargs) ReplayBuffer.__init__(self, size, **kwargs) assert alpha > 0.0 and beta >= 0.0 self._alpha, self._beta = alpha, beta From ea6b3c31620b6fc248def38a01cc80e55bc759d8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Wed, 10 Feb 2021 17:51:47 +0800 Subject: [PATCH 083/104] fix n_episode test --- test/base/test_returns.py | 41 +++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index e262eca1a..b51fae9a5 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -21,50 +21,71 @@ def compute_episodic_return_base(batch, gamma): # TODO need to change def test_episodic_returns(size=2560): fn = BasePolicy.compute_episodic_return + buf = ReplayBuffer(20) batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), ) - batch = fn(batch, None, gamma=.1, gae_lambda=1) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) assert np.allclose(batch.returns, ans) + buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 1, 0.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) - batch = fn(batch, None, gamma=.1, gae_lambda=1) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) assert np.allclose(batch.returns, ans) + buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 0, 1.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) - batch = fn(batch, None, gamma=.1, gae_lambda=1) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) assert np.allclose(batch.returns, ans) + buf.reset() batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), - rew=np.array([ - 101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]) + rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), ) + for b in batch: + b.obs = b.act = 1 + buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) - ret = fn(batch, v, gamma=0.99, gae_lambda=0.95) + ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) returns = np.array([ 454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201., 474.2876, 390.1027, 299.476, 202.]) assert np.allclose(ret.returns, returns) + buf.reset() if __name__ == '__main__': + buf = ReplayBuffer(size) batch = Batch( done=np.random.randint(100, size=size) == 0, rew=np.random.random(size), ) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + indice = buf.sample_index(0) def vanilla(): return compute_episodic_return_base(batch, gamma=.1) def optimized(): - return fn(batch, gamma=.1) + return fn(batch, buf, indice, gamma=.1, gae_lambda=1.0) cnt = 3000 print('GAE vanilla', timeit(vanilla, setup=vanilla, number=cnt)) @@ -131,9 +152,9 @@ def test_nstep_returns(size=10000): assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns')) - assert np.allclose(returns, [ - 3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) + batch, buf, indice, target_q_fn, gamma=.1, n_step=10 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indice) assert np.allclose(returns, r_) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( From 3365eec57de3c880b86aa97fb43ea3b8896d8ab4 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Wed, 10 Feb 2021 18:58:45 +0800 Subject: [PATCH 084/104] value mask bug fix --- test/base/test_returns.py | 4 ++-- tianshou/policy/base.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index b51fae9a5..9884a2b23 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -94,7 +94,7 @@ def optimized(): def target_q_fn(buffer, indice): # return the next reward - indice = (indice + 1 - buffer.done[indice]) % len(buffer) + indice = buffer.next(indice) return torch.tensor(-buffer.rew[indice], dtype=torch.float32) @@ -126,7 +126,7 @@ def test_nstep_returns(size=10000): buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3)) batch, indice = buf.sample(0) assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) - # rew: [10, 11, 2, 3, 4, 5, 6, 7, 8, 9] + # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy(BasePolicy.compute_nstep_return( diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 0721f62d0..3e107d963 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -190,9 +190,9 @@ def update( return result @staticmethod - def value_mask(batch: Batch) -> np.ndarray: + def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray: # TODO doc - return ~batch.done.astype(np.bool) + return ~buffer.done[indice].astype(np.bool) @staticmethod def compute_episodic_return( @@ -229,7 +229,7 @@ def compute_episodic_return( assert np.isclose(gae_lambda, 1.0) v_s_ = np.zeros_like(rew) else: - v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(batch) + v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice) end_flag = batch.done.copy() end_flag[np.isin(indice, buffer.unfinished_index())] = True @@ -298,7 +298,7 @@ def compute_nstep_return( with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) target_q = to_numpy(target_q_torch.reshape(bsz, -1)) - target_q = target_q * BasePolicy.value_mask(batch).reshape(-1, 1) + target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True target_q = _nstep_return(rew, end_flag, target_q, From 8a61dd956fe8fa6b7df03aa578a98fef755ec174 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Wed, 17 Feb 2021 12:02:30 +0800 Subject: [PATCH 085/104] update doc --- README.md | 7 +++--- docs/tutorials/tictactoe.rst | 3 +-- test/base/test_returns.py | 1 - tianshou/data/buffer.py | 44 ------------------------------------ tianshou/policy/base.py | 41 ++++++++++++++++++++++----------- 5 files changed, 33 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 1c845e545..3ba5f705e 100644 --- a/README.md +++ b/README.md @@ -158,13 +158,14 @@ Currently, the overall code of Tianshou platform is less than 2500 lines. Most o ```python result = collector.collect(n_step=n) ``` -# TODO remove this -If you have 3 environments in total and want to collect 1 episode in the first environment, 3 for the third environment: +If you have 3 environments in total and want to collect 4 episodes: ```python -result = collector.collect(n_episode=[1, 0, 3]) +result = collector.collect(n_episode=4) ``` +Collector will collect exactly 4 episodes without any bias of episode length despite we only have 3 parallel environments. + If you want to train the given policy with a sampled batch: ```python diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 6e840abbd..4aa356e6f 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -128,8 +128,7 @@ Tianshou already provides some builtin classes for multi-agent learning. You can >>> >>> # use collectors to collect a episode of trajectories >>> # the reward is a vector, so we need a scalar metric to monitor the training - # TODO remove reward_metric - >>> collector = Collector(policy, env, reward_metric=lambda x: x[0]) + >>> collector = Collector(policy, env) >>> >>> # you will see a long trajectory showing the board status at each timestep >>> result = collector.collect(n_episode=1, render=.1) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 9884a2b23..5aba7f56a 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -18,7 +18,6 @@ def compute_episodic_return_base(batch, gamma): return batch -# TODO need to change def test_episodic_returns(size=2560): fn = BasePolicy.compute_episodic_return buf = ReplayBuffer(20) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 4181efeee..9904fa52d 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -102,54 +102,10 @@ def __setattr__(self, key: str, value: Any) -> None: ), "key '{}' is reserved and cannot be assigned".format(key) super().__setattr__(key, value) -<<<<<<< HEAD def save_hdf5(self, path: str) -> None: """Save replay buffer to HDF5 file.""" with h5py.File(path, "w") as f: to_hdf5(self.__dict__, f) -======= - def _add_to_buffer(self, name: str, inst: Any) -> None: - try: - value = self._meta.__dict__[name] - except KeyError: - self._meta.__dict__[name] = _create_value(inst, self._maxsize) - value = self._meta.__dict__[name] - if isinstance(inst, (torch.Tensor, np.ndarray)): - if inst.shape != value.shape[1:]: - raise ValueError( - "Cannot add data to a buffer with different shape with key" - f" {name}, expect {value.shape[1:]}, given {inst.shape}." - ) - try: - value[self._index] = inst - except ValueError: - for key in set(inst.keys()).difference(value.__dict__.keys()): - value.__dict__[key] = _create_value(inst[key], self._maxsize) - value[self._index] = inst - - @property - def stack_num(self) -> int: - return self._stack - - @stack_num.setter - def stack_num(self, num: int) -> None: - assert num > 0, "stack_num should greater than 0" - self._stack = num - - def update(self, buffer: "ReplayBuffer") -> None: - """Move the data from the given buffer to self.""" - if len(buffer) == 0: - return - i = begin = buffer._index % len(buffer) - stack_num_orig = buffer.stack_num - buffer.stack_num = 1 - while True: - self.add(**buffer[i]) # type: ignore - i = (i + 1) % len(buffer) - if i == begin: - break - buffer.stack_num = stack_num_orig ->>>>>>> upstream/master @classmethod def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer": diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 3e107d963..91d22e71c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -191,7 +191,22 @@ def update( @staticmethod def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray: - # TODO doc + """Value mask determines whether the obs_next of buffer[indice] is valid. + + For instance, usually 'obs_next' after 'done' flag is considered to be invalid, + and its q/advantage value can provide meaningless(even misleading) information, + and should be set to 0 by hand. But if 'done' flag is generated because of + timelimit of game length (info['TimeLimit.truncated'] is set to True in 'gym' + settings), 'obs_next' will instead be valid. Value mask is typically used + to assist in calculating correct q/advantage value. + + :param ReplayBuffer buffer: the corresponding replay buffer. + :param numpy.ndarray indice: indices of replay buffer whose 'obs_next' will be + judged. + + :return: A bool type numpy.ndarray in the same shape with indice. 'True' means + 'obs_next' of that buffer[indice] is valid. + """ return ~buffer.done[indice].astype(np.bool) @staticmethod @@ -204,16 +219,20 @@ def compute_episodic_return( gae_lambda: float = 0.95, rew_norm: bool = False, ) -> Batch: - # TODO change doc - """Compute returns over given full-length episodes. + """Compute returns over given batch. - Implementation of Generalized Advantage Estimator (arXiv:1506.02438). + Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) + to calculate q function/reward to go of give batch. - :param batch: a data batch which contains several full-episode data - chronologically. TODO generalize + :param batch: a data batch which contains several episodes of data + in sequential order. Mind that the end of each finished episode of batch + should be marked by done flag, unfinished(collecting) episodes will be + recongized by buffer.unfinished_index(). :type batch: :class:`~tianshou.data.Batch` + :param numpy.ndarray indice: tell batch's location in buffer, batch is + equal to buffer[indice]. :param v_s_: the value function of all next states :math:`V(s')`. - :type v_s_: numpy.ndarray #TODO n+1 value shape + :type v_s_: numpy.ndarray :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param float gae_lambda: the parameter for Generalized Advantage @@ -249,7 +268,6 @@ def compute_nstep_return( n_step: int = 1, rew_norm: bool = False, ) -> Batch: - # TODO, doc r"""Compute n-step return for Q-learning targets. .. math:: @@ -264,10 +282,8 @@ def compute_nstep_return( :type batch: :class:`~tianshou.data.Batch` :param buffer: the data buffer. :type buffer: :class:`~tianshou.data.ReplayBuffer` - :param indice: sampled timestep. - :type indice: numpy.ndarray - :param function target_q_fn: a function receives :math:`t+n-1` step's - data and compute target Q value. + :param function target_q_fn: a function which compute target Q value + of 'obs_next' given data buffer and wanted indices. :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param int n_step: the number of estimation step, should be an int @@ -280,7 +296,6 @@ def compute_nstep_return( """ rew = buffer.rew bsz = len(indice) - # TODO this rew_norm will cause unstablity in training if rew_norm: bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() From dc3e150d533cdaa0badd792377e158728183143c Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Wed, 17 Feb 2021 12:04:42 +0800 Subject: [PATCH 086/104] pep8 fix --- tianshou/data/buffer.py | 1 - tianshou/policy/base.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 9904fa52d..f3636afe6 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,6 +1,5 @@ import h5py import torch -import warnings import numpy as np from typing import Any, Dict, List, Tuple, Union, Optional diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 91d22e71c..64e226206 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -192,11 +192,11 @@ def update( @staticmethod def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray: """Value mask determines whether the obs_next of buffer[indice] is valid. - + For instance, usually 'obs_next' after 'done' flag is considered to be invalid, and its q/advantage value can provide meaningless(even misleading) information, - and should be set to 0 by hand. But if 'done' flag is generated because of - timelimit of game length (info['TimeLimit.truncated'] is set to True in 'gym' + and should be set to 0 by hand. But if 'done' flag is generated because of + timelimit of game length (info['TimeLimit.truncated'] is set to True in 'gym' settings), 'obs_next' will instead be valid. Value mask is typically used to assist in calculating correct q/advantage value. From f8e73106ab21f54963dc3420d631f323824356ac Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Wed, 17 Feb 2021 19:07:28 +0800 Subject: [PATCH 087/104] drqn --- test/discrete/test_drqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 62569fe2e..a91d5d3eb 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=3) parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( From 4ffbf5817ef28a0b56588b1d6301cfd2ec79392c Mon Sep 17 00:00:00 2001 From: n+e Date: Thu, 18 Feb 2021 13:01:18 +0800 Subject: [PATCH 088/104] Apply suggestions from code review --- tianshou/policy/base.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 64e226206..9b3858dcf 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -193,19 +193,19 @@ def update( def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray: """Value mask determines whether the obs_next of buffer[indice] is valid. - For instance, usually 'obs_next' after 'done' flag is considered to be invalid, - and its q/advantage value can provide meaningless(even misleading) information, - and should be set to 0 by hand. But if 'done' flag is generated because of - timelimit of game length (info['TimeLimit.truncated'] is set to True in 'gym' - settings), 'obs_next' will instead be valid. Value mask is typically used - to assist in calculating correct q/advantage value. + For instance, usually "obs_next" after "done" flag is considered to be invalid, + and its q/advantage value can provide meaningless (even misleading) information, + and should be set to 0 by hand. But if "done" flag is generated because + timelimit of game length (info["TimeLimit.truncated"] is set to True in gym's + settings), "obs_next" will instead be valid. Value mask is typically used + for assisting in calculating the correct q/advantage value. :param ReplayBuffer buffer: the corresponding replay buffer. - :param numpy.ndarray indice: indices of replay buffer whose 'obs_next' will be - judged. + :param numpy.ndarray indice: indices of replay buffer whose "obs_next" will be + judged. - :return: A bool type numpy.ndarray in the same shape with indice. 'True' means - 'obs_next' of that buffer[indice] is valid. + :return: A bool type numpy.ndarray in the same shape with indice. "True" means + "obs_next" of that buffer[indice] is valid. """ return ~buffer.done[indice].astype(np.bool) @@ -222,11 +222,11 @@ def compute_episodic_return( """Compute returns over given batch. Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) - to calculate q function/reward to go of give batch. + to calculate q function/reward to go of given batch. :param batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch - should be marked by done flag, unfinished(collecting) episodes will be + should be marked by done flag, unfinished (or collecting) episodes will be recongized by buffer.unfinished_index(). :type batch: :class:`~tianshou.data.Batch` :param numpy.ndarray indice: tell batch's location in buffer, batch is @@ -283,7 +283,7 @@ def compute_nstep_return( :param buffer: the data buffer. :type buffer: :class:`~tianshou.data.ReplayBuffer` :param function target_q_fn: a function which compute target Q value - of 'obs_next' given data buffer and wanted indices. + of "obs_next" given data buffer and wanted indices. :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param int n_step: the number of estimation step, should be an int From 51f06b8f5e596c0b657f52134ecb0bc862094420 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 16:31:40 +0800 Subject: [PATCH 089/104] add reward_metric in trainer --- test/multiagent/tic_tac_toe.py | 7 +++++-- tianshou/policy/base.py | 8 ++++---- tianshou/trainer/offline.py | 13 ++++++++++--- tianshou/trainer/offpolicy.py | 15 ++++++++++++--- tianshou/trainer/onpolicy.py | 13 +++++++++++-- tianshou/trainer/utils.py | 4 ++++ 6 files changed, 46 insertions(+), 14 deletions(-) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index cf58aadcb..dfe25d42b 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -157,13 +157,16 @@ def train_fn(epoch, env_step): def test_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) + def reward_metric(rews): + return rews[:, args.agent_id - 1] + # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, + writer=writer, test_in_train=False) return result, policy.policies[args.agent_id - 1] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 9b3858dcf..6167b30cb 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -194,10 +194,10 @@ def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray: """Value mask determines whether the obs_next of buffer[indice] is valid. For instance, usually "obs_next" after "done" flag is considered to be invalid, - and its q/advantage value can provide meaningless (even misleading) information, - and should be set to 0 by hand. But if "done" flag is generated because - timelimit of game length (info["TimeLimit.truncated"] is set to True in gym's - settings), "obs_next" will instead be valid. Value mask is typically used + and its q/advantage value can provide meaningless (even misleading) + information, and should be set to 0 by hand. But if "done" flag is generated + because timelimit of game length (info["TimeLimit.truncated"] is set to True in + gym's settings), "obs_next" will instead be valid. Value mask is typically used for assisting in calculating the correct q/advantage value. :param ReplayBuffer buffer: the corresponding replay buffer. diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 56c9ef9e3..7eb6ec5b5 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,5 +1,6 @@ import time import tqdm +import numpy as np from collections import defaultdict from torch.utils.tensorboard import SummaryWriter from typing import Dict, Union, Callable, Optional @@ -21,6 +22,7 @@ def offline_trainer( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, @@ -29,8 +31,7 @@ def offline_trainer( The "step" in trainer means a policy network update. - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` - class. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param test_collector: the collector used for testing. :type test_collector: :class:`~tianshou.data.Collector` :param int max_epoch: the maximum number of epochs for training. The @@ -49,6 +50,12 @@ def offline_trainer( :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; if None is given, it will not write logs to TensorBoard. :param int log_interval: the log interval of the writer. @@ -81,7 +88,7 @@ def offline_trainer( t.set_postfix(**data) # test result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, gradient_step) + episode_per_test, writer, gradient_step, reward_metric) if best_epoch == -1 or best_reward < result["rews"].mean(): best_reward, best_reward_std = result["rews"].mean(), result['rews'].std() best_epoch = epoch diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index d7c6842a3..ba08c2e0b 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,5 +1,6 @@ import time import tqdm +import numpy as np from collections import defaultdict from torch.utils.tensorboard import SummaryWriter from typing import Dict, Union, Callable, Optional @@ -24,6 +25,7 @@ def offpolicy_trainer( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, @@ -33,8 +35,7 @@ def offpolicy_trainer( The "step" in trainer means a policy network update. - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` - class. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param train_collector: the collector used for training. :type train_collector: :class:`~tianshou.data.Collector` :param test_collector: the collector used for testing. @@ -65,6 +66,12 @@ def offpolicy_trainer( :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; if None is given, it will not write logs to TensorBoard. :param int log_interval: the log interval of the writer. @@ -90,6 +97,8 @@ def offpolicy_trainer( if train_fn: train_fn(epoch, env_step) result = train_collector.collect(n_step=collect_per_step) + if len(result["rews"]) > 0 and reward_metric: + result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) data = { "env_step": str(env_step), @@ -133,7 +142,7 @@ def offpolicy_trainer( t.update() # test result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step) + episode_per_test, writer, env_step, reward_metric) if best_epoch == -1 or best_reward < result["rews"].mean(): best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() best_epoch = epoch diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 11b5da71c..b951f9a9e 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,5 +1,6 @@ import time import tqdm +import numpy as np from collections import defaultdict from torch.utils.tensorboard import SummaryWriter from typing import Dict, Union, Callable, Optional @@ -24,6 +25,7 @@ def onpolicy_trainer( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, @@ -33,8 +35,7 @@ def onpolicy_trainer( The "step" in trainer means a policy network update. - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` - class. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param train_collector: the collector used for training. :type train_collector: :class:`~tianshou.data.Collector` :param test_collector: the collector used for testing. @@ -65,6 +66,12 @@ def onpolicy_trainer( :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; if None is given, it will not write logs to TensorBoard. :param int log_interval: the log interval of the writer. @@ -90,6 +97,8 @@ def onpolicy_trainer( if train_fn: train_fn(epoch, env_step) result = train_collector.collect(n_episode=collect_per_step) + if reward_metric: + result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) data = { "env_step": str(env_step), diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 927c61141..dfc60a789 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,4 +1,5 @@ import time +import numpy as np from torch.utils.tensorboard import SummaryWriter from typing import Any, Dict, Union, Callable, Optional @@ -14,6 +15,7 @@ def test_episode( n_episode: int, writer: Optional[SummaryWriter] = None, global_step: Optional[int] = None, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, ) -> Dict[str, Any]: """A simple wrapper of testing policy in collector.""" collector.reset_env() @@ -24,6 +26,8 @@ def test_episode( result = collector.collect(n_episode=n_episode) if writer is not None and global_step is not None: rews, lens = result["rews"], result["lens"] + if reward_metric: + rews = reward_metric(rews) writer.add_scalar("test/rew", rews.mean(), global_step=global_step) writer.add_scalar("test/rew_std", rews.std(), global_step=global_step) writer.add_scalar("test/len", lens.mean(), global_step=global_step) From d0357c1bba5251505f811cb1644f4908d149ba9a Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 16:57:09 +0800 Subject: [PATCH 090/104] add exploration_noise in mapolicy --- tianshou/policy/base.py | 21 ++++++++------------- tianshou/policy/modelfree/ddpg.py | 6 +----- tianshou/policy/modelfree/discrete_sac.py | 4 +--- tianshou/policy/modelfree/dqn.py | 6 +----- tianshou/policy/multiagent/mapolicy.py | 12 ++++++++++++ 5 files changed, 23 insertions(+), 26 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 6167b30cb..be6d8216b 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -68,9 +68,7 @@ def set_agent_id(self, agent_id: int) -> None: self.agent_id = agent_id def exploration_noise( - self, - act: Union[np.ndarray, Batch], - batch: Batch, + self, act: Union[np.ndarray, Batch], batch: Batch ) -> Union[np.ndarray, Batch]: """Modify the action from policy.forward with exploration noise. @@ -78,8 +76,8 @@ def exploration_noise( policy.forward. :param batch: the input batch for policy.forward, kept for advanced usage. - :return: action in the same form of input 'act' but with added - exploration noise. + :return: action in the same form of input "act" but with added exploration + noise. """ return act @@ -92,8 +90,7 @@ def forward( ) -> Batch: """Compute action over the given batch data. - :return: A :class:`~tianshou.data.Batch` which MUST have the following\ - keys: + :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: * ``act`` an numpy.ndarray or a torch.Tensor, the action over \ given batch data. @@ -122,8 +119,7 @@ def process_fn( ) -> Batch: """Pre-process the data from the provided replay buffer. - Used in :meth:`update`. Check out :ref:`process_fn` for more - information. + Used in :meth:`update`. Check out :ref:`process_fn` for more information. """ return batch @@ -274,9 +270,8 @@ def compute_nstep_return( G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) - where :math:`\gamma` is the discount factor, - :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step - :math:`t`. + where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, + :math:`d_t` is the done flag of step :math:`t`. :param batch: a data batch, which is equal to buffer[indice]. :type batch: :class:`~tianshou.data.Batch` @@ -296,7 +291,7 @@ def compute_nstep_return( """ rew = buffer.rew bsz = len(indice) - if rew_norm: + if rew_norm: # TODO: remove it or fix this bug bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if np.isclose(std, 0, 1e-2): diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 1bca2014a..efa9fb7f9 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -164,11 +164,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: "loss/critic": critic_loss.item(), } - def exploration_noise( - self, - act: np.ndarray, - batch: Batch, - ) -> np.ndarray: + def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: if self._noise: act = act + self._noise(act.shape) act = act.clip(self._range[0], self._range[1]) diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 02781a32d..1a0dc0f7e 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -147,8 +147,6 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: return result def exploration_noise( - self, - act: Union[np.ndarray, Batch], - batch: Batch, + self, act: Union[np.ndarray, Batch], batch: Batch ) -> Union[np.ndarray, Batch]: return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index d4cca7cb2..50321860b 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -174,11 +174,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self._iter += 1 return {"loss": loss.item()} - def exploration_noise( - self, - act: np.ndarray, - batch: Batch, - ) -> np.ndarray: + def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: if not np.isclose(self.eps, 0.0): for i in range(len(act)): if np.random.rand() < self.eps: diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 5bfd5e9b2..7aa1f661c 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -59,6 +59,18 @@ def process_fn( buffer._meta.rew = save_rew return Batch(results) + def exploration_noise( + self, act: Union[np.ndarray, Batch], batch: Batch + ) -> Union[np.ndarray, Batch]: + """Add exploration noise from sub-policy onto act.""" + for policy in self.policies: + agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + if len(agent_index) == 0: + continue + act[agent_index] = policy.exploration_noise( + act[agent_index], batch[agent_index]) + return act + def forward( self, batch: Batch, From 64fe3ce4c068a5ff45b06fcdac69ad18e8f27a73 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 17:09:37 +0800 Subject: [PATCH 091/104] it works! --- test/multiagent/test_tic_tac_toe.py | 2 -- tianshou/trainer/utils.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/multiagent/test_tic_tac_toe.py b/test/multiagent/test_tic_tac_toe.py index 92ecb97c6..1cc06d374 100644 --- a/test/multiagent/test_tic_tac_toe.py +++ b/test/multiagent/test_tic_tac_toe.py @@ -1,10 +1,8 @@ import pprint -from tianshou.data import Collector from tic_tac_toe import get_args, train_agent, watch def test_tic_tac_toe(args=get_args()): - Collector._default_rew_metric = lambda x: x[args.agent_id - 1] if args.watch: watch(args) return diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index dfc60a789..2cdeb15fe 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -24,10 +24,10 @@ def test_episode( if test_fn: test_fn(epoch, global_step) result = collector.collect(n_episode=n_episode) + if reward_metric: + result["rews"] = reward_metric(result["rews"]) if writer is not None and global_step is not None: rews, lens = result["rews"], result["lens"] - if reward_metric: - rews = reward_metric(rews) writer.add_scalar("test/rew", rews.mean(), global_step=global_step) writer.add_scalar("test/rew_std", rews.std(), global_step=global_step) writer.add_scalar("test/len", lens.mean(), global_step=global_step) From 64e04dfc9d7fed7197bfbeeeadbb8387e38e85ff Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 18:01:25 +0800 Subject: [PATCH 092/104] fix test --- tianshou/data/buffer.py | 69 ++++++++++++++++++++++++++++++++++---- tianshou/data/collector.py | 12 ------- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index f3636afe6..b24e8c6b2 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,7 +1,7 @@ import h5py import torch import numpy as np -from typing import Any, Dict, List, Tuple, Union, Optional +from typing import Any, Dict, List, Tuple, Union, Sequence, Optional from tianshou.data.batch import _create_value from tianshou.data import Batch, SegmentTree, to_numpy @@ -167,7 +167,8 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._size = min(self._size + 1, self.maxsize) to_indices = np.array(to_indices) if self._meta.is_empty(): - self._meta = _create_value(buffer._meta, self.maxsize, stack=False) + self._meta = _create_value( # type: ignore + buffer._meta, self.maxsize, stack=False) self._meta[to_indices] = buffer._meta[from_indices] return to_indices @@ -239,7 +240,8 @@ def add( batch.rew = batch.rew.astype(np.float) batch.done = batch.done.astype(np.bool_) if self._meta.is_empty(): - self._meta = _create_value(batch, self.maxsize, stack) + self._meta = _create_value( # type: ignore + batch, self.maxsize, stack) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) self._meta[ptr] = batch @@ -431,7 +433,7 @@ class ReplayBufferManager(ReplayBuffer): These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory. - :param int buffer_list: a list of ReplayBuffer needed to be handled. + :param buffer_list: a list of ReplayBuffer needed to be handled. .. seealso:: @@ -539,7 +541,8 @@ def add( batch.rew = batch.rew.astype(np.float) batch.done = batch.done.astype(np.bool_) if self._meta.is_empty(): - self._meta = _create_value(batch, self.maxsize, stack=False) + self._meta = _create_value( # type: ignore + batch, self.maxsize, stack=False) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() @@ -576,8 +579,23 @@ def sample_index(self, batch_size: int) -> np.ndarray: class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): - def __init__(self, buffer_list: List[PrioritizedReplayBuffer]) -> None: - ReplayBufferManager.__init__(self, buffer_list) + """PrioritizedReplayBufferManager contains a list of PrioritizedReplayBuffer with \ + exactly the same configuration. + + These replay buffers have contiguous memory layout, and the storage space each + buffer has is a shallow copy of the topmost memory. + + :param buffer_list: a list of PrioritizedReplayBuffer needed to be handled. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer`, + :class:`~tianshou.data.ReplayBufferManager`, and + :class:`~tianshou.data.PrioritizedReplayBuffer` for more detailed explanation. + """ + + def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: + ReplayBufferManager.__init__(self, buffer_list) # type: ignore kwargs = buffer_list[0].options for buf in buffer_list: del buf.weight @@ -585,6 +603,24 @@ def __init__(self, buffer_list: List[PrioritizedReplayBuffer]) -> None: class VectorReplayBuffer(ReplayBufferManager): + """VectorReplayBuffer contains n ReplayBuffer with the same size. + + It is used for storing data frame from different environments yet keeping the order + of time. + + :param int total_size: the total size of VectorReplayBuffer. + :param int buffer_num: the number of ReplayBuffer it uses, which are under the same + configuration. + + Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail) + are the same as :class:`~tianshou.data.ReplayBuffer`. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` and + :class:`~tianshou.data.ReplayBufferManager` for more detailed explanation. + """ + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) @@ -593,6 +629,25 @@ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): + """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size. + + It is used for storing data frame from different environments yet keeping the order + of time. + + :param int total_size: the total size of PrioritizedVectorReplayBuffer. + :param int buffer_num: the number of PrioritizedReplayBuffer it uses, which are + under the same configuration. + + Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/ + sample_avail) are the same as :class:`~tianshou.data.PrioritizedReplayBuffer`. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` and + :class:`~tianshou.data.PrioritizedReplayBufferManager` for more detailed + explanation. + """ + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 63d02f9b2..74ee72d11 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -116,18 +116,6 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: ) self.buffer = buffer - # TODO move to trainer - # @staticmethod - # def _default_rew_metric( - # x: Union[Number, np.number] - # ) -> Union[Number, np.number]: - # # this internal function is designed for single-agent RL - # # for multi-agent RL, a reward_metric must be provided - # assert np.asanyarray(x).size == 1, ( - # "Please specify the reward_metric " - # "since the reward is not a scalar.") - # return x - def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for "state" so that self.data supports slicing From 2275efbc0c83aecfb77be6f1d05f5d7021b5fac8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 19:49:55 +0800 Subject: [PATCH 093/104] fix dqn family eps-test --- examples/atari/atari_bcq.py | 2 +- examples/atari/atari_c51.py | 18 +++++++----------- examples/atari/atari_dqn.py | 24 ++++++++++-------------- examples/atari/atari_qrdqn.py | 18 +++++++----------- examples/box2d/acrobot_dualdqn.py | 8 +++----- examples/box2d/lunarlander_dqn.py | 2 +- test/discrete/test_c51.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 12 +++++------- test/discrete/test_il_bcq.py | 2 +- test/discrete/test_qrdqn.py | 2 +- 11 files changed, 38 insertions(+), 54 deletions(-) diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index 0db587700..e2b8f0778 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -111,7 +111,7 @@ def test_discrete_bcq(args=get_args()): exit(0) # collector - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') writer = SummaryWriter(log_path) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index cb8f0d296..42cffa9be 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -84,19 +84,16 @@ def test_c51(args=get_args()): ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device - )) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM - buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs), - ignore_obs_next=True, - save_only_last_obs=True, - stack_num=args.frames_stack) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) @@ -132,8 +129,7 @@ def watch(): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) if args.watch: @@ -141,7 +137,7 @@ def watch(): exit(0) # test train_collector and start filling replay buffer - train_collector.collect(n_step=args.batch_size * 4) + train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index ff589dcbb..559b0878e 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -80,19 +80,16 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device - )) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM - buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs), - ignore_obs_next=True, - save_only_last_obs=True, - stack_num=args.frames_stack) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) @@ -129,11 +126,10 @@ def watch(): test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") - buffer = VectorReplayBuffer(args.buffer_size, - buffer_num=len(test_envs), - ignore_obs_next=True, - save_only_last_obs=True, - stack_num=args.frames_stack) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(test_envs), + ignore_obs_next=True, save_only_last_obs=True, + stack_num=args.frames_stack) collector = Collector(policy, test_envs, buffer) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") @@ -151,7 +147,7 @@ def watch(): exit(0) # test train_collector and start filling replay buffer - train_collector.collect(n_step=args.batch_size * 4) + train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 2a23926de..ed356381f 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -82,19 +82,16 @@ def test_qrdqn(args=get_args()): ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device - )) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM - buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs), - ignore_obs_next=True, - save_only_last_obs=True, - stack_num=args.frames_stack) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) @@ -130,8 +127,7 @@ def watch(): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) if args.watch: @@ -139,7 +135,7 @@ def watch(): exit(0) # test train_collector and start filling replay buffer - train_collector.collect(n_step=args.batch_size * 4) + train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index dadd7ddb1..444c357f8 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -28,8 +28,7 @@ def get_args(): parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=100) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128]) parser.add_argument('--dueling-q-hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--dueling-v-hidden-sizes', type=int, @@ -75,7 +74,7 @@ def test_dqn(args=get_args()): policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True) - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log @@ -116,8 +115,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index c01410bb7..de88aa315 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -76,7 +76,7 @@ def test_dqn(args=get_args()): policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True) - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 31c0c2a1a..684ce9696 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -82,7 +82,7 @@ def test_c51(args=get_args()): buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 273213af9..df02684c0 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -84,7 +84,7 @@ def test_dqn(args=get_args()): buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index a91d5d3eb..2375edc54 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -65,14 +65,12 @@ def test_drqn(args=get_args()): net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) # collector - train_collector = Collector( - policy, train_envs, - VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - stack_num=args.stack_num, ignore_obs_next=True), - exploration_noise=True) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + stack_num=args.stack_num, ignore_obs_next=True) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) # the stack_num is for RNN training: sample framestack obs - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 9691119e3..3dd2b8d63 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -78,7 +78,7 @@ def test_discrete_bcq(args=get_args()): buffer = pickle.load(open(args.load_buffer_name, "rb")) # collector - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') writer = SummaryWriter(log_path) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 3a4f0c998..006dd827b 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -80,7 +80,7 @@ def test_qrdqn(args=get_args()): buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log From 114535958e4aa79f94452c010473d05f8ef118ff Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 20:08:24 +0800 Subject: [PATCH 094/104] fix dead loop in creating new Batch (drqn _is_scalar replace np.asanyarray with np.isscalar) --- test/discrete/test_drqn.py | 2 +- tianshou/data/batch.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 2375edc54..420f8e6cd 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=3) parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 4f15622ab..889467077 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -35,8 +35,8 @@ def _is_scalar(value: Any) -> bool: if isinstance(value, torch.Tensor): return value.numel() == 1 and not value.shape else: - value = np.asanyarray(value) - return value.size == 1 and not value.shape + # np.asanyarray will cause dead loop in some cases + return np.isscalar(value) def _is_number(value: Any) -> bool: From c7a624f06e6c1b261288695315d26bda74f82727 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 21:18:17 +0800 Subject: [PATCH 095/104] split exploration_noise in bcq --- tianshou/policy/imitation/discrete_bcq.py | 26 +++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 688a9901d..0061ea20f 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -74,7 +74,7 @@ def _target_q( ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - act = self(batch, input="obs_next", eps=0.0).act + act = self(batch, input="obs_next").act target_q, _ = self.model_old(batch.obs_next) target_q = target_q[np.arange(len(act)), act] return target_q @@ -84,13 +84,12 @@ def forward( # type: ignore batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", - eps: Optional[float] = None, **kwargs: Any, ) -> Batch: - if eps is None: - eps = self._eps obs = batch[input] q_value, state = self.model(obs, state=state, info=batch.info) + if not hasattr(self, "max_action_num"): + self.max_action_num = q_value.shape[1] imitation_logits, _ = self.imitator(obs, state=state, info=batch.info) # mask actions for argmax @@ -99,24 +98,25 @@ def forward( # type: ignore mask = (ratio < self._log_tau).float() action = (q_value - np.inf * mask).argmax(dim=-1) - # add eps to act - if not np.isclose(eps, 0.0): - bsz, action_num = q_value.shape - mask = np.random.rand(bsz) < eps - action_rand = torch.randint( - action_num, size=[bsz], device=action.device) - action[mask] = action_rand[mask] - return Batch(act=action, state=state, q_value=q_value, imitation_logits=imitation_logits) + def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: + # add eps to act + if not np.isclose(self._eps, 0.0): + bsz = len(act) + mask = np.random.rand(bsz) < self._eps + act_rand = np.random.randint(self.max_action_num, size=[bsz]) + act[mask] = act_rand[mask] + return act + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._iter % self._freq == 0: self.sync_weight() self._iter += 1 target_q = batch.returns.flatten() - result = self(batch, eps=0.0) + result = self(batch) imitation_logits = result.imitation_logits current_q = result.q_value[np.arange(len(target_q)), batch.act] act = to_torch(batch.act, dtype=torch.long, device=target_q.device) From d5fa008d00de88d55fbd4d74a498d45e5eccba6f Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 21:43:42 +0800 Subject: [PATCH 096/104] add test for priovecbuf --- test/base/test_buffer.py | 31 +++++++++++++---------- tianshou/policy/modelfree/c51.py | 4 +-- tianshou/policy/modelfree/discrete_sac.py | 3 +-- tianshou/policy/modelfree/dqn.py | 8 ++---- tianshou/policy/modelfree/qrdqn.py | 4 +-- 5 files changed, 22 insertions(+), 28 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 3fd09590f..3f65ad633 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -11,6 +11,8 @@ from tianshou.data import Batch, SegmentTree, ReplayBuffer from tianshou.data import PrioritizedReplayBuffer from tianshou.data import VectorReplayBuffer, CachedReplayBuffer +from tianshou.data import PrioritizedVectorReplayBuffer + if __name__ == '__main__': from env import MyTestEnv @@ -151,13 +153,16 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): def test_priortized_replaybuffer(size=32, bufsize=15): env = MyTestEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) + buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) batch = Batch(obs=obs, act=a, rew=rew, done=done, obs_next=obs_next, info=info, policy=np.random.randn() - 0.5) + batch_stack = Batch.stack([batch, batch, batch]) buf.add(Batch.stack([batch]), buffer_ids=[0]) + buf2.add(batch_stack, buffer_ids=[0, 1, 2]) obs = obs_next data, indice = buf.sample(len(buf) // 2) if len(buf) // 2 == 0: @@ -165,13 +170,23 @@ def test_priortized_replaybuffer(size=32, bufsize=15): else: assert len(data) == len(buf) // 2 assert len(buf) == min(bufsize, i + 1) + assert len(buf2) == min(bufsize, 3 * (i + 1)) + # check single buffer's data assert buf.info.key.shape == (buf.maxsize,) assert buf.rew.dtype == np.float assert buf.done.dtype == np.bool_ data, indice = buf.sample(len(buf) // 2) buf.update_weight(indice, -data.weight / 2) - assert np.allclose( - buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) + assert np.allclose(buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) + # check multi buffer's data + assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1) + batch, indice = buf2.sample(10) + buf2.update_weight(indice, batch.weight * 0) + weight = buf2[np.arange(buf2.maxsize)].weight + mask = np.isin(np.arange(buf2.maxsize), indice) + assert np.all(weight[mask] == weight[mask][0]) + assert np.all(weight[~mask] == weight[~mask][0]) + assert weight[~mask][0] < weight[mask][0] and weight[mask][0] < 1 def test_update(): @@ -608,18 +623,6 @@ def test_multibuf_stack(): indice = buf5.sample_index(0) assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) batch, _ = buf5.sample(0) - # the below test code should move to PrioritizedReplayBufferManager - # assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1) - # buf5.update_weight(indice, batch.weight * 0) - # weight = buf5[np.arange(buf5.maxsize)].weight - # modified_weight = weight[[0, 1, 2, 5, 6, 7]] - # assert modified_weight.min() == modified_weight.max() - # assert modified_weight.max() < 1 - # unmodified_weight = weight[[3, 4, 8]] - # assert unmodified_weight.min() == unmodified_weight.max() - # assert unmodified_weight.max() < 1 - # cached_weight = weight[9:] - # assert cached_weight.min() == cached_weight.max() == 1 # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index b0e94c616..dce2112a8 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -70,9 +70,7 @@ def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: def _target_dist(self, batch: Batch) -> torch.Tensor: if self._target: a = self(batch, input="obs_next").act - next_dist = self( - batch, model="model_old", input="obs_next" - ).logits + next_dist = self(batch, model="model_old", input="obs_next").logits else: next_b = self(batch, input="obs_next") a = next_b.act diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 1a0dc0f7e..9c46fc4a3 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -119,8 +119,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: current_q1a = self.critic1(batch.obs) current_q2a = self.critic2(batch.obs) q = torch.min(current_q1a, current_q2a) - actor_loss = -(self._alpha * entropy - + (dist.probs * q).sum(dim=-1)).mean() + actor_loss = -(self._alpha * entropy + (dist.probs * q).sum(dim=-1)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 50321860b..e79ff3206 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -46,9 +46,7 @@ def __init__( self.model = model self.optim = optim self.eps = 0.0 - assert ( - 0.0 <= discount_factor <= 1.0 - ), "discount factor should be in [0, 1]" + assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" self._gamma = discount_factor assert estimation_step > 0, "estimation_step should be greater than 0" self._n_step = estimation_step @@ -81,9 +79,7 @@ def _target_q( # target_Q = Q_old(s_, argmax(Q_new(s_, *))) if self._target: a = self(batch, input="obs_next").act - target_q = self( - batch, model="model_old", input="obs_next" - ).logits + target_q = self(batch, model="model_old", input="obs_next").logits target_q = target_q[np.arange(len(a)), a] else: target_q = self(batch, input="obs_next").logits.max(dim=1)[0] diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 754d9acce..8816b6b1a 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -56,9 +56,7 @@ def _target_q( batch = buffer[indice] # batch.obs_next: s_{t+n} if self._target: a = self(batch, input="obs_next").act - next_dist = self( - batch, model="model_old", input="obs_next" - ).logits + next_dist = self(batch, model="model_old", input="obs_next").logits else: next_b = self(batch, input="obs_next") a = next_b.act From c3b35d42af9401abdfb7243cdc00d40829f5b760 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 21:55:43 +0800 Subject: [PATCH 097/104] improve coverage --- test/base/test_collector.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index e8c886c51..b9d789193 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -6,7 +6,12 @@ from tianshou.policy import BasePolicy from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.data import Batch, Collector, AsyncCollector -from tianshou.data import ReplayBuffer, VectorReplayBuffer, CachedReplayBuffer +from tianshou.data import ( + ReplayBuffer, + PrioritizedReplayBuffer, + VectorReplayBuffer, + CachedReplayBuffer, +) if __name__ == '__main__': from env import MyTestEnv @@ -124,6 +129,14 @@ def test_collector(): assert np.all(c2.buffer.obs[:, 0] == obs) c2.collect(n_episode=4, random=True) + # test corner case + with pytest.raises(TypeError): + Collector(policy, dum, ReplayBuffer(10)) + with pytest.raises(TypeError): + Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5)) + with pytest.raises(TypeError): + c2.collect() + def test_collector_with_async(): env_lens = [2, 3, 4, 5] @@ -167,6 +180,8 @@ def test_collector_with_async(): assert np.all(buf.info.env_id == i) assert np.all(buf.obs.reshape(-1, env_len) == seq) assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) + with pytest.raises(TypeError): + c1.collect() def test_collector_with_dict_state(): From 0cae28c4e096078a8860cd0bf15ac7c59e3ced15 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 22:49:51 +0800 Subject: [PATCH 098/104] fix several bugs of documentation --- README.md | 8 ++--- docs/tutorials/cheatsheet.rst | 14 ++++----- docs/tutorials/concepts.rst | 12 ++++---- docs/tutorials/dqn.rst | 9 +++--- docs/tutorials/tictactoe.rst | 53 +++++++++++++++++++--------------- test/multiagent/tic_tac_toe.py | 17 +++++------ 6 files changed, 56 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 3ba5f705e..c50321f98 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,7 @@ train_num, test_num = 8, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 -step_per_epoch, collect_per_step = 1000, 10 +step_per_epoch, collect_per_step = 1000, 8 writer = SummaryWriter('log/dqn') # tensorboard is also supported! ``` @@ -224,8 +224,8 @@ Setup policy and collectors: ```python policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, target_update_freq=target_freq) -train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size)) -test_collector = ts.data.Collector(policy, test_envs) +train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True) +test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method ``` Let's train it: @@ -253,7 +253,7 @@ Watch the performance with 35 FPS: ```python policy.eval() policy.set_eps(eps_test) -collector = ts.data.Collector(policy, env) +collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=1, render=1 / 35) ``` diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 07155f356..69495ba63 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -144,7 +144,7 @@ And finally, :: test_processor = MyProcessor(size=100) - collector = Collector(policy, env, buffer, test_processor.preprocess_fn) + collector = Collector(policy, env, buffer, preprocess_fn=test_processor.preprocess_fn) Some examples are in `test/base/test_collector.py `_. @@ -156,7 +156,7 @@ RNN-style Training This is related to `Issue 19 `_. -First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`: +First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`, :class:`~tianshou.data.VectorReplayBuffer`, or other types of buffer you are using, like: :: buf = ReplayBuffer(size=size, stack_num=stack_num) @@ -206,14 +206,13 @@ The state can be a ``numpy.ndarray`` or a Python dictionary. Take "FetchReach-v1 It shows that the state is a dictionary which has 3 keys. It will stored in :class:`~tianshou.data.ReplayBuffer` as: :: - >>> from tianshou.data import ReplayBuffer + >>> from tianshou.data import Batch, ReplayBuffer >>> b = ReplayBuffer(size=3) - >>> b.add(obs=e.reset(), act=0, rew=0, done=0) + >>> b.add(Batch(obs=e.reset(), act=0, rew=0, done=0)) >>> print(b) ReplayBuffer( act: array([0, 0, 0]), - done: array([0, 0, 0]), - info: Batch(), + done: array([False, False, False]), obs: Batch( achieved_goal: array([[1.34183265, 0.74910039, 0.53472272], [0. , 0. , 0. ], @@ -234,7 +233,6 @@ It shows that the state is a dictionary which has 3 keys. It will stored in :cla 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]), ), - policy: Batch(), rew: array([0, 0, 0]), ) >>> print(b.obs.achieved_goal) @@ -278,7 +276,7 @@ For self-defined class, the replay buffer will store the reference into a ``nump >>> import networkx as nx >>> b = ReplayBuffer(size=3) - >>> b.add(obs=nx.Graph(), act=0, rew=0, done=0) + >>> b.add(Batch(obs=nx.Graph(), act=0, rew=0, done=0)) >>> print(b) ReplayBuffer( act: array([0, 0, 0]), diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index a314cdedb..b3e126352 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -53,7 +53,7 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair Buffer ------ -:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style. +:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of :class:`~tianshou.data.Batch`. It stores all the data in a batch with circular-queue style. The current implementation of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: @@ -209,7 +209,7 @@ The following code snippet illustrates its usage, including:
-Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. +Tianshou provides other type of data buffer such as :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``) and :class:`~tianshou.data.VectorReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. Policy @@ -339,14 +339,12 @@ Collector The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. -:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer. - -Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. - -The proposed solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number. +:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer, then return the statistics of the collected data such as episode's total reward. The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation. +There is also another type of collector :class:`~tianshou.data.AsyncCollector` which supports asynchronous environment setting (for those taking a long time to step). However, AsyncCollector only supports **at least** ``n_step`` or ``n_episode`` collection due to the property of asynchronous environments. + Trainer ------- diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index edd39355b..361f79f3c 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -113,8 +113,8 @@ The collector is a key concept in Tianshou. It allows the policy to interact wit In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer. :: - train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000)) - test_collector = ts.data.Collector(policy, test_envs) + train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 8), exploration_noise=True) + test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) Train Policy with a Trainer @@ -191,7 +191,7 @@ Watch the Agent's Performance policy.eval() policy.set_eps(0.05) - collector = ts.data.Collector(policy, env) + collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=1, render=1 / 35) @@ -206,8 +206,7 @@ Tianshou supports user-defined training code. Here is the code snippet: :: # pre-collect at least 5000 frames with random action before training - policy.set_eps(1) - train_collector.collect(n_step=5000) + train_collector.collect(n_step=5000, random=True) policy.set_eps(0.1) for i in range(int(1e6)): # total step diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 4aa356e6f..c656c1ee2 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -180,7 +180,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer - from tianshou.data import Collector, ReplayBuffer + from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import BasePolicy, RandomPolicy, DQNPolicy, MultiAgentPolicyManager from tic_tac_toe_env import TicTacToeEnv @@ -199,27 +199,27 @@ The explanation of each Tianshou class/function will be deferred to their first help='a smaller gamma favors earlier win') parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=500) + parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.1) - parser.add_argument('--board_size', type=int, default=6) - parser.add_argument('--win_size', type=int, default=4) - parser.add_argument('--win-rate', type=float, default=np.float32(0.9), + parser.add_argument('--board-size', type=int, default=6) + parser.add_argument('--win-size', type=int, default=4) + parser.add_argument('--win-rate', type=float, default=0.9, help='the expected winning rate') parser.add_argument('--watch', default=False, action='store_true', help='no training, watch the play of pre-trained models') - parser.add_argument('--agent_id', type=int, default=2, + parser.add_argument('--agent-id', type=int, default=2, help='the learned agent plays as the agent_id-th player. Choices are 1 and 2.') - parser.add_argument('--resume_path', type=str, default='', + parser.add_argument('--resume-path', type=str, default='', help='the path of agent pth file for resuming from a pre-trained agent') - parser.add_argument('--opponent_path', type=str, default='', + parser.add_argument('--opponent-path', type=str, default='', help='the path of opponent agent pth file for resuming from a pre-trained agent') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -240,11 +240,13 @@ Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, whi Here it is: :: - def get_agents(args=get_args(), - agent_learn=None, # BasePolicy - agent_opponent=None, # BasePolicy - optim=None, # torch.optim.Optimizer - ): # return a tuple of (BasePolicy, torch.optim.Optimizer) + def get_agents( + args=get_args(), + agent_learn=None, # BasePolicy + agent_opponent=None, # BasePolicy + optim=None, # torch.optim.Optimizer + ): # return a tuple of (BasePolicy, torch.optim.Optimizer) + env = TicTacToeEnv(args.board_size, args.win_size) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n @@ -279,9 +281,6 @@ With the above preparation, we are close to the first learned agent. The followi :: args = get_args() - # the reward is a vector, we need a scalar metric to monitor the training. - # we choose the reward of the learning agent - Collector._default_rew_metric = lambda x: x[args.agent_id - 1] # ======== a test function that tests a pre-trained agent and exit ====== def watch(args=get_args(), @@ -313,8 +312,9 @@ With the above preparation, we are close to the first learned agent. The followi policy, optim = get_agents() # ======== collector setup ========= - train_collector = Collector(policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) + buffer = VectorReplayBuffer(args.buffer_size, args.training_num) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) train_collector.collect(n_step=args.batch_size * args.training_num) # ======== tensorboard logging setup ========= @@ -347,13 +347,18 @@ With the above preparation, we are close to the first learned agent. The followi def test_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) + # the reward is a vector, we need a scalar metric to monitor the training. + # we choose the reward of the learning agent + def reward_metric(rews): + return rews[:, args.agent_id - 1] + # start training, this may require about three minutes result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, + writer=writer, test_in_train=False) agent = policy.policies[args.agent_id - 1] # let's watch the match! @@ -476,7 +481,7 @@ By default, the trained agent is stored in ``log/tic_tac_toe/dqn/policy.pth``. Y .. code-block:: console - $ python test_tic_tac_toe.py --watch --resume_path=log/tic_tac_toe/dqn/policy.pth --opponent_path=log/tic_tac_toe/dqn/policy.pth + $ python test_tic_tac_toe.py --watch --resume-path log/tic_tac_toe/dqn/policy.pth --opponent-path log/tic_tac_toe/dqn/policy.pth Here is our output: diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index dfe25d42b..6fce2c509 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -37,20 +37,20 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.1) - parser.add_argument('--board_size', type=int, default=6) - parser.add_argument('--win_size', type=int, default=4) - parser.add_argument('--win_rate', type=float, default=0.9, + parser.add_argument('--board-size', type=int, default=6) + parser.add_argument('--win-size', type=int, default=4) + parser.add_argument('--win-rate', type=float, default=0.9, help='the expected winning rate') parser.add_argument('--watch', default=False, action='store_true', help='no training, ' 'watch the play of pre-trained models') - parser.add_argument('--agent_id', type=int, default=2, + parser.add_argument('--agent-id', type=int, default=2, help='the learned agent plays as the' - ' agent_id-th player. choices are 1 and 2.') - parser.add_argument('--resume_path', type=str, default='', + ' agent_id-th player. Choices are 1 and 2.') + parser.add_argument('--resume-path', type=str, default='', help='the path of agent pth file ' 'for resuming from a pre-trained agent') - parser.add_argument('--opponent_path', type=str, default='', + parser.add_argument('--opponent-path', type=str, default='', help='the path of opponent agent pth file ' 'for resuming from a pre-trained agent') parser.add_argument( @@ -61,8 +61,7 @@ def get_parser() -> argparse.ArgumentParser: def get_args() -> argparse.Namespace: parser = get_parser() - args = parser.parse_known_args()[0] - return args + return parser.parse_args() def get_agents( From b7efc68fe971233cc04f25a1b8a20e51c2db1bdd Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 22:56:27 +0800 Subject: [PATCH 099/104] fix test --- test/multiagent/tic_tac_toe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 6fce2c509..b081a5055 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -61,7 +61,7 @@ def get_parser() -> argparse.ArgumentParser: def get_args() -> argparse.Namespace: parser = get_parser() - return parser.parse_args() + return parser.parse_known_args()[0] def get_agents( From b2098eeb68a5ce72afcdfed4196a962aa808b4e0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 23:12:18 +0800 Subject: [PATCH 100/104] add a test of batch --- test/base/test_batch.py | 2 ++ tianshou/data/batch.py | 10 +++------- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 4553edff7..58631325d 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -21,6 +21,8 @@ def test_batch(): assert not Batch(a=[1, 2, 3]).is_empty() b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) assert b.c.dtype == np.object + b = Batch(d=[None], e=[starmap], f=Batch) + assert b.d.dtype == b.e.dtype == b.f.dtype == np.object b = Batch() b.update() assert b.is_empty() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 889467077..4476f900b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -48,10 +48,8 @@ def _is_number(value: Any) -> bool: def _to_array_with_correct_type(v: Any) -> np.ndarray: - if isinstance(v, np.ndarray) and issubclass( - v.dtype.type, (np.bool_, np.number) - ): # most often case - return v + if isinstance(v, np.ndarray) and issubclass(v.dtype.type, (np.bool_, np.number)): + return v # most often case # convert the value to np.ndarray # convert to np.object data type if neither bool nor number # raises an exception if array's elements are tensors themself @@ -66,9 +64,7 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: # array([{}, array({}, dtype=object)], dtype=object) if not v.shape: v = v.item(0) - elif any( - isinstance(e, (np.ndarray, torch.Tensor)) for e in v.reshape(-1) - ): + elif any(isinstance(e, (np.ndarray, torch.Tensor)) for e in v.reshape(-1)): raise ValueError("Numpy arrays of tensors are not supported yet.") return v From 4f181cc6281ca57ced4617e8355cbb4ebe1ea420 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 23:32:29 +0800 Subject: [PATCH 101/104] fix test --- test/base/test_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 58631325d..0898e154c 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -22,7 +22,7 @@ def test_batch(): b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) assert b.c.dtype == np.object b = Batch(d=[None], e=[starmap], f=Batch) - assert b.d.dtype == b.e.dtype == b.f.dtype == np.object + assert b.d.dtype == b.e.dtype == np.object and b.f == Batch b = Batch() b.update() assert b.is_empty() From 8be5b2c6b65c5802fe33f2cb5c207e106d9ec272 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 18 Feb 2021 23:36:35 +0800 Subject: [PATCH 102/104] add a note --- docs/tutorials/cheatsheet.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 69495ba63..fefb30934 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -297,6 +297,10 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y ... return copy.deepcopy(self.graph), reward, done, {} +.. note :: + + Please make sure this variable is numpy-compatible, e.g., np.array([variable]) will not result in an empty array. Otherwise, ReplayBuffer cannot create an numpy array to store it. + .. _marl_example: From cb4dbda68fb8b40090838cf219a4fb0d179fff7b Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 19 Feb 2021 08:35:13 +0800 Subject: [PATCH 103/104] remove redundant code --- tianshou/data/batch.py | 53 +++++++++++++----------------------------- 1 file changed, 16 insertions(+), 37 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 4476f900b..fd12e0ec4 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -5,8 +5,7 @@ from copy import deepcopy from numbers import Number from collections.abc import Collection -from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \ - Sequence +from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, Sequence def _is_batch_set(data: Any) -> bool: @@ -74,20 +73,16 @@ def _create_value( ) -> Union["Batch", np.ndarray, torch.Tensor]: """Create empty place-holders accroding to inst's shape. - :param bool stack: whether to stack or to concatenate. E.g. if inst has - shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape - of (10, 3, 5), otherwise (10, 5) + :param bool stack: whether to stack or to concatenate. E.g. if inst has shape of + (3, 5), size = 10, stack=True returns an np.ndarry with shape of (10, 3, 5), + otherwise (10, 5) """ has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) is_scalar = _is_scalar(inst) if not stack and is_scalar: - # _create_value(Batch(a={}, b=[1, 2, 3]), 10, False) will fail here - if isinstance(inst, Batch) and inst.is_empty(recurse=True): - return inst - # should never hit since it has already checked in Batch.cat_ - # here we do not consider scalar types, following the behavior of numpy - # which does not support concatenation of zero-dimensional arrays - # (scalars) + # should never hit since it has already checked in Batch.cat_ , here we do not + # consider scalar types, following the behavior of numpy which does not support + # concatenation of zero-dimensional arrays (scalars) raise TypeError(f"cannot concatenate with {inst} which is scalar") if has_shape: shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) @@ -102,9 +97,7 @@ def _create_value( dtype=target_type ) elif isinstance(inst, torch.Tensor): - return torch.full( - shape, fill_value=0, device=inst.device, dtype=inst.dtype - ) + return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): zero_batch = Batch() for key, val in inst.items(): @@ -117,9 +110,8 @@ def _create_value( def _assert_type_keys(keys: Iterable[str]) -> None: - assert all( - isinstance(e, str) for e in keys - ), f"keys should all be string, but got {keys}" + assert all(isinstance(e, str) for e in keys), \ + f"keys should all be string, but got {keys}" def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: @@ -436,9 +428,7 @@ def __cat( val, sum_lens[-1], stack=False) self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val - def cat_( - self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]] - ) -> None: + def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None: """Concatenate a list of (or one) Batch objects into current batch.""" if isinstance(batches, Batch): batches = [batches] @@ -494,9 +484,7 @@ def cat(batches: Sequence[Union[dict, "Batch"]]) -> "Batch": batch.cat_(batches) return batch - def stack_( - self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0 - ) -> None: + def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None: """Stack a list of Batch object into current batch.""" # check input format batch_list = [] @@ -560,9 +548,7 @@ def stack_( self.__dict__[k][i] = val @staticmethod - def stack( - batches: Sequence[Union[dict, "Batch"]], axis: int = 0 - ) -> "Batch": + def stack(batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> "Batch": """Stack a list of Batch object into a single new batch. For keys that are not shared across all batches, batches that do not @@ -589,10 +575,7 @@ def stack( return batch def empty_( - self, - index: Union[ - str, slice, int, np.integer, np.ndarray, List[int] - ] = None, + self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None ) -> "Batch": """Return an empty Batch object with 0 or None filled. @@ -642,9 +625,7 @@ def empty_( @staticmethod def empty( batch: "Batch", - index: Union[ - str, slice, int, np.integer, np.ndarray, List[int] - ] = None, + index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None, ) -> "Batch": """Return an empty Batch object with 0 or None filled. @@ -670,9 +651,7 @@ def __len__(self) -> int: for v in self.__dict__.values(): if isinstance(v, Batch) and v.is_empty(recurse=True): continue - elif hasattr(v, "__len__") and ( - isinstance(v, Batch) or v.ndim > 0 - ): + elif hasattr(v, "__len__") and (isinstance(v, Batch) or v.ndim > 0): r.append(len(v)) else: raise TypeError(f"Object {v} in {self} has no len()") From 9b466780f00d928e7f239cd742f4784fcf074f39 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 19 Feb 2021 09:58:22 +0800 Subject: [PATCH 104/104] polish docs organization --- docs/api/tianshou.data.rst | 24 +++++++++++++++++++++++- docs/api/tianshou.env.rst | 8 ++++++++ docs/api/tianshou.utils.rst | 4 ++++ docs/conf.py | 1 + tianshou/policy/__init__.py | 4 ++-- 5 files changed, 38 insertions(+), 3 deletions(-) diff --git a/docs/api/tianshou.data.rst b/docs/api/tianshou.data.rst index fa1d5c738..555d35640 100644 --- a/docs/api/tianshou.data.rst +++ b/docs/api/tianshou.data.rst @@ -1,7 +1,29 @@ tianshou.data ============= -.. automodule:: tianshou.data + +Batch +----- + +.. automodule:: tianshou.data.batch + :members: + :undoc-members: + :show-inheritance: + + +Buffer +------ + +.. automodule:: tianshou.data.buffer + :members: + :undoc-members: + :show-inheritance: + + +Collector +--------- + +.. automodule:: tianshou.data.collector :members: :undoc-members: :show-inheritance: diff --git a/docs/api/tianshou.env.rst b/docs/api/tianshou.env.rst index 7201bae46..f7eec6998 100644 --- a/docs/api/tianshou.env.rst +++ b/docs/api/tianshou.env.rst @@ -1,11 +1,19 @@ tianshou.env ============ + +VectorEnv +--------- + .. automodule:: tianshou.env :members: :undoc-members: :show-inheritance: + +Worker +------ + .. automodule:: tianshou.env.worker :members: :undoc-members: diff --git a/docs/api/tianshou.utils.rst b/docs/api/tianshou.utils.rst index 3a293b1c1..b2ac6a976 100644 --- a/docs/api/tianshou.utils.rst +++ b/docs/api/tianshou.utils.rst @@ -6,6 +6,10 @@ tianshou.utils :undoc-members: :show-inheritance: + +Pre-defined Networks +-------------------- + .. automodule:: tianshou.utils.net.common :members: :undoc-members: diff --git a/docs/conf.py b/docs/conf.py index f7bcc562d..b981eb4d4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -70,6 +70,7 @@ ] ) } +autodoc_member_order = "bysource" bibtex_bibfiles = ['refs.bib'] # -- Options for HTML output ------------------------------------------------- diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 01b7019af..a3625ca3f 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,6 +1,5 @@ from tianshou.policy.base import BasePolicy from tianshou.policy.random import RandomPolicy -from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.modelfree.qrdqn import QRDQNPolicy @@ -11,6 +10,7 @@ from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.modelbase.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -19,7 +19,6 @@ __all__ = [ "BasePolicy", "RandomPolicy", - "ImitationPolicy", "DQNPolicy", "C51Policy", "QRDQNPolicy", @@ -30,6 +29,7 @@ "TD3Policy", "SACPolicy", "DiscreteSACPolicy", + "ImitationPolicy", "DiscreteBCQPolicy", "PSRLPolicy", "MultiAgentPolicyManager",