Merge branch 'main' into fix/nvidia-safety-provider-endpoint-4189

This commit is contained in:
Roy Belio 2025-11-20 13:30:11 +02:00 committed by GitHub
commit f8f28344a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
117 changed files with 16294 additions and 769 deletions

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
@ -22,4 +20,4 @@ class AzureInferenceAdapter(OpenAIMixin):
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")
return str(self.config.base_url)

View file

@ -32,8 +32,9 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(RemoteInferenceProviderConfig):
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
base_url: HttpUrl | None = Field(
default=None,
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1)",
)
api_version: str | None = Field(
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
@ -48,14 +49,14 @@ class AzureConfig(RemoteInferenceProviderConfig):
def sample_run_config(
cls,
api_key: str = "${env.AZURE_API_KEY:=}",
api_base: str = "${env.AZURE_API_BASE:=}",
base_url: str = "${env.AZURE_API_BASE:=}",
api_version: str = "${env.AZURE_API_VERSION:=}",
api_type: str = "${env.AZURE_API_TYPE:=}",
**kwargs,
) -> dict[str, Any]:
return {
"api_key": api_key,
"api_base": api_base,
"base_url": base_url,
"api_version": api_version,
"api_type": api_type,
}

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import (
OpenAIEmbeddingsRequestWithExtraBody,
@ -21,7 +19,7 @@ class CerebrasInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "cerebras_api_key"
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
return str(self.config.base_url)
async def openai_embeddings(
self,

View file

@ -7,12 +7,12 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
DEFAULT_BASE_URL = "https://api.cerebras.ai/v1"
class CerebrasProviderDataValidator(BaseModel):
@ -24,8 +24,8 @@ class CerebrasProviderDataValidator(BaseModel):
@json_schema_type
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
base_url: HttpUrl | None = Field(
default=HttpUrl(os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL)),
description="Base URL for the Cerebras API",
)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,9 +21,9 @@ class DatabricksProviderDataValidator(BaseModel):
@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
description="The URL for the Databricks model serving endpoint (should include /serving-endpoints path)",
)
auth_credential: SecretStr | None = Field(
default=None,
@ -34,11 +34,11 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_HOST:=}",
base_url: str = "${env.DATABRICKS_HOST:=}",
api_token: str = "${env.DATABRICKS_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
"api_token": api_token,
}

View file

@ -29,15 +29,21 @@ class DatabricksInferenceAdapter(OpenAIMixin):
}
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
return str(self.config.base_url)
async def list_provider_model_ids(self) -> Iterable[str]:
# Filter out None values from endpoint names
api_token = self._get_api_key_from_config_or_provider_data()
# WorkspaceClient expects base host without /serving-endpoints suffix
base_url_str = str(self.config.base_url)
if base_url_str.endswith("/serving-endpoints"):
host = base_url_str[:-18] # Remove '/serving-endpoints'
else:
host = base_url_str
return [
endpoint.name # type: ignore[misc]
for endpoint in WorkspaceClient(
host=self.config.url, token=api_token
host=host, token=api_token
).serving_endpoints.list() # TODO: this is not async
]

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
@json_schema_type
class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.fireworks.ai/inference/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.fireworks.ai/inference/v1"),
description="The URL for the Fireworks server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"base_url": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,
}

View file

@ -24,4 +24,4 @@ class FireworksInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "fireworks_api_key"
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.groq.com",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.groq.com/openai/v1"),
description="The URL for the Groq AI server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.groq.com",
"base_url": "https://api.groq.com/openai/v1",
"api_key": api_key,
}

View file

@ -15,4 +15,4 @@ class GroqInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "groq_api_key"
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"
return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type
class LlamaCompatConfig(RemoteInferenceProviderConfig):
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.llama.com/compat/v1/"),
description="The URL for the Llama API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
"base_url": "https://api.llama.com/compat/v1/",
"api_key": api_key,
}

