feat!: standardize base_url for inference (#4177)

# What does this PR do?

Completes #3732 by removing runtime URL transformations and requiring
users to provide full URLs in configuration. All providers now use
'base_url' consistently and respect the exact URL provided without
appending paths like /v1 or /openai/v1 at runtime.

BREAKING CHANGE: Users must update configs to include full URL paths
(e.g., http://localhost:11434/v1 instead of http://localhost:11434).

Closes #3732 

## Test Plan

Existing tests should pass even with the URL changes, due to default
URLs being altered.

Add unit test to enforce URL standardization across remote inference
providers (verifies all use 'base_url' field with HttpUrl | None type)

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-11-19 11:44:28 -05:00 committed by GitHub
parent 91f1b352b4
commit d5cd0eea14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
67 changed files with 282 additions and 227 deletions

View file

@ -50,7 +50,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
name="ollama",
description="Local Ollama provider with text + safety models",
env={
"OLLAMA_URL": "http://0.0.0.0:11434",
"OLLAMA_URL": "http://0.0.0.0:11434/v1",
"SAFETY_MODEL": "ollama/llama-guard3:1b",
},
defaults={
@ -64,7 +64,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
name="ollama",
description="Local Ollama provider with a vision model",
env={
"OLLAMA_URL": "http://0.0.0.0:11434",
"OLLAMA_URL": "http://0.0.0.0:11434/v1",
},
defaults={
"vision_model": "ollama/llama3.2-vision:11b",
@ -75,7 +75,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
name="ollama-postgres",
description="Server-mode tests with Postgres-backed persistence",
env={
"OLLAMA_URL": "http://0.0.0.0:11434",
"OLLAMA_URL": "http://0.0.0.0:11434/v1",
"SAFETY_MODEL": "ollama/llama-guard3:1b",
"POSTGRES_HOST": "127.0.0.1",
"POSTGRES_PORT": "5432",

View file

@ -120,7 +120,7 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
VLLMInferenceAdapter,
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
{
"url": "http://fake",
"base_url": "http://fake",
},
),
],
@ -153,7 +153,7 @@ def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_valid
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
assumption that there is an OpenAI-compatible client object."""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter = adapter_cls(config=config_cls(base_url="http://fake"))
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator

View file

@ -40,7 +40,7 @@ from llama_stack_api import (
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config=config)
inference_adapter.model_store = AsyncMock()
await inference_adapter.initialize()
@ -204,7 +204,7 @@ async def test_vllm_completion_extra_body():
via extra_body to the underlying OpenAI client through the InferenceRouter.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
@ -277,7 +277,7 @@ async def test_vllm_chat_completion_extra_body():
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()

View file

@ -146,7 +146,7 @@ async def test_hosted_model_not_in_endpoint_mapping():
async def test_self_hosted_ignores_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
config=NVIDIAConfig(base_url="http://localhost:8000", api_key=None),
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
)
mock_session = MockSession(MockResponse())

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import get_args, get_origin
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, HttpUrl
from llama_stack.core.distribution import get_provider_registry, providable_apis
from llama_stack.core.utils.dynamic import instantiate_class_type
@ -41,3 +43,55 @@ class TestProviderConfigurations:
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
def test_remote_inference_url_standardization(self):
"""Verify all remote inference providers use standardized base_url configuration."""
provider_registry = get_provider_registry()
inference_providers = provider_registry.get("inference", {})
# Filter for remote providers only
remote_providers = {k: v for k, v in inference_providers.items() if k.startswith("remote::")}
failures = []
for provider_type, provider_spec in remote_providers.items():
try:
config_class_name = provider_spec.config_class
config_type = instantiate_class_type(config_class_name)
# Check that config has base_url field (not url)
if hasattr(config_type, "model_fields"):
fields = config_type.model_fields
# Should NOT have 'url' field (old pattern)
if "url" in fields:
failures.append(
f"{provider_type}: Uses deprecated 'url' field instead of 'base_url'. "
f"Please rename to 'base_url' for consistency."
)
# Should have 'base_url' field with HttpUrl | None type
if "base_url" in fields:
field_info = fields["base_url"]
annotation = field_info.annotation
# Check if it's HttpUrl or HttpUrl | None
# get_origin() returns Union for (X | Y), None for plain types
# get_args() returns the types inside Union, e.g. (HttpUrl, NoneType)
is_valid = False
if get_origin(annotation) is not None: # It's a Union/Optional
if HttpUrl in get_args(annotation):
is_valid = True
elif annotation == HttpUrl: # Plain HttpUrl without | None
is_valid = True
if not is_valid:
failures.append(
f"{provider_type}: base_url field has incorrect type annotation. "
f"Expected 'HttpUrl | None', got '{annotation}'"
)
except Exception as e:
failures.append(f"{provider_type}: Error checking URL standardization: {str(e)}")
if failures:
pytest.fail("URL standardization violations found:\n" + "\n".join(f" - {f}" for f in failures))