diff --git a/tests/unit/distribution/test_inference_recordings.py b/tests/unit/distribution/test_inference_recordings.py index c69cf319b..f1d606fc2 100644 --- a/tests/unit/distribution/test_inference_recordings.py +++ b/tests/unit/distribution/test_inference_recordings.py @@ -6,10 +6,11 @@ import tempfile from pathlib import Path -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from openai import AsyncOpenAI +from openai.types.model import Model as OpenAIModel # Import the real Pydantic response types instead of using Mocks from llama_stack.apis.inference import ( @@ -213,6 +214,49 @@ class TestInferenceRecording: # Verify the original method was NOT called mock_create_patch.assert_not_called() + async def test_replay_mode_models(self, temp_storage_dir): + """Test that replay mode returns stored responses without making real model listing calls.""" + + class MockAsyncPaginator: + def __init__(self, models: list[OpenAIModel]): + self._models = models + + def __aiter__(self): + return self._async_iterator() + + async def _async_iterator(self): + for model in self._models: + yield model + + models = [ + OpenAIModel(id="foo", created=1, object="model", owned_by="test"), + OpenAIModel(id="bar", created=2, object="model", owned_by="test"), + ] + + expected_ids = {m.id for m in models} + + temp_storage_dir = temp_storage_dir / "test_replay_mode_models" + + # baseline - mock works without recording + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models)) + assert {m.id async for m in client.models.list()} == expected_ids + client.models._get_api_list.assert_called_once() + + # record the call + with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models)) + assert {m.id async for m in client.models.list()} == expected_ids + client.models._get_api_list.assert_called_once() + + # replay the call + with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models)) + assert {m.id async for m in client.models.list()} == expected_ids + client.models._get_api_list.assert_not_called() + async def test_replay_missing_recording(self, temp_storage_dir): """Test that replay mode fails when no recording is found.""" temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"