mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
chore(recorder, tests): add support for openai /v1/models
This commit is contained in:
parent
48dda8bed8
commit
a14f42f1b8
1 changed files with 35 additions and 30 deletions
|
@ -7,6 +7,7 @@
|
||||||
from __future__ import annotations # for forward references
|
from __future__ import annotations # for forward references
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
@ -198,16 +199,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
||||||
|
|
||||||
Supported endpoints:
|
Supported endpoints:
|
||||||
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
||||||
- '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ]
|
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
||||||
Returns a list of unique identifiers or None if structure doesn't match.
|
Returns a list of unique identifiers or None if structure doesn't match.
|
||||||
"""
|
"""
|
||||||
body = response["body"]
|
items = response["body"]
|
||||||
if endpoint == "/api/tags":
|
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
|
||||||
items = body.get("models")
|
|
||||||
idents = [m.model for m in items]
|
|
||||||
else:
|
|
||||||
items = body.get("data")
|
|
||||||
idents = [m.id for m in items]
|
|
||||||
return sorted(set(idents))
|
return sorted(set(idents))
|
||||||
|
|
||||||
identifiers = _extract_model_identifiers()
|
identifiers = _extract_model_identifiers()
|
||||||
|
@ -219,28 +215,22 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
||||||
seen: dict[str, dict[str, Any]] = {}
|
seen: dict[str, dict[str, Any]] = {}
|
||||||
for rec in records:
|
for rec in records:
|
||||||
body = rec["response"]["body"]
|
body = rec["response"]["body"]
|
||||||
if endpoint == "/api/tags":
|
if endpoint == "/v1/models":
|
||||||
items = body.models
|
for m in body:
|
||||||
elif endpoint == "/v1/models":
|
|
||||||
items = body.data
|
|
||||||
else:
|
|
||||||
items = []
|
|
||||||
|
|
||||||
for m in items:
|
|
||||||
if endpoint == "/v1/models":
|
|
||||||
key = m.id
|
key = m.id
|
||||||
else:
|
seen[key] = m
|
||||||
|
elif endpoint == "/api/tags":
|
||||||
|
for m in body.models:
|
||||||
key = m.model
|
key = m.model
|
||||||
seen[key] = m
|
seen[key] = m
|
||||||
|
|
||||||
ordered = [seen[k] for k in sorted(seen.keys())]
|
ordered = [seen[k] for k in sorted(seen.keys())]
|
||||||
canonical = records[0]
|
canonical = records[0]
|
||||||
canonical_req = canonical.get("request", {})
|
canonical_req = canonical.get("request", {})
|
||||||
if isinstance(canonical_req, dict):
|
if isinstance(canonical_req, dict):
|
||||||
canonical_req["endpoint"] = endpoint
|
canonical_req["endpoint"] = endpoint
|
||||||
if endpoint == "/v1/models":
|
body = ordered
|
||||||
body = {"data": ordered, "object": "list"}
|
if endpoint == "/api/tags":
|
||||||
else:
|
|
||||||
from ollama import ListResponse
|
from ollama import ListResponse
|
||||||
|
|
||||||
body = ListResponse(models=ordered)
|
body = ListResponse(models=ordered)
|
||||||
|
@ -252,7 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
|
|
||||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||||
# Normal operation
|
# Normal operation
|
||||||
return await original_method(self, *args, **kwargs)
|
if inspect.iscoroutinefunction(original_method):
|
||||||
|
return await original_method(self, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
# Get base URL based on client type
|
# Get base URL based on client type
|
||||||
if client_type == "openai":
|
if client_type == "openai":
|
||||||
|
@ -282,7 +275,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
if recording:
|
if recording:
|
||||||
response_body = recording["response"]["body"]
|
response_body = recording["response"]["body"]
|
||||||
|
|
||||||
if recording["response"].get("is_streaming", False):
|
if recording["response"].get("is_streaming", False) or endpoint == "/v1/models":
|
||||||
|
|
||||||
async def replay_stream():
|
async def replay_stream():
|
||||||
for chunk in response_body:
|
for chunk in response_body:
|
||||||
|
@ -300,7 +293,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
elif _current_mode == InferenceMode.RECORD:
|
elif _current_mode == InferenceMode.RECORD:
|
||||||
response = await original_method(self, *args, **kwargs)
|
if inspect.iscoroutinefunction(original_method):
|
||||||
|
response = await original_method(self, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
response = original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"method": method,
|
"method": method,
|
||||||
|
@ -314,7 +310,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
# Determine if this is a streaming request based on request parameters
|
# Determine if this is a streaming request based on request parameters
|
||||||
is_streaming = body.get("stream", False)
|
is_streaming = body.get("stream", False)
|
||||||
|
|
||||||
if is_streaming:
|
if is_streaming or endpoint == "/v1/models":
|
||||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||||
chunks = []
|
chunks = []
|
||||||
|
@ -380,10 +376,19 @@ def patch_inference_clients():
|
||||||
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
|
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def patched_models_list(self, *args, **kwargs):
|
def patched_models_list(self, *args, **kwargs):
|
||||||
return await _patched_inference_method(
|
import asyncio
|
||||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
import concurrent.futures
|
||||||
)
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(
|
||||||
|
lambda: asyncio.run(
|
||||||
|
_patched_inference_method(
|
||||||
|
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return future.result()
|
||||||
|
|
||||||
# Apply OpenAI patches
|
# Apply OpenAI patches
|
||||||
AsyncChatCompletions.create = patched_chat_completions_create
|
AsyncChatCompletions.create = patched_chat_completions_create
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue