mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
chore(recorder, tests): add support for openai /v1/models
This commit is contained in:
parent
48dda8bed8
commit
a14f42f1b8
1 changed files with 35 additions and 30 deletions
|
@ -7,6 +7,7 @@
|
|||
from __future__ import annotations # for forward references
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
@ -198,16 +199,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
|||
|
||||
Supported endpoints:
|
||||
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
||||
- '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ]
|
||||
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
||||
Returns a list of unique identifiers or None if structure doesn't match.
|
||||
"""
|
||||
body = response["body"]
|
||||
if endpoint == "/api/tags":
|
||||
items = body.get("models")
|
||||
idents = [m.model for m in items]
|
||||
else:
|
||||
items = body.get("data")
|
||||
idents = [m.id for m in items]
|
||||
items = response["body"]
|
||||
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
|
||||
return sorted(set(idents))
|
||||
|
||||
identifiers = _extract_model_identifiers()
|
||||
|
@ -219,28 +215,22 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
|||
seen: dict[str, dict[str, Any]] = {}
|
||||
for rec in records:
|
||||
body = rec["response"]["body"]
|
||||
if endpoint == "/api/tags":
|
||||
items = body.models
|
||||
elif endpoint == "/v1/models":
|
||||
items = body.data
|
||||
else:
|
||||
items = []
|
||||
|
||||
for m in items:
|
||||
if endpoint == "/v1/models":
|
||||
if endpoint == "/v1/models":
|
||||
for m in body:
|
||||
key = m.id
|
||||
else:
|
||||
seen[key] = m
|
||||
elif endpoint == "/api/tags":
|
||||
for m in body.models:
|
||||
key = m.model
|
||||
seen[key] = m
|
||||
seen[key] = m
|
||||
|
||||
ordered = [seen[k] for k in sorted(seen.keys())]
|
||||
canonical = records[0]
|
||||
canonical_req = canonical.get("request", {})
|
||||
if isinstance(canonical_req, dict):
|
||||
canonical_req["endpoint"] = endpoint
|
||||
if endpoint == "/v1/models":
|
||||
body = {"data": ordered, "object": "list"}
|
||||
else:
|
||||
body = ordered
|
||||
if endpoint == "/api/tags":
|
||||
from ollama import ListResponse
|
||||
|
||||
body = ListResponse(models=ordered)
|
||||
|
@ -252,7 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
|
||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, *args, **kwargs)
|
||||
if inspect.iscoroutinefunction(original_method):
|
||||
return await original_method(self, *args, **kwargs)
|
||||
else:
|
||||
return original_method(self, *args, **kwargs)
|
||||
|
||||
# Get base URL based on client type
|
||||
if client_type == "openai":
|
||||
|
@ -282,7 +275,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
if recording["response"].get("is_streaming", False):
|
||||
if recording["response"].get("is_streaming", False) or endpoint == "/v1/models":
|
||||
|
||||
async def replay_stream():
|
||||
for chunk in response_body:
|
||||
|
@ -300,7 +293,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
)
|
||||
|
||||
elif _current_mode == InferenceMode.RECORD:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
if inspect.iscoroutinefunction(original_method):
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
else:
|
||||
response = original_method(self, *args, **kwargs)
|
||||
|
||||
request_data = {
|
||||
"method": method,
|
||||
|
@ -314,7 +310,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
# Determine if this is a streaming request based on request parameters
|
||||
is_streaming = body.get("stream", False)
|
||||
|
||||
if is_streaming:
|
||||
if is_streaming or endpoint == "/v1/models":
|
||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||
chunks = []
|
||||
|
@ -380,10 +376,19 @@ def patch_inference_clients():
|
|||
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_models_list(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||
)
|
||||
def patched_models_list(self, *args, **kwargs):
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: asyncio.run(
|
||||
_patched_inference_method(
|
||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||
)
|
||||
)
|
||||
)
|
||||
return future.result()
|
||||
|
||||
# Apply OpenAI patches
|
||||
AsyncChatCompletions.create = patched_chat_completions_create
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue