mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 04:45:44 +00:00
further simplify the async management
This commit is contained in:
parent
c6403706b4
commit
a7caacf1bf
2 changed files with 15 additions and 23 deletions
|
@ -221,16 +221,9 @@ class TestInferenceRecording:
|
|||
async def test_replay_mode_models(self, temp_storage_dir):
|
||||
"""Test that replay mode returns stored responses without making real model listing calls."""
|
||||
|
||||
class MockAsyncPaginator:
|
||||
def __init__(self, models: list[OpenAIModel]):
|
||||
self._models = models
|
||||
|
||||
def __aiter__(self):
|
||||
return self._async_iterator()
|
||||
|
||||
async def _async_iterator(self):
|
||||
for model in self._models:
|
||||
yield model
|
||||
async def _async_iterator(models):
|
||||
for model in models:
|
||||
yield model
|
||||
|
||||
models = [
|
||||
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
||||
|
@ -243,21 +236,21 @@ class TestInferenceRecording:
|
|||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models))
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# record the call
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models))
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# replay the call
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models))
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_not_called()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue