mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(inference): enable routing of models with provider_data alone (#3928)
This PR enables routing of fully qualified model IDs of the form `provider_id/model_id` even when the models are not registered with the Stack. Here's the situation: assume a remote inference provider which works only when users provide their own API keys via `X-LlamaStack-Provider-Data` header. By definition, we cannot list models and hence update our routing registry. But because we _require_ a provider ID in the models now, we can identify which provider to route to and let that provider decide. Note that we still try to look up our registry since it may have a pre-registered alias. Just that we don't outright fail when we are not able to look it up. Also, updated inference router so that the responses have the _exact_ model that the request had. ## Test Plan Added an integration test Closes #3929 --------- Co-authored-by: ehhuang <ehhuang@users.noreply.github.com>
This commit is contained in:
parent
94b0592240
commit
f88416ef87
6 changed files with 216 additions and 63 deletions
|
|
@ -110,7 +110,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
) -> list[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
|
|
@ -118,7 +119,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total number of tokens used
|
||||
model: Model object containing model_id and provider_id
|
||||
fully_qualified_model_id:
|
||||
provider_id: The provider identifier
|
||||
|
||||
Returns:
|
||||
List of MetricEvent objects with token usage metrics
|
||||
|
|
@ -144,8 +146,8 @@ class InferenceRouter(Inference):
|
|||
timestamp=datetime.now(UTC),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
"model_id": fully_qualified_model_id,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -158,7 +160,9 @@ class InferenceRouter(Inference):
|
|||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> list[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id
|
||||
)
|
||||
if self.telemetry_enabled:
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
|
@ -178,14 +182,25 @@ class InferenceRouter(Inference):
|
|||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
|
||||
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
||||
if model:
|
||||
if model.model_type != expected_model_type:
|
||||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model.identifier)
|
||||
return provider, model.provider_resource_id
|
||||
|
||||
splits = model_id.split("/", maxsplit=1)
|
||||
if len(splits) != 2:
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
provider_id, provider_resource_id = splits
|
||||
if provider_id not in self.routing_table.impls_by_provider_id:
|
||||
logger.warning(f"Provider {provider_id} not found for model {model_id}")
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
|
|
@ -195,14 +210,8 @@ class InferenceRouter(Inference):
|
|||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
logger.debug(f"InferenceRouter.rerank: {model}")
|
||||
model_obj = await self._get_model(model, ModelType.rerank)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.rerank(
|
||||
model=model_obj.identifier,
|
||||
query=query,
|
||||
items=items,
|
||||
max_num_results=max_num_results,
|
||||
)
|
||||
provider, provider_resource_id = await self._get_model_provider(model, ModelType.rerank)
|
||||
return await provider.rerank(provider_resource_id, query, items, max_num_results)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
@ -211,24 +220,24 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
request_model_id = params.model
|
||||
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
|
||||
params.model = provider_resource_id
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
|
||||
response = await provider.openai_completion(params)
|
||||
response.model = request_model_id
|
||||
if self.telemetry_enabled:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
|
@ -246,7 +255,9 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
request_model_id = params.model
|
||||
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
|
||||
params.model = provider_resource_id
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
|
|
@ -264,10 +275,6 @@ class InferenceRouter(Inference):
|
|||
params.tool_choice = None
|
||||
params.tools = None
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if params.stream:
|
||||
response_stream = await provider.openai_chat_completion(params)
|
||||
|
||||
|
|
@ -275,11 +282,13 @@ class InferenceRouter(Inference):
|
|||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
messages=params.messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
response.model = request_model_id
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
|
|
@ -290,7 +299,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
|
@ -307,13 +317,13 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.embedding)
|
||||
request_model_id = params.model
|
||||
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.embedding)
|
||||
params.model = provider_resource_id
|
||||
|
||||
# Update model to use resolved identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_embeddings(params)
|
||||
response = await provider.openai_embeddings(params)
|
||||
response.model = request_model_id
|
||||
return response
|
||||
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
|
|
@ -369,7 +379,8 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
response,
|
||||
prompt_tokens,
|
||||
model,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
completion_text = ""
|
||||
|
|
@ -407,7 +418,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in [
|
||||
|
|
@ -427,7 +439,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
|
|
@ -439,7 +452,8 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
response: ChatCompletionResponse | CompletionResponse,
|
||||
prompt_tokens,
|
||||
model,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
):
|
||||
if isinstance(response, ChatCompletionResponse):
|
||||
|
|
@ -456,7 +470,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
|
|
@ -470,14 +485,16 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||
self,
|
||||
response: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
model: Model,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
|
||||
|
|
@ -497,6 +514,8 @@ class InferenceRouter(Inference):
|
|||
if created is None and chunk.created:
|
||||
created = chunk.created
|
||||
|
||||
chunk.model = fully_qualified_model_id
|
||||
|
||||
# Accumulate choice data for final assembly
|
||||
if chunk.choices:
|
||||
for choice_delta in chunk.choices:
|
||||
|
|
@ -553,7 +572,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
model=model,
|
||||
model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
|
@ -601,7 +621,7 @@ class InferenceRouter(Inference):
|
|||
id=id,
|
||||
choices=assembled_choices,
|
||||
created=created or int(time.time()),
|
||||
model=model.identifier,
|
||||
model=fully_qualified_model_id,
|
||||
object="chat.completion",
|
||||
)
|
||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||
|
|
|
|||
|
|
@ -46,8 +46,7 @@ class SentenceTransformerEmbeddingMixin:
|
|||
raise ValueError("Empty list not supported")
|
||||
|
||||
# Get the model and generate embeddings
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||
embedding_model = await self._load_sentence_transformer_model(params.model)
|
||||
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
|
||||
|
||||
# Convert embeddings to the requested format
|
||||
|
|
|
|||
|
|
@ -226,8 +226,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
:param model: The registered model name/identifier
|
||||
:return: The provider-specific model ID (e.g., "gpt-4")
|
||||
"""
|
||||
# Look up the registered model to get the provider-specific model ID
|
||||
# self.model_store is injected by the distribution system at runtime
|
||||
if not await self.model_store.has_model(model): # type: ignore[attr-defined]
|
||||
return model
|
||||
|
||||
# Look up the registered model to get the provider-specific model ID
|
||||
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
|
||||
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
|
||||
if model_obj.provider_resource_id is None:
|
||||
|
|
|
|||
|
|
@ -161,8 +161,7 @@ def test_openai_embeddings_single_string(compat_client, client_with_models, embe
|
|||
|
||||
assert response.object == "list"
|
||||
|
||||
# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
|
||||
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
|
||||
assert response.model == embedding_model_id
|
||||
assert len(response.data) == 1
|
||||
assert response.data[0].object == "embedding"
|
||||
assert response.data[0].index == 0
|
||||
|
|
@ -186,8 +185,7 @@ def test_openai_embeddings_multiple_strings(compat_client, client_with_models, e
|
|||
|
||||
assert response.object == "list"
|
||||
|
||||
# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
|
||||
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
|
||||
assert response.model == embedding_model_id
|
||||
assert len(response.data) == len(input_texts)
|
||||
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
|
|
@ -365,8 +363,7 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
|
|||
# Validate response structure
|
||||
assert response.object == "list"
|
||||
|
||||
# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
|
||||
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
|
||||
assert response.model == embedding_model_id
|
||||
assert len(response.data) == len(input_texts)
|
||||
|
||||
# Validate each embedding in the batch
|
||||
|
|
|
|||
133
tests/integration/inference/test_provider_data_routing.py
Normal file
133
tests/integration/inference/test_provider_data_routing.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test that models can be routed using provider_id/model_id format
|
||||
when the provider is configured but the specific model is not registered.
|
||||
|
||||
This test validates the fix in src/llama_stack/core/routers/inference.py
|
||||
that enables routing based on provider_data alone.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAIChoice,
|
||||
)
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
|
||||
|
||||
class OpenAIChatCompletionWithMetrics(OpenAIChatCompletion):
|
||||
metrics: list[MetricEvent] | None = None
|
||||
|
||||
|
||||
def test_unregistered_model_routing_with_provider_data(client_with_models):
|
||||
"""
|
||||
Test that a model can be routed using provider_id/model_id format
|
||||
even when the model is not explicitly registered, as long as the provider
|
||||
is available.
|
||||
|
||||
This validates the fix where the router:
|
||||
1. Tries to lookup model in routing table
|
||||
2. If not found, splits model_id by "/" to extract provider_id and provider_resource_id
|
||||
3. Routes directly to the provider with the provider_resource_id
|
||||
|
||||
Without the fix, this would raise ModelNotFoundError immediately.
|
||||
With the fix, the routing succeeds and the request reaches the provider.
|
||||
"""
|
||||
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Test requires library client for provider-level patching")
|
||||
|
||||
client = client_with_models
|
||||
|
||||
# Use a model format that follows provider_id/model_id convention
|
||||
# We'll use anthropic as an example since it's a remote provider that
|
||||
# benefits from this pattern
|
||||
test_model_id = "anthropic/claude-3-5-sonnet-20241022"
|
||||
|
||||
# First, verify the model is NOT registered
|
||||
registered_models = {m.identifier for m in client.models.list()}
|
||||
assert test_model_id not in registered_models, f"Model {test_model_id} should not be pre-registered for this test"
|
||||
|
||||
# Check if anthropic provider is available in ci-tests
|
||||
providers = {p.provider_id: p for p in client.providers.list()}
|
||||
if "anthropic" not in providers:
|
||||
pytest.skip("Anthropic provider not configured in ci-tests - cannot test unregistered model routing")
|
||||
|
||||
# Get the actual provider implementation from the library client's stack
|
||||
inference_router = client.async_client.impls.get(Api.inference)
|
||||
if not inference_router:
|
||||
raise RuntimeError("No inference router found")
|
||||
|
||||
# The inference router's routing_table.impls_by_provider_id should have anthropic
|
||||
# Let's patch the anthropic provider's openai_chat_completion method
|
||||
# to avoid making real API calls
|
||||
mock_response = OpenAIChatCompletionWithMetrics(
|
||||
id="chatcmpl-test-123",
|
||||
created=1234567890,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content="Mocked response to test routing",
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=OpenAIChatCompletionUsage(
|
||||
prompt_tokens=5,
|
||||
completion_tokens=10,
|
||||
total_tokens=15,
|
||||
),
|
||||
)
|
||||
|
||||
# Get the routing table from the inference router
|
||||
routing_table = inference_router.routing_table
|
||||
|
||||
# Patch the anthropic provider's openai_chat_completion method
|
||||
anthropic_provider = routing_table.impls_by_provider_id.get("anthropic")
|
||||
if not anthropic_provider:
|
||||
raise RuntimeError("Anthropic provider not found in routing table even though it's in providers list")
|
||||
|
||||
with patch.object(
|
||||
anthropic_provider,
|
||||
"openai_chat_completion",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_method:
|
||||
# Make the request with the unregistered model
|
||||
response = client.chat.completions.create(
|
||||
model=test_model_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Test message for unregistered model routing",
|
||||
}
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify the provider's method was called
|
||||
assert mock_method.called, "Provider's openai_chat_completion should have been called"
|
||||
|
||||
# Verify the response came through
|
||||
assert response.choices[0].message.content == "Mocked response to test routing"
|
||||
|
||||
# Verify that the router passed the correct model to the provider
|
||||
# (without the "anthropic/" prefix)
|
||||
call_args = mock_method.call_args
|
||||
params = call_args[0][0] # First positional argument is the params object
|
||||
assert params.model == "claude-3-5-sonnet-20241022", (
|
||||
f"Provider should receive model without provider prefix, got {params.model}"
|
||||
)
|
||||
|
|
@ -64,10 +64,11 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
|
|||
|
||||
# Verify spans
|
||||
spans = mock_otlp_collector.get_spans()
|
||||
assert len(spans) == 5
|
||||
# Expected spans: 1 root span + 3 autotraced method calls from routing/inference
|
||||
assert len(spans) == 4, f"Expected 4 spans, got {len(spans)}"
|
||||
|
||||
# we only need this captured one time
|
||||
logged_model_id = None
|
||||
# Collect all model_ids found in spans
|
||||
logged_model_ids = []
|
||||
|
||||
for span in spans:
|
||||
attrs = span.attributes
|
||||
|
|
@ -87,10 +88,10 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
|
|||
|
||||
args = json.loads(attrs["__args__"])
|
||||
if "model_id" in args:
|
||||
logged_model_id = args["model_id"]
|
||||
logged_model_ids.append(args["model_id"])
|
||||
|
||||
assert logged_model_id is not None
|
||||
assert logged_model_id == text_model_id
|
||||
# At least one span should capture the fully qualified model ID
|
||||
assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}"
|
||||
|
||||
# TODO: re-enable this once metrics get fixed
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue