-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
Hi, when I run my toy script, I find a bug have a question about this:
when batch
input in Policy class's forward
function, it is Batch
type. but it will input to the model, and the model's obs
input's type actually is tensor/ndarray, I cannot find the transition mechanism.
so my toy mechanism is:
import gymnasium as gym
import torch
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts
from tianshou.utils.net.common import Net
from gymnasium.envs.registration import register
task = "DummyAudioAmplify-v0"
register(
id=task,
entry_point="uim_sfit.envs.dummy:DummpyEnv",
max_episode_steps=20,
)
lr, epoch, batch_size = 1e-3, 10, 64
train_num, test_num = 10, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported!
# you can also try with SubprocVectorEnv
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
env = gym.make(task, render_mode="human")
state_shape = [x.shape for x in env.observation_space.values()]
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)
policy = ts.policy.DQNPolicy(
model=net,
optim=optim,
discount_factor=gamma,
action_space=env.action_space,
estimation_step=n_step,
target_update_freq=target_freq
)
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method
result = ts.trainer.OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=epoch,
step_per_epoch=step_per_epoch,
step_per_collect=step_per_collect,
episode_per_test=test_num,
batch_size=batch_size,
update_per_step=1 / step_per_collect,
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
logger=logger,
).run()
print(f"Finished training in {result.timing.total_time} seconds")
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))
policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)
and the env is defined as:
import numpy as np
import soundfile as sf
import librosa
import torch
from gymnasium import Env
from gymnasium import spaces
AUDIO_PATH = "/data/data/aishell3/test/SSB0005.wav"
# https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/
class DummpyEnv(Env):
metadata = {"render_modes": ["human"]}
def __init__(self, render_mode=None, size=3):
super().__init__()
self.size = size
self.observation_space = spaces.Dict(
{
"agent": spaces.Box(0, size - 1, shape=(2,), dtype=int),
"target": spaces.Box(0, size - 1, shape=(2,), dtype=int),
}
)
# We have 4 actions, corresponding to "right", "up", "left", "down"
self.action_space = spaces.Discrete(4)
"""
The following dictionary maps abstract actions from `self.action_space` to
the direction we will walk in if that action is taken.
I.e. 0 corresponds to "right", 1 to "up" etc.
"""
self._action_to_direction = {
0: np.array([1, 0]),
1: np.array([0, 1]),
2: np.array([-1, 0]),
3: np.array([0, -1]),
}
self.audio_spec = librosa.stft(sf.read(AUDIO_PATH)[0])
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
def _get_obs(self):
return {
"agent": torch.as_tensor(self._agent_location),
"target": torch.as_tensor(self._target_location),
}
def _get_info(self):
return {
"distance": np.linalg.norm(
self._agent_location - self._target_location, ord=1
)
}
def reset(self):
self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)
self._target_location = self._agent_location
while np.array_equal(self._target_location, self._agent_location):
self._target_location = self.np_random.integers(
0, self.size, size=2, dtype=int
)
observation = self._get_obs()
info = self._get_info()
return observation, info
def step(self, action):
# Map the action (element of {0,1,2,3}) to the direction we walk in
direction = self._action_to_direction[action]
# We use `np.clip` to make sure we don't leave the grid
self._agent_location = np.clip(
self._agent_location + direction, 0, self.size - 1
)
# An episode is done iff the agent has reached the target
terminated = np.array_equal(self._agent_location, self._target_location)
reward = 1 if terminated else 0 # Binary sparse rewards
observation = self.xiansuanl()
info = self._get_info()
# observation, reward, terminated, truncated, info
return observation, reward, terminated, False, info
def render(self):
# TODO: using self._agent_location to amplify the audio
audio_spec = self.audio_spec * np.repeat(
np.expand_dims(self._location_to_amplify(), axis=1),
self.audio_spec.shape[1],
axis=1
)
return librosa.istft(audio_spec)
def _location_to_amplify(self):
size = self.audio_spec.shape[0]
amplify = self._agent_location * 0.2 + 0.8
size1 = size // 2
size2 = size - size1
amplify = np.concatenate(
[
np.array([amplify[0]] * size1),
np.array([amplify[1]] * size2)
], axis=0
)
return amplify
def close(self):
pass
the bug is:
/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/librosa/core/intervals.py:8: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
from pkg_resources import resource_filename
/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:168: DeprecationWarning: WARN: Current gymnasium version requires that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.
logger.deprecation(
/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:181: DeprecationWarning: WARN: Current gymnasium version requires that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.
logger.deprecation(
/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:131: UserWarning: WARN: The obs returned by the `reset()` method was expecting a numpy array, actual type: <class 'torch.Tensor'>
logger.warn(
/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/gymnasium/spaces/box.py:240: UserWarning: WARN: Casting input x to numpy array.
gym.logger.warn("Casting input x to numpy array.")
Traceback (most recent call last):
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
cli.main()
File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="__main__")
File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/home/haoyu.tang/uim_sfit/test_pipeline.py", line 64, in <module>
).run()
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/trainer/base.py", line 441, in run
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/trainer/base.py", line 252, in __iter__
self.reset()
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/trainer/base.py", line 237, in reset
test_result = test_episode(
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/trainer/utils.py", line 27, in test_episode
result = collector.collect(n_episode=n_episode)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/data/collector.py", line 279, in collect
result = self.policy(self.data, last_state)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/policy/modelfree/dqn.py", line 160, in forward
logits, hidden = model(obs_next, state=state, info=batch.info)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/utils/net/common.py", line 248, in forward
logits = self.model(obs)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/utils/net/common.py", line 142, in forward
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/data/batch.py", line 689, in __len__
raise TypeError(f"Object {obj} in {self} has no len()")
TypeError: Object 0 in Batch(
target: tensor(0),
agent: tensor(0),
) has no len()
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested