这是indexloc提供的服务,不要输入任何密码
Skip to content

Standardized behavior of Batch.cat and misc code refactor #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1f84806
code refactor; remove unused kwargs; add reward_normalization for dqn
youkaichao Jul 14, 2020
80a4f98
bugfix for __setitem__ with torch.Tensor; add Batch.condense
youkaichao Jul 14, 2020
3d7cc24
minor fix
Trinkle23897 Jul 14, 2020
fc84433
support cat with empty Batch
youkaichao Jul 14, 2020
9fa118c
remove the dependency of is_empty on len; specify the semantic of emp…
youkaichao Jul 14, 2020
ceca419
support stack with empty Batch
youkaichao Jul 14, 2020
4557baa
remove condense
youkaichao Jul 14, 2020
3de0218
refactor code to reflect the shared / partial / reserved categories o…
youkaichao Jul 14, 2020
f840c73
add is_empty(recursive=False)
youkaichao Jul 16, 2020
ce08bac
doc fix
youkaichao Jul 16, 2020
35b1533
docfix and bugfix for _is_batch_set
youkaichao Jul 16, 2020
8c2847f
add doc for key reservation
youkaichao Jul 16, 2020
e1e36e0
bugfix for algebra operators
youkaichao Jul 16, 2020
6d2cda6
fix cat with lens hint
youkaichao Jul 16, 2020
ebf19ea
code refactor
youkaichao Jul 16, 2020
2541bec
bugfix for storing None
youkaichao Jul 16, 2020
28b33e8
use ValueError instead of exception
youkaichao Jul 16, 2020
3046067
hide lens away from users
youkaichao Jul 16, 2020
fba94a6
add comment for __cat
youkaichao Jul 16, 2020
6287326
move the computation of the initial value of lens in cat_ itself.
youkaichao Jul 16, 2020
f64faf5
change the place of doc string
youkaichao Jul 16, 2020
b795493
doc fix for Batch doc string
youkaichao Jul 16, 2020
02f9d1a
change recursive to recurse
youkaichao Jul 16, 2020
2a063f0
doc string fix
youkaichao Jul 16, 2020
7912eae
minor fix for batch doc
Trinkle23897 Jul 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
def test_batch():
assert list(Batch()) == []
assert Batch().is_empty()
assert Batch(b={'c': {}}).is_empty()
assert not Batch(b={'c': {}}).is_empty()
assert Batch(b={'c': {}}).is_empty(recurse=True)
assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
assert not Batch(d=1).is_empty()
assert not Batch(a=np.float64(1.0)).is_empty()
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
b = Batch()
Expand Down Expand Up @@ -109,6 +114,11 @@ def test_batch():
assert isinstance(batch5.b, Batch)
assert np.allclose(batch5.b.index, [1])

# None is a valid object and can be stored in Batch
a = Batch.stack([Batch(a=None), Batch(b=None)])
assert a.a[0] is None and a.a[1] is None
assert a.b[0] is None and a.b[1] is None


def test_batch_over_batch():
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
Expand Down Expand Up @@ -162,6 +172,20 @@ def test_batch_cat_and_stack():
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b12_cat_in.a.d.e.ndim == 1

a = Batch(a=Batch(a=np.random.randn(3, 4)))
assert np.allclose(
np.concatenate([a.a.a, a.a.a]),
Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)

# test cat with lens infer
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a,
np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()

b12_stack = Batch.stack((b1, b2))
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
Expand All @@ -177,6 +201,32 @@ def test_batch_cat_and_stack():
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)

# test cat with reserved keys (values are Batch())
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)

# test cat with all reserved keys (values are Batch())
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=Batch(),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert ans.a.is_empty()
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)

# test stack with compatible keys
b3 = Batch(a=np.zeros((3, 4)),
b=torch.ones((2, 5)),
Expand Down Expand Up @@ -205,6 +255,25 @@ def test_batch_cat_and_stack():
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])

# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))

b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
Expand Down
Loading