chore(recorder): add support for NOT_GIVEN

This commit is contained in:
Matthew Farrellee 2025-09-13 05:06:07 -04:00
parent 3de9ad0a87
commit d37978508f
2 changed files with 67 additions and 3 deletions

View file

@ -16,6 +16,8 @@ from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Any, Literal, cast from typing import Any, Literal, cast
from openai import NOT_GIVEN
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(__name__, category="testing") logger = get_logger(__name__, category="testing")
@ -250,6 +252,9 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
# Get base URL based on client type # Get base URL based on client type
if client_type == "openai": if client_type == "openai":
base_url = str(self._client.base_url) base_url = str(self._client.base_url)
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
elif client_type == "ollama": elif client_type == "ollama":
# Get base URL from the client (Ollama client uses host attribute) # Get base URL from the client (Ollama client uses host attribute)
base_url = getattr(self, "host", "http://localhost:11434") base_url = getattr(self, "host", "http://localhost:11434")

View file

@ -9,7 +9,7 @@ from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from openai import AsyncOpenAI from openai import NOT_GIVEN, AsyncOpenAI
from openai.types.model import Model as OpenAIModel from openai.types.model import Model as OpenAIModel
# Import the real Pydantic response types instead of using Mocks # Import the real Pydantic response types instead of using Mocks
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChoice, OpenAIChoice,
OpenAICompletion,
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
@ -170,6 +171,7 @@ class TestInferenceRecording:
messages=[{"role": "user", "content": "Hello, how are you?"}], messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7, temperature=0.7,
max_tokens=50, max_tokens=50,
user=NOT_GIVEN,
) )
# Verify the response was returned correctly # Verify the response was returned correctly
@ -198,6 +200,7 @@ class TestInferenceRecording:
messages=[{"role": "user", "content": "Hello, how are you?"}], messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7, temperature=0.7,
max_tokens=50, max_tokens=50,
user=NOT_GIVEN,
) )
# Now test replay mode - should not call the original method # Now test replay mode - should not call the original method
@ -281,7 +284,11 @@ class TestInferenceRecording:
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create( response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"] model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
dimensions=NOT_GIVEN,
user=NOT_GIVEN,
) )
assert len(response.data) == 2 assert len(response.data) == 2
@ -292,7 +299,8 @@ class TestInferenceRecording:
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create( response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"] model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
) )
# Verify we got the recorded response # Verify we got the recorded response
@ -302,6 +310,57 @@ class TestInferenceRecording:
# Verify original method was not called # Verify original method was not called
mock_create_patch.assert_not_called() mock_create_patch.assert_not_called()
async def test_completions_recording(self, temp_storage_dir):
real_completions_response = OpenAICompletion(
id="test_completion",
object="text_completion",
created=1234567890,
model="llama3.2:3b",
choices=[
{
"text": "Hello! I'm doing well, thank you for asking.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
)
async def mock_create(*args, **kwargs):
return real_completions_response
temp_storage_dir = temp_storage_dir / "test_completions_recording"
# Record
with patch(
"openai.resources.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
assert response.choices[0].text == real_completions_response.choices[0].text
# Replay
with patch("openai.resources.completions.AsyncCompletions.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
)
assert response.choices[0].text == real_completions_response.choices[0].text
mock_create_patch.assert_not_called()
async def test_live_mode(self, real_openai_chat_response): async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods.""" """Test that live mode passes through to original methods."""