feat!: standardize base_url for inference (#4177)

# What does this PR do?

Completes #3732 by removing runtime URL transformations and requiring
users to provide full URLs in configuration. All providers now use
'base_url' consistently and respect the exact URL provided without
appending paths like /v1 or /openai/v1 at runtime.

BREAKING CHANGE: Users must update configs to include full URL paths
(e.g., http://localhost:11434/v1 instead of http://localhost:11434).

Closes #3732 

## Test Plan

Existing tests should pass even with the URL changes, due to default
URLs being altered.

Add unit test to enforce URL standardization across remote inference
providers (verifies all use 'base_url' field with HttpUrl | None type)

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-11-19 11:44:28 -05:00 committed by GitHub
parent 91f1b352b4
commit d5cd0eea14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
67 changed files with 282 additions and 227 deletions

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -16,9 +16,8 @@ providers:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: nvidia
provider_type: remote::nvidia
config:

View file

@ -16,9 +16,8 @@ providers:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
vector_io:
- provider_id: faiss
provider_type: inline::faiss

View file

@ -27,12 +27,12 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
vector_io:
- provider_id: sqlite-vec

View file

@ -11,7 +11,7 @@ providers:
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=http://localhost:8000/v1}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -15,7 +15,7 @@ providers:
- provider_id: watsonx
provider_type: remote::watsonx
config:
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=}
vector_io:

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
@ -22,4 +20,4 @@ class AzureInferenceAdapter(OpenAIMixin):
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")
return str(self.config.base_url)

View file

@ -32,8 +32,9 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(RemoteInferenceProviderConfig):
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
base_url: HttpUrl | None = Field(
default=None,
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1)",
)
api_version: str | None = Field(
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
@ -48,14 +49,14 @@ class AzureConfig(RemoteInferenceProviderConfig):
def sample_run_config(
cls,
api_key: str = "${env.AZURE_API_KEY:=}",
api_base: str = "${env.AZURE_API_BASE:=}",
base_url: str = "${env.AZURE_API_BASE:=}",
api_version: str = "${env.AZURE_API_VERSION:=}",
api_type: str = "${env.AZURE_API_TYPE:=}",
**kwargs,
) -> dict[str, Any]:
return {
"api_key": api_key,
"api_base": api_base,
"base_url": base_url,
"api_version": api_version,
"api_type": api_type,
}

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import (
OpenAIEmbeddingsRequestWithExtraBody,
@ -21,7 +19,7 @@ class CerebrasInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "cerebras_api_key"
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
return str(self.config.base_url)
async def openai_embeddings(
self,

View file

@ -7,12 +7,12 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
DEFAULT_BASE_URL = "https://api.cerebras.ai/v1"
class CerebrasProviderDataValidator(BaseModel):
@ -24,8 +24,8 @@ class CerebrasProviderDataValidator(BaseModel):
@json_schema_type
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
base_url: HttpUrl | None = Field(
default=HttpUrl(os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL)),
description="Base URL for the Cerebras API",
)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,9 +21,9 @@ class DatabricksProviderDataValidator(BaseModel):
@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
description="The URL for the Databricks model serving endpoint (should include /serving-endpoints path)",
)
auth_credential: SecretStr | None = Field(
default=None,
@ -34,11 +34,11 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_HOST:=}",
base_url: str = "${env.DATABRICKS_HOST:=}",
api_token: str = "${env.DATABRICKS_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
"api_token": api_token,
}

View file

@ -29,15 +29,21 @@ class DatabricksInferenceAdapter(OpenAIMixin):
}
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
return str(self.config.base_url)
async def list_provider_model_ids(self) -> Iterable[str]:
# Filter out None values from endpoint names
api_token = self._get_api_key_from_config_or_provider_data()
# WorkspaceClient expects base host without /serving-endpoints suffix
base_url_str = str(self.config.base_url)
if base_url_str.endswith("/serving-endpoints"):
host = base_url_str[:-18] # Remove '/serving-endpoints'
else:
host = base_url_str
return [
endpoint.name # type: ignore[misc]
for endpoint in WorkspaceClient(
host=self.config.url, token=api_token
host=host, token=api_token
).serving_endpoints.list() # TODO: this is not async
]

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
@json_schema_type
class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.fireworks.ai/inference/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.fireworks.ai/inference/v1"),
description="The URL for the Fireworks server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"base_url": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,
}

View file

@ -24,4 +24,4 @@ class FireworksInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "fireworks_api_key"
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.groq.com",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.groq.com/openai/v1"),
description="The URL for the Groq AI server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.groq.com",
"base_url": "https://api.groq.com/openai/v1",
"api_key": api_key,
}

View file

@ -15,4 +15,4 @@ class GroqInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "groq_api_key"
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"
return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type
class LlamaCompatConfig(RemoteInferenceProviderConfig):
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.llama.com/compat/v1/"),
description="The URL for the Llama API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
"base_url": "https://api.llama.com/compat/v1/",
"api_key": api_key,
}

View file

@ -31,7 +31,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
:return: The Llama API base URL
"""
return self.config.openai_compat_api_base
return str(self.config.base_url)
async def openai_completion(
self,

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -44,18 +44,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
URL of your running NVIDIA NIM and do not need to set the api_key.
"""
url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
base_url: HttpUrl | None = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
description="A base url for accessing the NVIDIA NIM",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
rerank_model_to_url: dict[str, str] = Field(
default_factory=lambda: {
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
@ -68,13 +64,11 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(
cls,
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
base_url: HttpUrl | None = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}",
api_key: str = "${env.NVIDIA_API_KEY:=}",
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
**kwargs,
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
"api_key": api_key,
"append_api_version": append_api_version,
}

View file

@ -44,7 +44,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
}
async def initialize(self) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.base_url})...")
if _is_nvidia_hosted(self.config):
if not self.config.auth_credential:
@ -72,7 +72,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API base URL
"""
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
return str(self.config.base_url)
async def list_provider_model_ids(self) -> Iterable[str]:
"""

View file

@ -8,4 +8,4 @@ from . import NVIDIAConfig
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url
return "integrate.api.nvidia.com" in str(config.base_url)

View file

@ -6,20 +6,22 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
DEFAULT_OLLAMA_URL = "http://localhost:11434/v1"
class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL
base_url: HttpUrl | None = Field(default=HttpUrl(DEFAULT_OLLAMA_URL))
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
def sample_run_config(
cls, base_url: str = "${env.OLLAMA_URL:=http://localhost:11434/v1}", **kwargs
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
}

View file

@ -55,17 +55,23 @@ class OllamaInferenceAdapter(OpenAIMixin):
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
# Ollama client expects base URL without /v1 suffix
base_url_str = str(self.config.base_url)
if base_url_str.endswith("/v1"):
host = base_url_str[:-3]
else:
host = base_url_str
self._clients[loop] = AsyncOllamaClient(host=host)
return self._clients[loop]
def get_api_key(self):
return "NO KEY REQUIRED"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
return str(self.config.base_url)
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
logger.info(f"checking connectivity to Ollama at `{self.config.base_url}`...")
r = await self.health()
if r["status"] == HealthStatus.ERROR:
logger.warning(

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,8 +21,8 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default="https://api.openai.com/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.openai.com/v1"),
description="Base URL for OpenAI API",
)

View file

@ -35,4 +35,4 @@ class OpenAIInferenceAdapter(OpenAIMixin):
Returns the OpenAI API base URL from the configuration.
"""
return self.config.base_url
return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,16 +14,16 @@ from llama_stack_api import json_schema_type
@json_schema_type
class PassthroughImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the passthrough endpoint",
)
@classmethod
def sample_run_config(
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
cls, base_url: HttpUrl | None = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
"api_key": api_key,
}

View file

@ -82,8 +82,8 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
def _get_passthrough_url(self) -> str:
"""Get the passthrough URL from config or provider data."""
if self.config.url is not None:
return self.config.url
if self.config.base_url is not None:
return str(self.config.base_url)
provider_data = self.get_request_provider_data()
if provider_data is None:

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,7 +21,7 @@ class RunpodProviderDataValidator(BaseModel):
@json_schema_type
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the Runpod model serving endpoint",
)
@ -34,6 +34,6 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"url": "${env.RUNPOD_URL:=}",
"base_url": "${env.RUNPOD_URL:=}",
"api_token": "${env.RUNPOD_API_TOKEN}",
}

View file

@ -28,7 +28,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url
return str(self.config.base_url)
async def openai_chat_completion(
self,

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class SambaNovaProviderDataValidator(BaseModel):
@json_schema_type
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.sambanova.ai/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.sambanova.ai/v1"),
description="The URL for the SambaNova AI server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"base_url": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

View file

@ -25,4 +25,4 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
:return: The SambaNova base URL
"""
return self.config.url
return str(self.config.base_url)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -15,18 +15,19 @@ from llama_stack_api import json_schema_type
class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field(
description="The URL for the TGI serving endpoint",
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the TGI serving endpoint (should include /v1 path)",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.TGI_URL:=}",
base_url: str = "${env.TGI_URL:=}",
**kwargs,
):
return {
"url": url,
"base_url": base_url,
}

