+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 79 additions & 88 deletions test/llm/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_has_transformers = importlib.util.find_spec("transformers") is not None
_has_vllm = importlib.util.find_spec("vllm") is not None
_has_datasets = importlib.util.find_spec("datasets") is not None
_has_ray = importlib.util.find_spec("ray") is not None

TransformersWrapperMaxTokens = partial(
TransformersWrapper, generate_kwargs={"max_new_tokens": 10, "do_sample": True}
Expand Down Expand Up @@ -2508,94 +2509,6 @@ def test_batching_min_batch_size_one_immediate_processing(
finally:
pool.shutdown(wait=False, cancel_futures=True)

@pytest.mark.parametrize(
"wrapper_class",
[vLLMWrapper, TransformersWrapperMaxTokens],
ids=["vllm", "transformers"],
)
def test_batching_continuous_throughput(
self,
wrapper_class,
vllm_instance,
transformers_instance,
monkey_patch_forward_for_instrumentation,
):
"""Test that the wrapper stays busy with continuous requests."""
import time
from concurrent.futures import ThreadPoolExecutor, wait

# Create wrapper using helper function
wrapper = create_batching_test_wrapper(
wrapper_class,
vllm_instance,
transformers_instance,
min_batch_size=1,
max_batch_size=2, # Small batch size to maximize throughput
batching_timeout=5.0,
)

# Monkey patch the forward method using fixture
processing_events = monkey_patch_forward_for_instrumentation[
"processing_events"
]

# Submit continuous requests
futures = []
pool = ThreadPoolExecutor(max_workers=5)
try:
# Submit requests rapidly
for i in range(10):
input_td = TensorDict(
text=Text(prompt=[f"Continuous request {i}"]), batch_size=(1,)
)
future = pool.submit(wrapper.instrumented_forward, input_td)
futures.append(future)
time.sleep(0.02) # Small delay between submissions

# Wait for all futures to complete
wait(futures, timeout=30)

# Verify all futures completed successfully
for future in futures:
result = future.result(timeout=5)
assert "text" in result

# Analyze processing patterns
assert len(processing_events) > 0, "No processing occurred"

# Check that processing happened across multiple threads (indicating concurrent processing)
thread_ids = {event["thread_id"] for event in processing_events} # noqa
assert (
len(thread_ids) > 1
), f"All processing happened in single thread: {thread_ids}"

# Check that we have multiple processing events (indicating continuous activity)
assert (
len(processing_events) >= 5
), f"Too few processing events: {len(processing_events)}"

# Check that batches were formed (some batch sizes > 1)
batch_sizes = [event["batch_size"] for event in processing_events]
assert any(
bs > 1 for bs in batch_sizes
), f"No batching occurred: {batch_sizes}"

# Check processing timing - should be relatively continuous
timestamps = [event["timestamp"] for event in processing_events]
time_diffs = [
timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)
]

# Most time differences should be small (indicating continuous processing)
small_diffs = [diff for diff in time_diffs if diff < 1.0]
assert (
len(small_diffs) >= len(time_diffs) * 0.7
), f"Too many large gaps in processing: {time_diffs}"
finally:
pool.shutdown(wait=False, cancel_futures=True)
del wrapper
gc.collect()

@pytest.mark.parametrize(
"wrapper_class",
[vLLMWrapper, TransformersWrapperMaxTokens],
Expand Down Expand Up @@ -2921,6 +2834,84 @@ def test_ray_wrapper(self, sample_text, backend):
gc.collect()


@pytest.mark.skipif(not _has_ray, reason="Ray not available")
class TestActorSharing:
"""Test actor sharing functionality for Remote wrappers."""

@pytest.mark.parametrize("backend", ["transformers", "vllm"])
def test_actor_sharing(self, backend):
"""Test that creating the same wrapper twice uses the same actor."""
import ray
from torchrl.modules.llm.policies import (
RemoteTransformersWrapper,
RemotevLLMWrapper,
)

# Initialize Ray if not already done
if not ray.is_initialized():
ray.init()

# Choose the wrapper class based on backend
if backend == "vllm":
if not _has_vllm:
pytest.skip("vllm not available")
WrapperClass = RemotevLLMWrapper
elif backend == "transformers":
if not _has_transformers:
pytest.skip("transformers not available")
WrapperClass = RemoteTransformersWrapper
else:
raise ValueError(f"Invalid backend: {backend}")

try:
# Create first wrapper with explicit actor name
wrapper1 = WrapperClass(
model="Qwen/Qwen2.5-0.5B",
generate=True,
input_mode="text",
generate_kwargs={"max_new_tokens": 5},
actor_name="test_shared_actor",
)

# Create second wrapper with same actor name
wrapper2 = WrapperClass(
model="Qwen/Qwen2.5-0.5B",
generate=True,
input_mode="text",
generate_kwargs={"max_new_tokens": 5},
actor_name="test_shared_actor",
)

# Check that both wrappers use the same actor
assert (
wrapper1._remote_wrapper == wrapper2._remote_wrapper
), f"Wrappers should share the same actor for backend {backend}"

# Test that both wrappers work
test_data = TensorDict(
text=Text(prompt="Hello, how are you?"),
batch_size=(),
)

result1 = wrapper1(test_data)
result2 = wrapper2(test_data)

# Both should produce valid results
assert "text" in result1
assert "text" in result2
assert isinstance(result1["text"].response, str)
assert isinstance(result2["text"].response, str)

finally:
# Cleanup
try:
del wrapper1
del wrapper2
gc.collect()
except Exception:
pass


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
87 changes: 47 additions & 40 deletions torchrl/modules/llm/policies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,11 +1300,11 @@ def _extract_responses_from_full_histories(
def _batching(func):
@wraps(func)
def _batched_func(self, td_input: TensorDictBase, **kwargs):
# -- 0. skip if batching is disabled
# -- 0. Bypass if batching disabled
if not self.batching:
return func(self, td_input, **kwargs)

# ── 1. Normalise input ──────────────────────────────────────────────────
# -- 1. Normalise --------------------------------------------------------
if td_input.batch_dims > 1:
raise ValueError(
f"Batching not supported for batch_dims > 1: {td_input.batch_dims}"
Expand All @@ -1313,52 +1313,59 @@ def _batched_func(self, td_input: TensorDictBase, **kwargs):
single = td_input.batch_dims == 0
inputs = [td_input] if single else list(td_input.unbind(0))
futures = [Future() for _ in inputs]
pending = set(futures) # ← track our own Futures

# ── 2. Enqueue work and, if first in, do the draining ───────────────────
# -- 2. Enqueue ----------------------------------------------------------
self._batch_queue.extend(inputs)
self._futures.extend(futures)

min_bs = getattr(self, "_min_batch_size", 1)
max_bs = getattr(self, "_max_batch_size", None)

# -- 3. Drain while holding the lock ------------------------------------
with self._batching_lock:
# Only the thread that managed to grab the lock will run the loop
while len(self._batch_queue) >= min_bs:
# Determine slice
slice_size = (
len(self._batch_queue)
if max_bs is None
else min(max_bs, len(self._batch_queue))
)
batch = self._batch_queue[:slice_size]
fut_slice = self._futures[:slice_size]
if all(f.done() for f in futures):
# Our items were already processed by another thread.
# Skip draining; other workers will handle the rest of the queue.
pass
else:
while len(self._batch_queue) >= min_bs:
slice_size = (
len(self._batch_queue)
if max_bs is None
else min(max_bs, len(self._batch_queue))
)
batch = self._batch_queue[:slice_size]
fut_slice = self._futures[:slice_size]

try:
results = func(self, lazy_stack(batch), **kwargs).unbind(0)
if len(results) != slice_size:
raise RuntimeError(
f"Expected {slice_size} results, got {len(results)}"
)
for fut, res in zip(fut_slice, results):
fut.set_result(res)
pending.discard(fut) # ← mark as done
except Exception as exc:
for fut in fut_slice:
fut.set_exception(exc)
pending.discard(fut)
raise

# Execute model
try:
results = func(self, lazy_stack(batch), **kwargs).unbind(0)
if len(results) != slice_size: # sanity
raise RuntimeError(
f"Expected {slice_size} results, got {len(results)}"
)
# Fulfil the corresponding futures
for fut, res in zip(fut_slice, results):
fut.set_result(res)
except Exception as exc:
for fut in fut_slice:
fut.set_exception(exc)
# Propagate to caller; other waiters will read the exception from their future
raise

# Pop processed work
del self._batch_queue[:slice_size]
del self._futures[:slice_size]

# ── 3. Outside the lock: wait only for OUR futures (they may already be done) ──
wait(
futures
) # no timeout → immediate return if set_result()/set_exception() already called
result = [f.result() for f in futures]

return result[0] if single else lazy_stack(result)
# Pop processed work
del self._batch_queue[:slice_size]
del self._futures[:slice_size]

# ---- Early-exit: all *our* Futures are done -------------------
if not pending:
break

# -- 4. Outside the lock: wait only on remaining (rare) -----------------
if pending: # usually empty; safety for min_bs > queue size
wait(pending)
results = [f.result() for f in futures]

return results[0] if single else lazy_stack(results)

return _batched_func
26 changes: 22 additions & 4 deletions torchrl/modules/llm/policies/transformers_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tensordict.utils import _zip_strict, NestedKey
from torch import distributions as D
from torch.nn.utils.rnn import pad_sequence

from torchrl import logger as torchrl_logger
from torchrl.modules.llm.policies.common import (
_batching,
_extract_responses_from_full_histories,
Expand Down Expand Up @@ -2443,7 +2443,12 @@ class RemoteTransformersWrapper:
"""

def __init__(
self, model, max_concurrency: int = 16, validate_model: bool = True, **kwargs
self,
model,
max_concurrency: int = 16,
validate_model: bool = True,
actor_name: str = None,
**kwargs,
):
import ray

Expand All @@ -2458,10 +2463,23 @@ def __init__(

if not ray.is_initialized():
ray.init()
# Create the remote actor

if actor_name is not None:
# Check if an actor with this name already exists
try:
existing_actor = ray.get_actor(actor_name)
# If we can get the actor, assume it's alive and use it
self._remote_wrapper = existing_actor
torchrl_logger.info(f"Using existing actor {actor_name}")
return
except ValueError:
# Actor doesn't exist, create a new one
torchrl_logger.info(f"Creating new actor {actor_name}")

# Create the remote actor with the unique name
self._remote_wrapper = (
ray.remote(TransformersWrapper)
.options(max_concurrency=max_concurrency)
.options(max_concurrency=max_concurrency, name=actor_name)
.remote(model, **kwargs)
)

Expand Down
24 changes: 21 additions & 3 deletions torchrl/modules/llm/policies/vllm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensordict.utils import _zip_strict, NestedKey
from torch import distributions as D
from torch.nn.utils.rnn import pad_sequence
from torchrl import logger as torchrl_logger

from torchrl.envs.utils import _classproperty
from torchrl.modules.llm.policies.common import (
Expand Down Expand Up @@ -2101,7 +2102,12 @@ class RemotevLLMWrapper:
"""

def __init__(
self, model, max_concurrency: int = 16, validate_model: bool = True, **kwargs
self,
model,
max_concurrency: int = 16,
validate_model: bool = True,
actor_name: str = None,
**kwargs,
):
import ray

Expand Down Expand Up @@ -2141,10 +2147,22 @@ def __init__(
if not ray.is_initialized():
ray.init()

# Create the remote actor
if actor_name is not None:
# Check if an actor with this name already exists
try:
existing_actor = ray.get_actor(actor_name)
torchrl_logger.info(f"Using existing actor {actor_name}")
# If we can get the actor, assume it's alive and use it
self._remote_wrapper = existing_actor
return
except ValueError:
# Actor doesn't exist, create a new one
torchrl_logger.info(f"Creating new actor {actor_name}")

# Create the remote actor with the unique name
self._remote_wrapper = (
ray.remote(vLLMWrapper)
.options(max_concurrency=max_concurrency)
.options(max_concurrency=max_concurrency, name=actor_name)
.remote(model, **kwargs)
)

Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载