From 6787755c0c8af6b59322352f985cffb224aadd3b Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 13 Sep 2025 14:11:38 -0400 Subject: [PATCH] chore(recorder): add support for NOT_GIVEN (#3430) # What does this PR do? the recorder mocks the openai-python interface. the openai-python interface allows NOT_GIVEN as an input option. this change properly handles NOT_GIVEN. ## Test Plan ci (coverage for chat, completions, embeddings) --- llama_stack/testing/inference_recorder.py | 5 ++ .../distribution/test_inference_recordings.py | 65 ++++++++++++++++++- 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index 745160976..f899d73d3 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -16,6 +16,8 @@ from enum import StrEnum from pathlib import Path from typing import Any, Literal, cast +from openai import NOT_GIVEN + from llama_stack.log import get_logger logger = get_logger(__name__, category="testing") @@ -250,6 +252,9 @@ async def _patched_inference_method(original_method, self, client_type, endpoint # Get base URL based on client type if client_type == "openai": base_url = str(self._client.base_url) + + # the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out + kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN} elif client_type == "ollama": # Get base URL from the client (Ollama client uses host attribute) base_url = getattr(self, "host", "http://localhost:11434") diff --git a/tests/unit/distribution/test_inference_recordings.py b/tests/unit/distribution/test_inference_recordings.py index 94fd2536e..4909bbe1e 100644 --- a/tests/unit/distribution/test_inference_recordings.py +++ b/tests/unit/distribution/test_inference_recordings.py @@ -9,7 +9,7 @@ from pathlib import Path from unittest.mock import AsyncMock, Mock, patch import pytest -from openai import AsyncOpenAI +from openai import NOT_GIVEN, AsyncOpenAI from openai.types.model import Model as OpenAIModel # Import the real Pydantic response types instead of using Mocks @@ -17,6 +17,7 @@ from llama_stack.apis.inference import ( OpenAIAssistantMessageParam, OpenAIChatCompletion, OpenAIChoice, + OpenAICompletion, OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, @@ -170,6 +171,7 @@ class TestInferenceRecording: messages=[{"role": "user", "content": "Hello, how are you?"}], temperature=0.7, max_tokens=50, + user=NOT_GIVEN, ) # Verify the response was returned correctly @@ -198,6 +200,7 @@ class TestInferenceRecording: messages=[{"role": "user", "content": "Hello, how are you?"}], temperature=0.7, max_tokens=50, + user=NOT_GIVEN, ) # Now test replay mode - should not call the original method @@ -281,7 +284,11 @@ class TestInferenceRecording: client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") response = await client.embeddings.create( - model="nomic-embed-text", input=["Hello world", "Test embedding"] + model=real_embeddings_response.model, + input=["Hello world", "Test embedding"], + encoding_format=NOT_GIVEN, + dimensions=NOT_GIVEN, + user=NOT_GIVEN, ) assert len(response.data) == 2 @@ -292,7 +299,8 @@ class TestInferenceRecording: client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") response = await client.embeddings.create( - model="nomic-embed-text", input=["Hello world", "Test embedding"] + model=real_embeddings_response.model, + input=["Hello world", "Test embedding"], ) # Verify we got the recorded response @@ -302,6 +310,57 @@ class TestInferenceRecording: # Verify original method was not called mock_create_patch.assert_not_called() + async def test_completions_recording(self, temp_storage_dir): + real_completions_response = OpenAICompletion( + id="test_completion", + object="text_completion", + created=1234567890, + model="llama3.2:3b", + choices=[ + { + "text": "Hello! I'm doing well, thank you for asking.", + "index": 0, + "logprobs": None, + "finish_reason": "stop", + } + ], + ) + + async def mock_create(*args, **kwargs): + return real_completions_response + + temp_storage_dir = temp_storage_dir / "test_completions_recording" + + # Record + with patch( + "openai.resources.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create + ): + with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.completions.create( + model=real_completions_response.model, + prompt="Hello, how are you?", + temperature=0.7, + max_tokens=50, + user=NOT_GIVEN, + ) + + assert response.choices[0].text == real_completions_response.choices[0].text + + # Replay + with patch("openai.resources.completions.AsyncCompletions.create") as mock_create_patch: + with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + response = await client.completions.create( + model=real_completions_response.model, + prompt="Hello, how are you?", + temperature=0.7, + max_tokens=50, + ) + assert response.choices[0].text == real_completions_response.choices[0].text + mock_create_patch.assert_not_called() + async def test_live_mode(self, real_openai_chat_response): """Test that live mode passes through to original methods."""