fix(inference): enable routing of models with provider_data alone (#3928)

This PR enables routing of fully qualified model IDs of the form
`provider_id/model_id` even when the models are not registered with the
Stack.

Here's the situation: assume a remote inference provider which works
only when users provide their own API keys via
`X-LlamaStack-Provider-Data` header. By definition, we cannot list
models and hence update our routing registry. But because we _require_ a
provider ID in the models now, we can identify which provider to route
to and let that provider decide.

Note that we still try to look up our registry since it may have a
pre-registered alias. Just that we don't outright fail when we are not
able to look it up.

Also, updated inference router so that the responses have the _exact_
model that the request had.

## Test Plan

Added an integration test

Closes #3929

---------

Co-authored-by: ehhuang <ehhuang@users.noreply.github.com>
This commit is contained in:
Ashwin Bharambe 2025-10-28 11:16:37 -07:00 committed by GitHub
parent 94b0592240
commit f88416ef87
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 216 additions and 63 deletions

View file

@ -161,8 +161,7 @@ def test_openai_embeddings_single_string(compat_client, client_with_models, embe
assert response.object == "list"
# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
assert response.model == embedding_model_id
assert len(response.data) == 1
assert response.data[0].object == "embedding"
assert response.data[0].index == 0
@ -186,8 +185,7 @@ def test_openai_embeddings_multiple_strings(compat_client, client_with_models, e
assert response.object == "list"
# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
assert response.model == embedding_model_id
assert len(response.data) == len(input_texts)
for i, embedding_data in enumerate(response.data):
@ -365,8 +363,7 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
# Validate response structure
assert response.object == "list"
# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
assert response.model == embedding_model_id
assert len(response.data) == len(input_texts)
# Validate each embedding in the batch

View file

@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test that models can be routed using provider_id/model_id format
when the provider is configured but the specific model is not registered.
This test validates the fix in src/llama_stack/core/routers/inference.py
that enables routing based on provider_data alone.
"""
from unittest.mock import AsyncMock, patch
import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.apis.datatypes import Api
from llama_stack.apis.inference.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionUsage,
OpenAIChoice,
)
from llama_stack.core.telemetry.telemetry import MetricEvent
class OpenAIChatCompletionWithMetrics(OpenAIChatCompletion):
metrics: list[MetricEvent] | None = None
def test_unregistered_model_routing_with_provider_data(client_with_models):
"""
Test that a model can be routed using provider_id/model_id format
even when the model is not explicitly registered, as long as the provider
is available.
This validates the fix where the router:
1. Tries to lookup model in routing table
2. If not found, splits model_id by "/" to extract provider_id and provider_resource_id
3. Routes directly to the provider with the provider_resource_id
Without the fix, this would raise ModelNotFoundError immediately.
With the fix, the routing succeeds and the request reaches the provider.
"""
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("Test requires library client for provider-level patching")
client = client_with_models
# Use a model format that follows provider_id/model_id convention
# We'll use anthropic as an example since it's a remote provider that
# benefits from this pattern
test_model_id = "anthropic/claude-3-5-sonnet-20241022"
# First, verify the model is NOT registered
registered_models = {m.identifier for m in client.models.list()}
assert test_model_id not in registered_models, f"Model {test_model_id} should not be pre-registered for this test"
# Check if anthropic provider is available in ci-tests
providers = {p.provider_id: p for p in client.providers.list()}
if "anthropic" not in providers:
pytest.skip("Anthropic provider not configured in ci-tests - cannot test unregistered model routing")
# Get the actual provider implementation from the library client's stack
inference_router = client.async_client.impls.get(Api.inference)
if not inference_router:
raise RuntimeError("No inference router found")
# The inference router's routing_table.impls_by_provider_id should have anthropic
# Let's patch the anthropic provider's openai_chat_completion method
# to avoid making real API calls
mock_response = OpenAIChatCompletionWithMetrics(
id="chatcmpl-test-123",
created=1234567890,
model="claude-3-5-sonnet-20241022",
choices=[
OpenAIChoice(
index=0,
finish_reason="stop",
message=OpenAIAssistantMessageParam(
content="Mocked response to test routing",
),
)
],
usage=OpenAIChatCompletionUsage(
prompt_tokens=5,
completion_tokens=10,
total_tokens=15,
),
)
# Get the routing table from the inference router
routing_table = inference_router.routing_table
# Patch the anthropic provider's openai_chat_completion method
anthropic_provider = routing_table.impls_by_provider_id.get("anthropic")
if not anthropic_provider:
raise RuntimeError("Anthropic provider not found in routing table even though it's in providers list")
with patch.object(
anthropic_provider,
"openai_chat_completion",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_method:
# Make the request with the unregistered model
response = client.chat.completions.create(
model=test_model_id,
messages=[
{
"role": "user",
"content": "Test message for unregistered model routing",
}
],
stream=False,
)
# Verify the provider's method was called
assert mock_method.called, "Provider's openai_chat_completion should have been called"
# Verify the response came through
assert response.choices[0].message.content == "Mocked response to test routing"
# Verify that the router passed the correct model to the provider
# (without the "anthropic/" prefix)
call_args = mock_method.call_args
params = call_args[0][0] # First positional argument is the params object
assert params.model == "claude-3-5-sonnet-20241022", (
f"Provider should receive model without provider prefix, got {params.model}"
)

View file

@ -64,10 +64,11 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
# Verify spans
spans = mock_otlp_collector.get_spans()
assert len(spans) == 5
# Expected spans: 1 root span + 3 autotraced method calls from routing/inference
assert len(spans) == 4, f"Expected 4 spans, got {len(spans)}"
# we only need this captured one time
logged_model_id = None
# Collect all model_ids found in spans
logged_model_ids = []
for span in spans:
attrs = span.attributes
@ -87,10 +88,10 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client,
args = json.loads(attrs["__args__"])
if "model_id" in args:
logged_model_id = args["model_id"]
logged_model_ids.append(args["model_id"])
assert logged_model_id is not None
assert logged_model_id == text_model_id
# At least one span should capture the fully qualified model ID
assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}"
# TODO: re-enable this once metrics get fixed
"""