From 3de9ad0a87d7bfad50ab23c859cebcaf06b6911b Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 12 Sep 2025 17:59:56 -0400 Subject: [PATCH] chore(recorder, tests): add test for openai /v1/models (#3426) # What does this PR do? - [x] adds a test for the recorder's handling of /v1/models - [x] adds a fix for /v1/models handling ## Test Plan ci --- llama_stack/testing/inference_recorder.py | 60 ++++++++++--------- .../distribution/test_inference_recordings.py | 51 ++++++++++++++-- 2 files changed, 79 insertions(+), 32 deletions(-) diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index 6f017c51d..745160976 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -7,6 +7,7 @@ from __future__ import annotations # for forward references import hashlib +import inspect import json import os from collections.abc import Generator @@ -198,16 +199,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str: Supported endpoints: - '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ] - - '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ] + - '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ] Returns a list of unique identifiers or None if structure doesn't match. """ - body = response["body"] - if endpoint == "/api/tags": - items = body.get("models") - idents = [m.model for m in items] - else: - items = body.get("data") - idents = [m.id for m in items] + items = response["body"] + idents = [m.model if endpoint == "/api/tags" else m.id for m in items] return sorted(set(idents)) identifiers = _extract_model_identifiers() @@ -219,28 +215,22 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) seen: dict[str, dict[str, Any]] = {} for rec in records: body = rec["response"]["body"] - if endpoint == "/api/tags": - items = body.models - elif endpoint == "/v1/models": - items = body.data - else: - items = [] - - for m in items: - if endpoint == "/v1/models": + if endpoint == "/v1/models": + for m in body: key = m.id - else: + seen[key] = m + elif endpoint == "/api/tags": + for m in body.models: key = m.model - seen[key] = m + seen[key] = m ordered = [seen[k] for k in sorted(seen.keys())] canonical = records[0] canonical_req = canonical.get("request", {}) if isinstance(canonical_req, dict): canonical_req["endpoint"] = endpoint - if endpoint == "/v1/models": - body = {"data": ordered, "object": "list"} - else: + body = ordered + if endpoint == "/api/tags": from ollama import ListResponse body = ListResponse(models=ordered) @@ -252,7 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint if _current_mode == InferenceMode.LIVE or _current_storage is None: # Normal operation - return await original_method(self, *args, **kwargs) + if inspect.iscoroutinefunction(original_method): + return await original_method(self, *args, **kwargs) + else: + return original_method(self, *args, **kwargs) # Get base URL based on client type if client_type == "openai": @@ -300,7 +293,14 @@ async def _patched_inference_method(original_method, self, client_type, endpoint ) elif _current_mode == InferenceMode.RECORD: - response = await original_method(self, *args, **kwargs) + if inspect.iscoroutinefunction(original_method): + response = await original_method(self, *args, **kwargs) + else: + response = original_method(self, *args, **kwargs) + + # we want to store the result of the iterator, not the iterator itself + if endpoint == "/v1/models": + response = [m async for m in response] request_data = { "method": method, @@ -380,10 +380,14 @@ def patch_inference_clients(): _original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs ) - async def patched_models_list(self, *args, **kwargs): - return await _patched_inference_method( - _original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs - ) + def patched_models_list(self, *args, **kwargs): + async def _iter(): + for item in await _patched_inference_method( + _original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs + ): + yield item + + return _iter() # Apply OpenAI patches AsyncChatCompletions.create = patched_chat_completions_create diff --git a/tests/unit/distribution/test_inference_recordings.py b/tests/unit/distribution/test_inference_recordings.py index c69cf319b..94fd2536e 100644 --- a/tests/unit/distribution/test_inference_recordings.py +++ b/tests/unit/distribution/test_inference_recordings.py @@ -6,10 +6,11 @@ import tempfile from pathlib import Path -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock, patch import pytest from openai import AsyncOpenAI +from openai.types.model import Model as OpenAIModel # Import the real Pydantic response types instead of using Mocks from llama_stack.apis.inference import ( @@ -158,7 +159,9 @@ class TestInferenceRecording: return real_openai_chat_response temp_storage_dir = temp_storage_dir / "test_recording_mode" - with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create + ): with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") @@ -184,7 +187,9 @@ class TestInferenceRecording: temp_storage_dir = temp_storage_dir / "test_replay_mode" # First, record a response - with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create + ): with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") @@ -213,6 +218,42 @@ class TestInferenceRecording: # Verify the original method was NOT called mock_create_patch.assert_not_called() + async def test_replay_mode_models(self, temp_storage_dir): + """Test that replay mode returns stored responses without making real model listing calls.""" + + async def _async_iterator(models): + for model in models: + yield model + + models = [ + OpenAIModel(id="foo", created=1, object="model", owned_by="test"), + OpenAIModel(id="bar", created=2, object="model", owned_by="test"), + ] + + expected_ids = {m.id for m in models} + + temp_storage_dir = temp_storage_dir / "test_replay_mode_models" + + # baseline - mock works without recording + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.models._get_api_list = Mock(return_value=_async_iterator(models)) + assert {m.id async for m in client.models.list()} == expected_ids + client.models._get_api_list.assert_called_once() + + # record the call + with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.models._get_api_list = Mock(return_value=_async_iterator(models)) + assert {m.id async for m in client.models.list()} == expected_ids + client.models._get_api_list.assert_called_once() + + # replay the call + with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.models._get_api_list = Mock(return_value=_async_iterator(models)) + assert {m.id async for m in client.models.list()} == expected_ids + client.models._get_api_list.assert_not_called() + async def test_replay_missing_recording(self, temp_storage_dir): """Test that replay mode fails when no recording is found.""" temp_storage_dir = temp_storage_dir / "test_replay_missing_recording" @@ -233,7 +274,9 @@ class TestInferenceRecording: temp_storage_dir = temp_storage_dir / "test_embeddings_recording" # Record - with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create): + with patch( + "openai.resources.embeddings.AsyncEmbeddings.create", new_callable=AsyncMock, side_effect=mock_create + ): with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")