From a7caacf1bff9d55c62098dada78d2302bb4a1a66 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 12 Sep 2025 17:44:48 -0400 Subject: [PATCH] further simplify the async management --- llama_stack/testing/inference_recorder.py | 19 +++++++++---------- .../distribution/test_inference_recordings.py | 19 ++++++------------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index e09513e4d..0b21d0c10 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -275,7 +275,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint if recording: response_body = recording["response"]["body"] - if recording["response"].get("is_streaming", False) or endpoint == "/v1/models": + if recording["response"].get("is_streaming", False): async def replay_stream(): for chunk in response_body: @@ -298,6 +298,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint else: response = original_method(self, *args, **kwargs) + # we want to store the result of the iterator, not the iterator itself + if endpoint == "/v1/models": + response = [m async for m in response] + request_data = { "method": method, "url": url, @@ -310,7 +314,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint # Determine if this is a streaming request based on request parameters is_streaming = body.get("stream", False) - if is_streaming or endpoint == "/v1/models": + if is_streaming: # For streaming responses, we need to collect all chunks immediately before yielding # This ensures the recording is saved even if the generator isn't fully consumed chunks = [] @@ -377,15 +381,10 @@ def patch_inference_clients(): ) def patched_models_list(self, *args, **kwargs): - import asyncio - - task = asyncio.create_task( - _patched_inference_method(_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs) - ) - async def _iter(): - result = await task - async for item in result: + for item in await _patched_inference_method( + _original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs + ): yield item return _iter() diff --git a/tests/unit/distribution/test_inference_recordings.py b/tests/unit/distribution/test_inference_recordings.py index d9cda5af9..94fd2536e 100644 --- a/tests/unit/distribution/test_inference_recordings.py +++ b/tests/unit/distribution/test_inference_recordings.py @@ -221,16 +221,9 @@ class TestInferenceRecording: async def test_replay_mode_models(self, temp_storage_dir): """Test that replay mode returns stored responses without making real model listing calls.""" - class MockAsyncPaginator: - def __init__(self, models: list[OpenAIModel]): - self._models = models - - def __aiter__(self): - return self._async_iterator() - - async def _async_iterator(self): - for model in self._models: - yield model + async def _async_iterator(models): + for model in models: + yield model models = [ OpenAIModel(id="foo", created=1, object="model", owned_by="test"), @@ -243,21 +236,21 @@ class TestInferenceRecording: # baseline - mock works without recording client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models)) + client.models._get_api_list = Mock(return_value=_async_iterator(models)) assert {m.id async for m in client.models.list()} == expected_ids client.models._get_api_list.assert_called_once() # record the call with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models)) + client.models._get_api_list = Mock(return_value=_async_iterator(models)) assert {m.id async for m in client.models.list()} == expected_ids client.models._get_api_list.assert_called_once() # replay the call with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models)) + client.models._get_api_list = Mock(return_value=_async_iterator(models)) assert {m.id async for m in client.models.list()} == expected_ids client.models._get_api_list.assert_not_called()