mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
chore(recorder, tests): add test for openai /v1/models (#3426)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 6s
Test External API and Providers / test-external (venv) (push) Failing after 5s
UI Tests / ui-tests (22) (push) Successful in 39s
Pre-commit / pre-commit (push) Successful in 1m19s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 6s
Test External API and Providers / test-external (venv) (push) Failing after 5s
UI Tests / ui-tests (22) (push) Successful in 39s
Pre-commit / pre-commit (push) Successful in 1m19s
# What does this PR do? - [x] adds a test for the recorder's handling of /v1/models - [x] adds a fix for /v1/models handling ## Test Plan ci
This commit is contained in:
parent
f67081d2d6
commit
3de9ad0a87
2 changed files with 79 additions and 32 deletions
|
@ -7,6 +7,7 @@
|
||||||
from __future__ import annotations # for forward references
|
from __future__ import annotations # for forward references
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
@ -198,16 +199,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
||||||
|
|
||||||
Supported endpoints:
|
Supported endpoints:
|
||||||
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
||||||
- '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ]
|
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
||||||
Returns a list of unique identifiers or None if structure doesn't match.
|
Returns a list of unique identifiers or None if structure doesn't match.
|
||||||
"""
|
"""
|
||||||
body = response["body"]
|
items = response["body"]
|
||||||
if endpoint == "/api/tags":
|
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
|
||||||
items = body.get("models")
|
|
||||||
idents = [m.model for m in items]
|
|
||||||
else:
|
|
||||||
items = body.get("data")
|
|
||||||
idents = [m.id for m in items]
|
|
||||||
return sorted(set(idents))
|
return sorted(set(idents))
|
||||||
|
|
||||||
identifiers = _extract_model_identifiers()
|
identifiers = _extract_model_identifiers()
|
||||||
|
@ -219,28 +215,22 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
||||||
seen: dict[str, dict[str, Any]] = {}
|
seen: dict[str, dict[str, Any]] = {}
|
||||||
for rec in records:
|
for rec in records:
|
||||||
body = rec["response"]["body"]
|
body = rec["response"]["body"]
|
||||||
if endpoint == "/api/tags":
|
if endpoint == "/v1/models":
|
||||||
items = body.models
|
for m in body:
|
||||||
elif endpoint == "/v1/models":
|
|
||||||
items = body.data
|
|
||||||
else:
|
|
||||||
items = []
|
|
||||||
|
|
||||||
for m in items:
|
|
||||||
if endpoint == "/v1/models":
|
|
||||||
key = m.id
|
key = m.id
|
||||||
else:
|
seen[key] = m
|
||||||
|
elif endpoint == "/api/tags":
|
||||||
|
for m in body.models:
|
||||||
key = m.model
|
key = m.model
|
||||||
seen[key] = m
|
seen[key] = m
|
||||||
|
|
||||||
ordered = [seen[k] for k in sorted(seen.keys())]
|
ordered = [seen[k] for k in sorted(seen.keys())]
|
||||||
canonical = records[0]
|
canonical = records[0]
|
||||||
canonical_req = canonical.get("request", {})
|
canonical_req = canonical.get("request", {})
|
||||||
if isinstance(canonical_req, dict):
|
if isinstance(canonical_req, dict):
|
||||||
canonical_req["endpoint"] = endpoint
|
canonical_req["endpoint"] = endpoint
|
||||||
if endpoint == "/v1/models":
|
body = ordered
|
||||||
body = {"data": ordered, "object": "list"}
|
if endpoint == "/api/tags":
|
||||||
else:
|
|
||||||
from ollama import ListResponse
|
from ollama import ListResponse
|
||||||
|
|
||||||
body = ListResponse(models=ordered)
|
body = ListResponse(models=ordered)
|
||||||
|
@ -252,7 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
|
|
||||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||||
# Normal operation
|
# Normal operation
|
||||||
return await original_method(self, *args, **kwargs)
|
if inspect.iscoroutinefunction(original_method):
|
||||||
|
return await original_method(self, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
# Get base URL based on client type
|
# Get base URL based on client type
|
||||||
if client_type == "openai":
|
if client_type == "openai":
|
||||||
|
@ -300,7 +293,14 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
elif _current_mode == InferenceMode.RECORD:
|
elif _current_mode == InferenceMode.RECORD:
|
||||||
response = await original_method(self, *args, **kwargs)
|
if inspect.iscoroutinefunction(original_method):
|
||||||
|
response = await original_method(self, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
response = original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# we want to store the result of the iterator, not the iterator itself
|
||||||
|
if endpoint == "/v1/models":
|
||||||
|
response = [m async for m in response]
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"method": method,
|
"method": method,
|
||||||
|
@ -380,10 +380,14 @@ def patch_inference_clients():
|
||||||
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
|
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def patched_models_list(self, *args, **kwargs):
|
def patched_models_list(self, *args, **kwargs):
|
||||||
return await _patched_inference_method(
|
async def _iter():
|
||||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
for item in await _patched_inference_method(
|
||||||
)
|
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||||
|
):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
return _iter()
|
||||||
|
|
||||||
# Apply OpenAI patches
|
# Apply OpenAI patches
|
||||||
AsyncChatCompletions.create = patched_chat_completions_create
|
AsyncChatCompletions.create = patched_chat_completions_create
|
||||||
|
|
|
@ -6,10 +6,11 @@
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
from openai.types.model import Model as OpenAIModel
|
||||||
|
|
||||||
# Import the real Pydantic response types instead of using Mocks
|
# Import the real Pydantic response types instead of using Mocks
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -158,7 +159,9 @@ class TestInferenceRecording:
|
||||||
return real_openai_chat_response
|
return real_openai_chat_response
|
||||||
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
with patch(
|
||||||
|
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
|
||||||
|
):
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
@ -184,7 +187,9 @@ class TestInferenceRecording:
|
||||||
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||||
# First, record a response
|
# First, record a response
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
with patch(
|
||||||
|
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
|
||||||
|
):
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
@ -213,6 +218,42 @@ class TestInferenceRecording:
|
||||||
# Verify the original method was NOT called
|
# Verify the original method was NOT called
|
||||||
mock_create_patch.assert_not_called()
|
mock_create_patch.assert_not_called()
|
||||||
|
|
||||||
|
async def test_replay_mode_models(self, temp_storage_dir):
|
||||||
|
"""Test that replay mode returns stored responses without making real model listing calls."""
|
||||||
|
|
||||||
|
async def _async_iterator(models):
|
||||||
|
for model in models:
|
||||||
|
yield model
|
||||||
|
|
||||||
|
models = [
|
||||||
|
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
||||||
|
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_ids = {m.id for m in models}
|
||||||
|
|
||||||
|
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
|
||||||
|
|
||||||
|
# baseline - mock works without recording
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||||
|
assert {m.id async for m in client.models.list()} == expected_ids
|
||||||
|
client.models._get_api_list.assert_called_once()
|
||||||
|
|
||||||
|
# record the call
|
||||||
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||||
|
assert {m.id async for m in client.models.list()} == expected_ids
|
||||||
|
client.models._get_api_list.assert_called_once()
|
||||||
|
|
||||||
|
# replay the call
|
||||||
|
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||||
|
assert {m.id async for m in client.models.list()} == expected_ids
|
||||||
|
client.models._get_api_list.assert_not_called()
|
||||||
|
|
||||||
async def test_replay_missing_recording(self, temp_storage_dir):
|
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||||
"""Test that replay mode fails when no recording is found."""
|
"""Test that replay mode fails when no recording is found."""
|
||||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||||
|
@ -233,7 +274,9 @@ class TestInferenceRecording:
|
||||||
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||||
# Record
|
# Record
|
||||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
with patch(
|
||||||
|
"openai.resources.embeddings.AsyncEmbeddings.create", new_callable=AsyncMock, side_effect=mock_create
|
||||||
|
):
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue