further simplify the async management

This commit is contained in:
Matthew Farrellee 2025-09-12 17:44:48 -04:00
parent c6403706b4
commit a7caacf1bf
2 changed files with 15 additions and 23 deletions

View file

@ -275,7 +275,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
if recording: if recording:
response_body = recording["response"]["body"] response_body = recording["response"]["body"]
if recording["response"].get("is_streaming", False) or endpoint == "/v1/models": if recording["response"].get("is_streaming", False):
async def replay_stream(): async def replay_stream():
for chunk in response_body: for chunk in response_body:
@ -298,6 +298,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
else: else:
response = original_method(self, *args, **kwargs) response = original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
response = [m async for m in response]
request_data = { request_data = {
"method": method, "method": method,
"url": url, "url": url,
@ -310,7 +314,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
# Determine if this is a streaming request based on request parameters # Determine if this is a streaming request based on request parameters
is_streaming = body.get("stream", False) is_streaming = body.get("stream", False)
if is_streaming or endpoint == "/v1/models": if is_streaming:
# For streaming responses, we need to collect all chunks immediately before yielding # For streaming responses, we need to collect all chunks immediately before yielding
# This ensures the recording is saved even if the generator isn't fully consumed # This ensures the recording is saved even if the generator isn't fully consumed
chunks = [] chunks = []
@ -377,15 +381,10 @@ def patch_inference_clients():
) )
def patched_models_list(self, *args, **kwargs): def patched_models_list(self, *args, **kwargs):
import asyncio
task = asyncio.create_task(
_patched_inference_method(_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs)
)
async def _iter(): async def _iter():
result = await task for item in await _patched_inference_method(
async for item in result: _original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
):
yield item yield item
return _iter() return _iter()

View file

@ -221,16 +221,9 @@ class TestInferenceRecording:
async def test_replay_mode_models(self, temp_storage_dir): async def test_replay_mode_models(self, temp_storage_dir):
"""Test that replay mode returns stored responses without making real model listing calls.""" """Test that replay mode returns stored responses without making real model listing calls."""
class MockAsyncPaginator: async def _async_iterator(models):
def __init__(self, models: list[OpenAIModel]): for model in models:
self._models = models yield model
def __aiter__(self):
return self._async_iterator()
async def _async_iterator(self):
for model in self._models:
yield model
models = [ models = [
OpenAIModel(id="foo", created=1, object="model", owned_by="test"), OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
@ -243,21 +236,21 @@ class TestInferenceRecording:
# baseline - mock works without recording # baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models)) client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_called_once() client.models._get_api_list.assert_called_once()
# record the call # record the call
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir): with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models)) client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_called_once() client.models._get_api_list.assert_called_once()
# replay the call # replay the call
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir): with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models)) client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_not_called() client.models._get_api_list.assert_not_called()