View file

@ -31,7 +31,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
:return: The Llama API base URL
"""
return self.config.openai_compat_api_base
return str(self.config.base_url)
async def openai_completion(
self,

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -44,18 +44,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
URL of your running NVIDIA NIM and do not need to set the api_key.
"""
url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
base_url: HttpUrl | None = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
description="A base url for accessing the NVIDIA NIM",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
rerank_model_to_url: dict[str, str] = Field(
default_factory=lambda: {
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
@ -68,13 +64,11 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(
cls,
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
base_url: HttpUrl | None = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}",
api_key: str = "${env.NVIDIA_API_KEY:=}",
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
**kwargs,
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
"api_key": api_key,
"append_api_version": append_api_version,
}

View file

@ -44,7 +44,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
}
async def initialize(self) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.base_url})...")
if _is_nvidia_hosted(self.config):
if not self.config.auth_credential:
@ -72,7 +72,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API base URL
"""
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
return str(self.config.base_url)
async def list_provider_model_ids(self) -> Iterable[str]:
"""

View file

@ -8,4 +8,4 @@ from . import NVIDIAConfig
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url
return "integrate.api.nvidia.com" in str(config.base_url)

View file

@ -6,20 +6,22 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
DEFAULT_OLLAMA_URL = "http://localhost:11434/v1"
class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL
base_url: HttpUrl | None = Field(default=HttpUrl(DEFAULT_OLLAMA_URL))
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
def sample_run_config(
cls, base_url: str = "${env.OLLAMA_URL:=http://localhost:11434/v1}", **kwargs
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
}

View file

@ -55,17 +55,23 @@ class OllamaInferenceAdapter(OpenAIMixin):
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
# Ollama client expects base URL without /v1 suffix
base_url_str = str(self.config.base_url)
if base_url_str.endswith("/v1"):
host = base_url_str[:-3]
else:
host = base_url_str
self._clients[loop] = AsyncOllamaClient(host=host)
return self._clients[loop]
def get_api_key(self):
return "NO KEY REQUIRED"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
return str(self.config.base_url)
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
logger.info(f"checking connectivity to Ollama at `{self.config.base_url}`...")
r = await self.health()
if r["status"] == HealthStatus.ERROR:
logger.warning(

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,8 +21,8 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default="https://api.openai.com/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.openai.com/v1"),
description="Base URL for OpenAI API",
)

View file

@ -35,4 +35,4 @@ class OpenAIInferenceAdapter(OpenAIMixin):
Returns the OpenAI API base URL from the configuration.
"""
return self.config.base_url
return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,16 +14,16 @@ from llama_stack_api import json_schema_type
@json_schema_type
class PassthroughImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the passthrough endpoint",
)
@classmethod
def sample_run_config(
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
cls, base_url: HttpUrl | None = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
"api_key": api_key,
}

View file

@ -82,8 +82,8 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
def _get_passthrough_url(self) -> str:
"""Get the passthrough URL from config or provider data."""
if self.config.url is not None:
return self.config.url
if self.config.base_url is not None:
return str(self.config.base_url)
provider_data = self.get_request_provider_data()
if provider_data is None:

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,7 +21,7 @@ class RunpodProviderDataValidator(BaseModel):
@json_schema_type
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the Runpod model serving endpoint",
)
@ -34,6 +34,6 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"url": "${env.RUNPOD_URL:=}",
"base_url": "${env.RUNPOD_URL:=}",
"api_token": "${env.RUNPOD_API_TOKEN}",
}

View file

@ -28,7 +28,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url
return str(self.config.base_url)
async def openai_chat_completion(
self,

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class SambaNovaProviderDataValidator(BaseModel):
@json_schema_type
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.sambanova.ai/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.sambanova.ai/v1"),
description="The URL for the SambaNova AI server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"base_url": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

View file

@ -25,4 +25,4 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
:return: The SambaNova base URL
"""
return self.config.url
return str(self.config.base_url)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -15,18 +15,19 @@ from llama_stack_api import json_schema_type
class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field(
description="The URL for the TGI serving endpoint",
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the TGI serving endpoint (should include /v1 path)",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.TGI_URL:=}",
base_url: str = "${env.TGI_URL:=}",
**kwargs,
):
return {
"url": url,
"base_url": base_url,
}

