chore(recorder, tests): add test for openai /v1/models (#3426)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 6s
Test External API and Providers / test-external (venv) (push) Failing after 5s
UI Tests / ui-tests (22) (push) Successful in 39s
Pre-commit / pre-commit (push) Successful in 1m19s

# What does this PR do?

- [x] adds a test for the recorder's handling of /v1/models
- [x] adds a fix for /v1/models handling

## Test Plan

ci
This commit is contained in:
Matthew Farrellee 2025-09-12 17:59:56 -04:00 committed by GitHub
parent f67081d2d6
commit 3de9ad0a87
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 79 additions and 32 deletions

View file

@ -7,6 +7,7 @@
from __future__ import annotations # for forward references
import hashlib
import inspect
import json
import os
from collections.abc import Generator
@ -198,16 +199,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
Supported endpoints:
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
- '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ]
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
Returns a list of unique identifiers or None if structure doesn't match.
"""
body = response["body"]
if endpoint == "/api/tags":
items = body.get("models")
idents = [m.model for m in items]
else:
items = body.get("data")
idents = [m.id for m in items]
items = response["body"]
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
return sorted(set(idents))
identifiers = _extract_model_identifiers()
@ -219,28 +215,22 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
seen: dict[str, dict[str, Any]] = {}
for rec in records:
body = rec["response"]["body"]
if endpoint == "/api/tags":
items = body.models
elif endpoint == "/v1/models":
items = body.data
else:
items = []
for m in items:
if endpoint == "/v1/models":
if endpoint == "/v1/models":
for m in body:
key = m.id
else:
seen[key] = m
elif endpoint == "/api/tags":
for m in body.models:
key = m.model
seen[key] = m
seen[key] = m
ordered = [seen[k] for k in sorted(seen.keys())]
canonical = records[0]
canonical_req = canonical.get("request", {})
if isinstance(canonical_req, dict):
canonical_req["endpoint"] = endpoint
if endpoint == "/v1/models":
body = {"data": ordered, "object": "list"}
else:
body = ordered
if endpoint == "/api/tags":
from ollama import ListResponse
body = ListResponse(models=ordered)
@ -252,7 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
if _current_mode == InferenceMode.LIVE or _current_storage is None:
# Normal operation
return await original_method(self, *args, **kwargs)
if inspect.iscoroutinefunction(original_method):
return await original_method(self, *args, **kwargs)
else:
return original_method(self, *args, **kwargs)
# Get base URL based on client type
if client_type == "openai":
@ -300,7 +293,14 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
)
elif _current_mode == InferenceMode.RECORD:
response = await original_method(self, *args, **kwargs)
if inspect.iscoroutinefunction(original_method):
response = await original_method(self, *args, **kwargs)
else:
response = original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
response = [m async for m in response]
request_data = {
"method": method,
@ -380,10 +380,14 @@ def patch_inference_clients():
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
)
async def patched_models_list(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
)
def patched_models_list(self, *args, **kwargs):
async def _iter():
for item in await _patched_inference_method(
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
):
yield item
return _iter()
# Apply OpenAI patches
AsyncChatCompletions.create = patched_chat_completions_create