mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
chore(recorder): add support for NOT_GIVEN
This commit is contained in:
parent
3de9ad0a87
commit
d37978508f
2 changed files with 67 additions and 3 deletions
|
@ -16,6 +16,8 @@ from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
|
from openai import NOT_GIVEN
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__, category="testing")
|
logger = get_logger(__name__, category="testing")
|
||||||
|
@ -250,6 +252,9 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
# Get base URL based on client type
|
# Get base URL based on client type
|
||||||
if client_type == "openai":
|
if client_type == "openai":
|
||||||
base_url = str(self._client.base_url)
|
base_url = str(self._client.base_url)
|
||||||
|
|
||||||
|
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
|
||||||
elif client_type == "ollama":
|
elif client_type == "ollama":
|
||||||
# Get base URL from the client (Ollama client uses host attribute)
|
# Get base URL from the client (Ollama client uses host attribute)
|
||||||
base_url = getattr(self, "host", "http://localhost:11434")
|
base_url = getattr(self, "host", "http://localhost:11434")
|
||||||
|
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai import AsyncOpenAI
|
from openai import NOT_GIVEN, AsyncOpenAI
|
||||||
from openai.types.model import Model as OpenAIModel
|
from openai.types.model import Model as OpenAIModel
|
||||||
|
|
||||||
# Import the real Pydantic response types instead of using Mocks
|
# Import the real Pydantic response types instead of using Mocks
|
||||||
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingData,
|
OpenAIEmbeddingData,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
|
@ -170,6 +171,7 @@ class TestInferenceRecording:
|
||||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
|
user=NOT_GIVEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the response was returned correctly
|
# Verify the response was returned correctly
|
||||||
|
@ -198,6 +200,7 @@ class TestInferenceRecording:
|
||||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
|
user=NOT_GIVEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now test replay mode - should not call the original method
|
# Now test replay mode - should not call the original method
|
||||||
|
@ -281,7 +284,11 @@ class TestInferenceRecording:
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
response = await client.embeddings.create(
|
response = await client.embeddings.create(
|
||||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
model=real_embeddings_response.model,
|
||||||
|
input=["Hello world", "Test embedding"],
|
||||||
|
encoding_format=NOT_GIVEN,
|
||||||
|
dimensions=NOT_GIVEN,
|
||||||
|
user=NOT_GIVEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(response.data) == 2
|
assert len(response.data) == 2
|
||||||
|
@ -292,7 +299,8 @@ class TestInferenceRecording:
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
response = await client.embeddings.create(
|
response = await client.embeddings.create(
|
||||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
model=real_embeddings_response.model,
|
||||||
|
input=["Hello world", "Test embedding"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify we got the recorded response
|
# Verify we got the recorded response
|
||||||
|
@ -302,6 +310,57 @@ class TestInferenceRecording:
|
||||||
# Verify original method was not called
|
# Verify original method was not called
|
||||||
mock_create_patch.assert_not_called()
|
mock_create_patch.assert_not_called()
|
||||||
|
|
||||||
|
async def test_completions_recording(self, temp_storage_dir):
|
||||||
|
real_completions_response = OpenAICompletion(
|
||||||
|
id="test_completion",
|
||||||
|
object="text_completion",
|
||||||
|
created=1234567890,
|
||||||
|
model="llama3.2:3b",
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"text": "Hello! I'm doing well, thank you for asking.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return real_completions_response
|
||||||
|
|
||||||
|
temp_storage_dir = temp_storage_dir / "test_completions_recording"
|
||||||
|
|
||||||
|
# Record
|
||||||
|
with patch(
|
||||||
|
"openai.resources.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
|
||||||
|
):
|
||||||
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.completions.create(
|
||||||
|
model=real_completions_response.model,
|
||||||
|
prompt="Hello, how are you?",
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
user=NOT_GIVEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||||
|
|
||||||
|
# Replay
|
||||||
|
with patch("openai.resources.completions.AsyncCompletions.create") as mock_create_patch:
|
||||||
|
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
response = await client.completions.create(
|
||||||
|
model=real_completions_response.model,
|
||||||
|
prompt="Hello, how are you?",
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||||
|
mock_create_patch.assert_not_called()
|
||||||
|
|
||||||
async def test_live_mode(self, real_openai_chat_response):
|
async def test_live_mode(self, real_openai_chat_response):
|
||||||
"""Test that live mode passes through to original methods."""
|
"""Test that live mode passes through to original methods."""
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue