chore(recorder, tests): add support for openai /v1/models

This commit is contained in:
Matthew Farrellee 2025-09-12 12:12:12 -04:00
parent 48dda8bed8
commit a14f42f1b8

View file

@ -7,6 +7,7 @@
from __future__ import annotations # for forward references from __future__ import annotations # for forward references
import hashlib import hashlib
import inspect
import json import json
import os import os
from collections.abc import Generator from collections.abc import Generator
@ -198,16 +199,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
Supported endpoints: Supported endpoints:
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ] - '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
- '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ] - '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
Returns a list of unique identifiers or None if structure doesn't match. Returns a list of unique identifiers or None if structure doesn't match.
""" """
body = response["body"] items = response["body"]
if endpoint == "/api/tags": idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
items = body.get("models")
idents = [m.model for m in items]
else:
items = body.get("data")
idents = [m.id for m in items]
return sorted(set(idents)) return sorted(set(idents))
identifiers = _extract_model_identifiers() identifiers = _extract_model_identifiers()
@ -219,28 +215,22 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
seen: dict[str, dict[str, Any]] = {} seen: dict[str, dict[str, Any]] = {}
for rec in records: for rec in records:
body = rec["response"]["body"] body = rec["response"]["body"]
if endpoint == "/api/tags": if endpoint == "/v1/models":
items = body.models for m in body:
elif endpoint == "/v1/models":
items = body.data
else:
items = []
for m in items:
if endpoint == "/v1/models":
key = m.id key = m.id
else: seen[key] = m
elif endpoint == "/api/tags":
for m in body.models:
key = m.model key = m.model
seen[key] = m seen[key] = m
ordered = [seen[k] for k in sorted(seen.keys())] ordered = [seen[k] for k in sorted(seen.keys())]
canonical = records[0] canonical = records[0]
canonical_req = canonical.get("request", {}) canonical_req = canonical.get("request", {})
if isinstance(canonical_req, dict): if isinstance(canonical_req, dict):
canonical_req["endpoint"] = endpoint canonical_req["endpoint"] = endpoint
if endpoint == "/v1/models": body = ordered
body = {"data": ordered, "object": "list"} if endpoint == "/api/tags":
else:
from ollama import ListResponse from ollama import ListResponse
body = ListResponse(models=ordered) body = ListResponse(models=ordered)
@ -252,7 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
if _current_mode == InferenceMode.LIVE or _current_storage is None: if _current_mode == InferenceMode.LIVE or _current_storage is None:
# Normal operation # Normal operation
return await original_method(self, *args, **kwargs) if inspect.iscoroutinefunction(original_method):
return await original_method(self, *args, **kwargs)
else:
return original_method(self, *args, **kwargs)
# Get base URL based on client type # Get base URL based on client type
if client_type == "openai": if client_type == "openai":
@ -282,7 +275,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
if recording: if recording:
response_body = recording["response"]["body"] response_body = recording["response"]["body"]
if recording["response"].get("is_streaming", False): if recording["response"].get("is_streaming", False) or endpoint == "/v1/models":
async def replay_stream(): async def replay_stream():
for chunk in response_body: for chunk in response_body:
@ -300,7 +293,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
) )
elif _current_mode == InferenceMode.RECORD: elif _current_mode == InferenceMode.RECORD:
response = await original_method(self, *args, **kwargs) if inspect.iscoroutinefunction(original_method):
response = await original_method(self, *args, **kwargs)
else:
response = original_method(self, *args, **kwargs)
request_data = { request_data = {
"method": method, "method": method,
@ -314,7 +310,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
# Determine if this is a streaming request based on request parameters # Determine if this is a streaming request based on request parameters
is_streaming = body.get("stream", False) is_streaming = body.get("stream", False)
if is_streaming: if is_streaming or endpoint == "/v1/models":
# For streaming responses, we need to collect all chunks immediately before yielding # For streaming responses, we need to collect all chunks immediately before yielding
# This ensures the recording is saved even if the generator isn't fully consumed # This ensures the recording is saved even if the generator isn't fully consumed
chunks = [] chunks = []
@ -380,10 +376,19 @@ def patch_inference_clients():
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs _original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
) )
async def patched_models_list(self, *args, **kwargs): def patched_models_list(self, *args, **kwargs):
return await _patched_inference_method( import asyncio
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs import concurrent.futures
)
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
lambda: asyncio.run(
_patched_inference_method(
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
)
)
)
return future.result()
# Apply OpenAI patches # Apply OpenAI patches
AsyncChatCompletions.create = patched_chat_completions_create AsyncChatCompletions.create = patched_chat_completions_create