+
Skip to content

[Feature] CompressedStorage #3058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ dependencies:
- transformers
- ninja
- timm
- zstandard
62 changes: 62 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ using the following components:
:template: rl_template.rst


CompressedStorage
CompressedStorageCheckpointer
FlatStorageCheckpointer
H5StorageCheckpointer
ImmutableDatasetWriter
Expand Down Expand Up @@ -191,6 +193,66 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be
| :class:`LazyMemmapStorage` | 3.44x |
+-------------------------------+-----------+

Compressed Storage for Memory Efficiency
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

For applications where memory usage is a primary concern, especially when storing
large sensory observations like images or audio, the :class:`~torchrl.data.replay_buffers.storages.CompressedStorage`
provides significant memory savings through compression.

The `CompressedStorage`` compresses data when storing and decompresses when retrieving,
achieving compression ratios of 2-10x for image data while maintaining full data fidelity.
It uses zstd compression by default but supports custom compression algorithms.

Key features:
- **Memory Efficiency**: Achieves significant memory savings through compression
- **Data Integrity**: Maintains full data fidelity through lossless compression
- **Flexible Compression**: Supports custom compression algorithms or uses zstd by default
- **TensorDict Support**: Seamlessly works with TensorDict structures
- **Checkpointing**: Full support for saving and loading compressed data

Example usage:

>>> import torch
>>> from torchrl.data import ReplayBuffer, CompressedStorage
>>> from tensordict import TensorDict
>>>
>>> # Create a compressed storage for image data
>>> storage = CompressedStorage(max_size=1000, compression_level=3)
>>> rb = ReplayBuffer(storage=storage, batch_size=32)
>>>
>>> # Add image data
>>> images = torch.randn(100, 3, 84, 84) # Atari-like frames
>>> data = TensorDict({"obs": images}, batch_size=[100])
>>> rb.extend(data)
>>>
>>> # Sample data (automatically decompressed)
>>> sample = rb.sample(16)
>>> print(sample["obs"].shape) # torch.Size([16, 3, 84, 84])

The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression),
with level 3 being a good default for most use cases.

For custom compression algorithms:

>>> def my_compress(tensor):
... return tensor.to(torch.uint8) # Simple example
>>>
>>> def my_decompress(compressed_tensor, metadata):
... return compressed_tensor.to(metadata["dtype"])
>>>
>>> storage = CompressedStorage(
... max_size=1000,
... compression_fn=my_compress,
... decompression_fn=my_decompress
... )

.. note:: The CompressedStorage requires the `zstandard` library for default compression.
Install with: ``pip install zstandard``

.. note:: An example of how to use the CompressedStorage is available in the
`examples/replay-buffers/compressed_replay_buffer_example.py <https://github.com/pytorch/rl/blob/main/examples/replay-buffers/compressed_replay_buffer_example.py>`_ file.

Sharing replay buffers across processes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
266 changes: 266 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
import os
import pickle
import sys
import tempfile
from functools import partial
from pathlib import Path
from unittest import mock

import numpy as np

import pytest
import torch
from packaging import version
Expand All @@ -35,6 +38,7 @@
from torchrl.collectors import RandomPolicy, SyncDataCollector
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
CompressedStorage,
FlatStorageCheckpointer,
MultiStep,
NestedStorageCheckpointer,
Expand Down Expand Up @@ -129,6 +133,7 @@
_os_is_windows = sys.platform == "win32"
_has_transformers = importlib.util.find_spec("transformers") is not None
_has_ray = importlib.util.find_spec("ray") is not None
_has_zstandard = importlib.util.find_spec("zstandard") is not None

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

Expand Down Expand Up @@ -4027,6 +4032,267 @@ def test_ray_rb_iter(self):
rb.close()


@pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.")
class TestCompressedStorage:
"""Test cases for CompressedStorage."""

def test_compressed_storage_initialization(self):
"""Test that CompressedStorage initializes correctly."""
storage = CompressedStorage(max_size=100, compression_level=3)
assert storage.max_size == 100
assert storage.compression_level == 3
assert len(storage) == 0

def test_compressed_storage_tensor(self):
"""Test compression and decompression of tensor data."""
storage = CompressedStorage(max_size=10, compression_level=3)

# Create test tensor
test_tensor = torch.randn(3, 84, 84, dtype=torch.float32)

# Store tensor
storage.set(0, test_tensor)

# Retrieve tensor
retrieved_tensor = storage.get(0)

# Verify data integrity
assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6)
assert test_tensor.shape == retrieved_tensor.shape
assert test_tensor.dtype == retrieved_tensor.dtype

def test_compressed_storage_tensordict(self):
"""Test compression and decompression of TensorDict data."""
storage = CompressedStorage(max_size=10, compression_level=3)

# Create test TensorDict
test_td = TensorDict(
{
"obs": torch.randn(3, 84, 84, dtype=torch.float32),
"action": torch.tensor([1, 2, 3]),
"reward": torch.randn(3),
"done": torch.tensor([False, True, False]),
},
batch_size=[3],
)

# Store TensorDict
storage.set(0, test_td)

# Retrieve TensorDict
retrieved_td = storage.get(0)

# Verify data integrity
assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6)
assert torch.allclose(test_td["action"], retrieved_td["action"])
assert torch.allclose(test_td["reward"], retrieved_td["reward"], atol=1e-6)
assert torch.allclose(test_td["done"], retrieved_td["done"])

def test_compressed_storage_multiple_indices(self):
"""Test storing and retrieving multiple items."""
storage = CompressedStorage(max_size=10, compression_level=3)

# Store multiple tensors
tensors = [
torch.randn(2, 2, dtype=torch.float32),
torch.randn(3, 3, dtype=torch.float32),
torch.randn(4, 4, dtype=torch.float32),
]

for i, tensor in enumerate(tensors):
storage.set(i, tensor)

# Retrieve multiple tensors
retrieved = storage.get([0, 1, 2])

# Verify data integrity
for original, retrieved_tensor in zip(tensors, retrieved):
assert torch.allclose(original, retrieved_tensor, atol=1e-6)

def test_compressed_storage_with_replay_buffer(self):
"""Test CompressedStorage with ReplayBuffer."""
storage = CompressedStorage(max_size=100, compression_level=3)
rb = ReplayBuffer(storage=storage, batch_size=5)

# Create test data
data = TensorDict(
{
"obs": torch.randn(10, 3, 84, 84, dtype=torch.float32),
"action": torch.randint(0, 4, (10,)),
"reward": torch.randn(10),
},
batch_size=[10],
)

# Add data to replay buffer
print("extending")
rb.extend(data)

# Sample from replay buffer
sample = rb.sample(5)

# Verify sample has correct shape
assert is_tensor_collection(sample), sample
assert sample["obs"].shape[0] == 5
assert sample["obs"].shape[1:] == (3, 84, 84)
assert sample["action"].shape[0] == 5
assert sample["reward"].shape[0] == 5

def test_compressed_storage_state_dict(self):
"""Test saving and loading state dict."""
storage = CompressedStorage(max_size=10, compression_level=3)

# Add some data
test_tensor = torch.randn(3, 3, dtype=torch.float32)
storage.set(0, test_tensor)

# Save state dict
state_dict = storage.state_dict()

# Create new storage and load state dict
new_storage = CompressedStorage(max_size=10, compression_level=3)
new_storage.load_state_dict(state_dict)

# Verify data integrity
retrieved_tensor = new_storage.get(0)
assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6)

def test_compressed_storage_checkpointing(self):
"""Test checkpointing functionality."""
storage = CompressedStorage(max_size=10, compression_level=3)

# Add some data
test_td = TensorDict(
{
"obs": torch.randn(3, 84, 84, dtype=torch.float32),
"action": torch.tensor([1, 2, 3]),
},
batch_size=[3],
)
storage.set(0, test_td)

# Create temporary directory for checkpointing
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = Path(tmpdir) / "checkpoint"

# Save checkpoint
storage.dumps(checkpoint_path)

# Create new storage and load checkpoint
new_storage = CompressedStorage(max_size=10, compression_level=3)
new_storage.loads(checkpoint_path)

# Verify data integrity
retrieved_td = new_storage.get(0)
assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6)
assert torch.allclose(test_td["action"], retrieved_td["action"])

def test_compressed_storage_length(self):
"""Test that length is calculated correctly."""
storage = CompressedStorage(max_size=10, compression_level=3)

# Initially empty
assert len(storage) == 0

# Add some data
storage.set(0, torch.randn(2, 2))
assert len(storage) == 1

storage.set(2, torch.randn(2, 2))
assert len(storage) == 2

storage.set(1, torch.randn(2, 2))
assert len(storage) == 3

def test_compressed_storage_contains(self):
"""Test the contains method."""
storage = CompressedStorage(max_size=10, compression_level=3)

# Initially empty
assert not storage.contains(0)

# Add data
storage.set(0, torch.randn(2, 2))
assert storage.contains(0)
assert not storage.contains(1)

def test_compressed_storage_empty(self):
"""Test emptying the storage."""
storage = CompressedStorage(max_size=10, compression_level=3)

# Add some data
storage.set(0, torch.randn(2, 2))
storage.set(1, torch.randn(2, 2))
assert len(storage) == 2

# Empty storage
storage._empty()
assert len(storage) == 0

def test_compressed_storage_custom_compression(self):
"""Test custom compression functions."""

def custom_compress(tensor):
# Simple compression: just convert to uint8
return tensor.to(torch.uint8)

def custom_decompress(compressed_tensor, metadata):
# Simple decompression: convert back to original dtype
return compressed_tensor.to(metadata["dtype"])

storage = CompressedStorage(
max_size=10,
compression_fn=custom_compress,
decompression_fn=custom_decompress,
)

# Test with tensor
test_tensor = torch.randn(2, 2, dtype=torch.float32)
storage.set(0, test_tensor)
retrieved_tensor = storage.get(0)

# Note: This will lose precision due to uint8 conversion
# but should still work
assert retrieved_tensor.shape == test_tensor.shape

def test_compressed_storage_error_handling(self):
"""Test error handling for invalid operations."""
storage = CompressedStorage(max_size=5, compression_level=3)

# Test setting data beyond max_size
with pytest.raises(RuntimeError):
storage.set(10, torch.randn(2, 2))

# Test getting non-existent data
with pytest.raises(IndexError):
storage.get(0)

def test_compressed_storage_memory_efficiency(self):
"""Test that compression actually reduces memory usage."""
storage = CompressedStorage(max_size=100, compression_level=3)

# Create large tensor data
large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64)
large_tensor.copy_(
torch.arange(large_tensor.numel(), dtype=torch.int32).view_as(large_tensor)
// (3 * 84 * 84)
)
original_size = large_tensor.numel() * large_tensor.element_size()

# Store in compressed storage
storage.set(0, large_tensor)

# Estimate compressed size
compressed_data = storage._compressed_data[0]
compressed_size = compressed_data.numel() # uint8 bytes

# Verify compression ratio is reasonable (at least 2x for random data)
compression_ratio = original_size / compressed_size
assert (
compression_ratio > 1.5
), f"Compression ratio {compression_ratio} is too low"


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载