View file

@ -8,7 +8,7 @@
from collections.abc import Iterable
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from pydantic import HttpUrl, SecretStr
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -23,7 +23,7 @@ log = get_logger(name=__name__, category="inference::tgi")
class _HfAdapter(OpenAIMixin):
url: str
base_url: HttpUrl
api_key: SecretStr
hf_client: AsyncInferenceClient
@ -36,7 +36,7 @@ class _HfAdapter(OpenAIMixin):
return "NO KEY REQUIRED"
def get_base_url(self):
return self.url
return self.base_url
async def list_provider_model_ids(self) -> Iterable[str]:
return [self.model_id]
@ -50,14 +50,20 @@ class _HfAdapter(OpenAIMixin):
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
if not config.url:
if not config.base_url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
log.info(f"Initializing TGI client with url={config.base_url}")
# Extract base URL without /v1 for HF client initialization
base_url_str = str(config.base_url).rstrip("/")
if base_url_str.endswith("/v1"):
base_url_for_client = base_url_str[:-3]
else:
base_url_for_client = base_url_str
self.hf_client = AsyncInferenceClient(model=base_url_for_client, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1"
self.base_url = config.base_url
self.api_key = SecretStr("NO_KEY")

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
@json_schema_type
class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.together.xyz/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.together.xyz/v1"),
description="The URL for the Together AI server",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"base_url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY:=}",
}

View file

@ -9,7 +9,6 @@ from collections.abc import Iterable
from typing import Any, cast
from together import AsyncTogether # type: ignore[import-untyped]
from together.constants import BASE_URL # type: ignore[import-untyped]
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
@ -42,7 +41,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
provider_data_api_key_field: str = "together_api_key"
def get_base_url(self):
return BASE_URL
return str(self.config.base_url)
def _get_client(self) -> AsyncTogether:
together_api_key = None

View file

@ -6,7 +6,7 @@
from pathlib import Path
from pydantic import Field, SecretStr, field_validator
from pydantic import Field, HttpUrl, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,7 +14,7 @@ from llama_stack_api import json_schema_type
@json_schema_type
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
)
@ -48,11 +48,11 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(
cls,
url: str = "${env.VLLM_URL:=}",
base_url: str = "${env.VLLM_URL:=}",
**kwargs,
):
return {
"url": url,
"base_url": base_url,
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",

View file

@ -39,12 +39,12 @@ class VLLMInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str:
"""Get the base URL from config."""
if not self.config.url:
if not self.config.base_url:
raise ValueError("No base URL configured")
return self.config.url
return str(self.config.base_url)
async def initialize(self) -> None:
if not self.config.url:
if not self.config.base_url:
raise ValueError(
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -23,7 +23,7 @@ class WatsonXProviderDataValidator(BaseModel):
@json_schema_type
class WatsonXConfig(RemoteInferenceProviderConfig):
url: str = Field(
base_url: HttpUrl | None = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
@ -39,7 +39,7 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"base_url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:=}",
"project_id": "${env.WATSONX_PROJECT_ID:=}",
}

View file

@ -255,7 +255,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
)
def get_base_url(self) -> str:
return self.config.url
return str(self.config.base_url)
# Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool:
@ -316,7 +316,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
"""
Retrieves foundation model specifications from the watsonx.ai API.
"""
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
url = f"{str(self.config.base_url)}/ml/v1/foundation_model_specs?version=2023-10-25"
headers = {
# Note that there is no authorization header. Listing models does not require authentication.
"Content-Type": "application/json",