View file

@ -8,7 +8,7 @@
from collections.abc import Iterable
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from pydantic import HttpUrl, SecretStr
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -23,7 +23,7 @@ log = get_logger(name=__name__, category="inference::tgi")
class _HfAdapter(OpenAIMixin):
url: str
base_url: HttpUrl
api_key: SecretStr
hf_client: AsyncInferenceClient
@ -36,7 +36,7 @@ class _HfAdapter(OpenAIMixin):
return "NO KEY REQUIRED"
def get_base_url(self):
return self.url
return self.base_url
async def list_provider_model_ids(self) -> Iterable[str]:
return [self.model_id]
@ -50,14 +50,20 @@ class _HfAdapter(OpenAIMixin):
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
if not config.url:
if not config.base_url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
log.info(f"Initializing TGI client with url={config.base_url}")
# Extract base URL without /v1 for HF client initialization
base_url_str = str(config.base_url).rstrip("/")
if base_url_str.endswith("/v1"):
base_url_for_client = base_url_str[:-3]
else:
base_url_for_client = base_url_str
self.hf_client = AsyncInferenceClient(model=base_url_for_client, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1"
self.base_url = config.base_url
self.api_key = SecretStr("NO_KEY")

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
@json_schema_type
class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.together.xyz/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.together.xyz/v1"),
description="The URL for the Together AI server",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"base_url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY:=}",
}

View file

@ -9,7 +9,6 @@ from collections.abc import Iterable
from typing import Any, cast
from together import AsyncTogether # type: ignore[import-untyped]
from together.constants import BASE_URL # type: ignore[import-untyped]
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
@ -42,7 +41,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
provider_data_api_key_field: str = "together_api_key"
def get_base_url(self):
return BASE_URL
return str(self.config.base_url)
def _get_client(self) -> AsyncTogether:
together_api_key = None

View file

@ -6,7 +6,7 @@
from pathlib import Path
from pydantic import Field, SecretStr, field_validator
from pydantic import Field, HttpUrl, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,7 +14,7 @@ from llama_stack_api import json_schema_type
@json_schema_type
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
)
@ -48,11 +48,11 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(
cls,
url: str = "${env.VLLM_URL:=}",
base_url: str = "${env.VLLM_URL:=}",
**kwargs,
):
return {
"url": url,
"base_url": base_url,
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",

View file

@ -39,12 +39,12 @@ class VLLMInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str:
"""Get the base URL from config."""
if not self.config.url:
if not self.config.base_url:
raise ValueError("No base URL configured")
return self.config.url
return str(self.config.base_url)
async def initialize(self) -> None:
if not self.config.url:
if not self.config.base_url:
raise ValueError(
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -23,7 +23,7 @@ class WatsonXProviderDataValidator(BaseModel):
@json_schema_type
class WatsonXConfig(RemoteInferenceProviderConfig):
url: str = Field(
base_url: HttpUrl | None = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
@ -39,7 +39,7 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"base_url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:=}",
"project_id": "${env.WATSONX_PROJECT_ID:=}",
}

View file

@ -255,7 +255,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
)
def get_base_url(self) -> str:
return self.config.url
return str(self.config.base_url)
# Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool:
@ -316,7 +316,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
"""
Retrieves foundation model specifications from the watsonx.ai API.
"""
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
url = f"{str(self.config.base_url)}/ml/v1/foundation_model_specs?version=2023-10-25"
headers = {
# Note that there is no authorization header. Listing models does not require authentication.
"Content-Type": "application/json",