From 71d67a983e1733f5e537673b6268925864c1843e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 2 Oct 2025 18:13:50 -0400 Subject: [PATCH] chore: make all remote inference provider configs RemoteInferenceProviderConfigs --- docs/docs/providers/inference/remote_anthropic.mdx | 1 + docs/docs/providers/inference/remote_azure.mdx | 1 + docs/docs/providers/inference/remote_bedrock.mdx | 1 + docs/docs/providers/inference/remote_cerebras.mdx | 1 + docs/docs/providers/inference/remote_databricks.mdx | 1 + docs/docs/providers/inference/remote_gemini.mdx | 1 + docs/docs/providers/inference/remote_groq.mdx | 1 + .../docs/providers/inference/remote_llama-openai-compat.mdx | 1 + docs/docs/providers/inference/remote_nvidia.mdx | 1 + docs/docs/providers/inference/remote_ollama.mdx | 1 + docs/docs/providers/inference/remote_openai.mdx | 1 + docs/docs/providers/inference/remote_passthrough.mdx | 1 + docs/docs/providers/inference/remote_runpod.mdx | 1 + docs/docs/providers/inference/remote_sambanova.mdx | 1 + docs/docs/providers/inference/remote_tgi.mdx | 1 + docs/docs/providers/inference/remote_vertexai.mdx | 1 + docs/docs/providers/inference/remote_vllm.mdx | 1 + docs/docs/providers/inference/remote_watsonx.mdx | 1 + docs/docs/providers/safety/remote_bedrock.mdx | 1 + llama_stack/providers/remote/inference/anthropic/config.py | 3 ++- llama_stack/providers/remote/inference/azure/config.py | 3 ++- llama_stack/providers/remote/inference/cerebras/config.py | 5 +++-- llama_stack/providers/remote/inference/databricks/config.py | 5 +++-- llama_stack/providers/remote/inference/gemini/config.py | 3 ++- llama_stack/providers/remote/inference/groq/config.py | 3 ++- .../remote/inference/llama_openai_compat/config.py | 3 ++- llama_stack/providers/remote/inference/nvidia/config.py | 5 +++-- llama_stack/providers/remote/inference/ollama/config.py | 6 ++++-- llama_stack/providers/remote/inference/openai/config.py | 3 ++- .../providers/remote/inference/passthrough/config.py | 5 +++-- llama_stack/providers/remote/inference/runpod/config.py | 5 +++-- llama_stack/providers/remote/inference/sambanova/config.py | 3 ++- llama_stack/providers/remote/inference/tgi/config.py | 3 ++- llama_stack/providers/remote/inference/vertexai/config.py | 3 ++- llama_stack/providers/remote/inference/vllm/config.py | 5 +++-- llama_stack/providers/remote/inference/watsonx/config.py | 3 ++- llama_stack/providers/utils/bedrock/config.py | 6 ++++-- 37 files changed, 65 insertions(+), 26 deletions(-) diff --git a/docs/docs/providers/inference/remote_anthropic.mdx b/docs/docs/providers/inference/remote_anthropic.mdx index 6bd636c92..96162d25c 100644 --- a/docs/docs/providers/inference/remote_anthropic.mdx +++ b/docs/docs/providers/inference/remote_anthropic.mdx @@ -14,6 +14,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `api_key` | `str \| None` | No | | API key for Anthropic models | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_azure.mdx b/docs/docs/providers/inference/remote_azure.mdx index 0eb0ea755..721fe429c 100644 --- a/docs/docs/providers/inference/remote_azure.mdx +++ b/docs/docs/providers/inference/remote_azure.mdx @@ -21,6 +21,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `api_key` | `` | No | | Azure API key for Azure | | `api_base` | `` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) | | `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) | diff --git a/docs/docs/providers/inference/remote_bedrock.mdx b/docs/docs/providers/inference/remote_bedrock.mdx index 04c2154a9..2a5d1b74d 100644 --- a/docs/docs/providers/inference/remote_bedrock.mdx +++ b/docs/docs/providers/inference/remote_bedrock.mdx @@ -14,6 +14,7 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | | `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | | `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | diff --git a/docs/docs/providers/inference/remote_cerebras.mdx b/docs/docs/providers/inference/remote_cerebras.mdx index d9cc93aef..1a543389d 100644 --- a/docs/docs/providers/inference/remote_cerebras.mdx +++ b/docs/docs/providers/inference/remote_cerebras.mdx @@ -14,6 +14,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `base_url` | `` | No | https://api.cerebras.ai | Base URL for the Cerebras API | | `api_key` | `` | No | | Cerebras API Key | diff --git a/docs/docs/providers/inference/remote_databricks.mdx b/docs/docs/providers/inference/remote_databricks.mdx index 7f736db9d..995eb72c1 100644 --- a/docs/docs/providers/inference/remote_databricks.mdx +++ b/docs/docs/providers/inference/remote_databricks.mdx @@ -14,6 +14,7 @@ Databricks inference provider for running models on Databricks' unified analytic | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | | The URL for the Databricks model serving endpoint | | `api_token` | `` | No | | The Databricks API token | diff --git a/docs/docs/providers/inference/remote_gemini.mdx b/docs/docs/providers/inference/remote_gemini.mdx index 0505c69da..5222eaa89 100644 --- a/docs/docs/providers/inference/remote_gemini.mdx +++ b/docs/docs/providers/inference/remote_gemini.mdx @@ -14,6 +14,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `api_key` | `str \| None` | No | | API key for Gemini models | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_groq.mdx b/docs/docs/providers/inference/remote_groq.mdx index 1797035c1..77516ed1f 100644 --- a/docs/docs/providers/inference/remote_groq.mdx +++ b/docs/docs/providers/inference/remote_groq.mdx @@ -14,6 +14,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `api_key` | `str \| None` | No | | The Groq API key | | `url` | `` | No | https://api.groq.com | The URL for the Groq AI server | diff --git a/docs/docs/providers/inference/remote_llama-openai-compat.mdx b/docs/docs/providers/inference/remote_llama-openai-compat.mdx index cb624ad87..bcd50f772 100644 --- a/docs/docs/providers/inference/remote_llama-openai-compat.mdx +++ b/docs/docs/providers/inference/remote_llama-openai-compat.mdx @@ -14,6 +14,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `api_key` | `str \| None` | No | | The Llama API key | | `openai_compat_api_base` | `` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server | diff --git a/docs/docs/providers/inference/remote_nvidia.mdx b/docs/docs/providers/inference/remote_nvidia.mdx index 4a8be5d03..348a42e59 100644 --- a/docs/docs/providers/inference/remote_nvidia.mdx +++ b/docs/docs/providers/inference/remote_nvidia.mdx @@ -14,6 +14,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | diff --git a/docs/docs/providers/inference/remote_ollama.mdx b/docs/docs/providers/inference/remote_ollama.mdx index 5d9a4ad6c..f075607d8 100644 --- a/docs/docs/providers/inference/remote_ollama.mdx +++ b/docs/docs/providers/inference/remote_ollama.mdx @@ -14,6 +14,7 @@ Ollama inference provider for running local models through the Ollama runtime. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | http://localhost:11434 | | | `refresh_models` | `` | No | False | Whether to refresh models periodically | diff --git a/docs/docs/providers/inference/remote_openai.mdx b/docs/docs/providers/inference/remote_openai.mdx index 56ca94233..b795d02b1 100644 --- a/docs/docs/providers/inference/remote_openai.mdx +++ b/docs/docs/providers/inference/remote_openai.mdx @@ -14,6 +14,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `api_key` | `str \| None` | No | | API key for OpenAI models | | `base_url` | `` | No | https://api.openai.com/v1 | Base URL for OpenAI API | diff --git a/docs/docs/providers/inference/remote_passthrough.mdx b/docs/docs/providers/inference/remote_passthrough.mdx index 972cc2a08..58d5619b8 100644 --- a/docs/docs/providers/inference/remote_passthrough.mdx +++ b/docs/docs/providers/inference/remote_passthrough.mdx @@ -14,6 +14,7 @@ Passthrough inference provider for connecting to any external inference service | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | | The URL for the passthrough endpoint | | `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint | diff --git a/docs/docs/providers/inference/remote_runpod.mdx b/docs/docs/providers/inference/remote_runpod.mdx index 2e8847dc5..92cc66eb1 100644 --- a/docs/docs/providers/inference/remote_runpod.mdx +++ b/docs/docs/providers/inference/remote_runpod.mdx @@ -14,6 +14,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint | | `api_token` | `str \| None` | No | | The API token | diff --git a/docs/docs/providers/inference/remote_sambanova.mdx b/docs/docs/providers/inference/remote_sambanova.mdx index 6ee28b400..b28471890 100644 --- a/docs/docs/providers/inference/remote_sambanova.mdx +++ b/docs/docs/providers/inference/remote_sambanova.mdx @@ -14,6 +14,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key | diff --git a/docs/docs/providers/inference/remote_tgi.mdx b/docs/docs/providers/inference/remote_tgi.mdx index 3a348056f..6ff82cc2b 100644 --- a/docs/docs/providers/inference/remote_tgi.mdx +++ b/docs/docs/providers/inference/remote_tgi.mdx @@ -14,6 +14,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | | The URL for the TGI serving endpoint | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_vertexai.mdx b/docs/docs/providers/inference/remote_vertexai.mdx index 13a910d43..48da6be24 100644 --- a/docs/docs/providers/inference/remote_vertexai.mdx +++ b/docs/docs/providers/inference/remote_vertexai.mdx @@ -53,6 +53,7 @@ Available Models: | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `project` | `` | No | | Google Cloud project ID for Vertex AI | | `location` | `` | No | us-central1 | Google Cloud location for Vertex AI | diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx index 77b8e1355..598f97b19 100644 --- a/docs/docs/providers/inference/remote_vllm.mdx +++ b/docs/docs/providers/inference/remote_vllm.mdx @@ -14,6 +14,7 @@ Remote vLLM inference provider for connecting to vLLM servers. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint | | `max_tokens` | `` | No | 4096 | Maximum number of tokens to generate. | | `api_token` | `str \| None` | No | fake | The API token | diff --git a/docs/docs/providers/inference/remote_watsonx.mdx b/docs/docs/providers/inference/remote_watsonx.mdx index 1ceccc3ed..8cd3b2869 100644 --- a/docs/docs/providers/inference/remote_watsonx.mdx +++ b/docs/docs/providers/inference/remote_watsonx.mdx @@ -14,6 +14,7 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key | | `project_id` | `str \| None` | No | | The Project ID key | diff --git a/docs/docs/providers/safety/remote_bedrock.mdx b/docs/docs/providers/safety/remote_bedrock.mdx index 5461d7cdc..530a208b5 100644 --- a/docs/docs/providers/safety/remote_bedrock.mdx +++ b/docs/docs/providers/safety/remote_bedrock.mdx @@ -14,6 +14,7 @@ AWS Bedrock safety provider for content moderation using AWS's safety services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | | `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | | `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | diff --git a/llama_stack/providers/remote/inference/anthropic/config.py b/llama_stack/providers/remote/inference/anthropic/config.py index a74b97a9e..de523ca5a 100644 --- a/llama_stack/providers/remote/inference/anthropic/config.py +++ b/llama_stack/providers/remote/inference/anthropic/config.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import BaseModel, Field +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -19,7 +20,7 @@ class AnthropicProviderDataValidator(BaseModel): @json_schema_type -class AnthropicConfig(BaseModel): +class AnthropicConfig(RemoteInferenceProviderConfig): api_key: str | None = Field( default=None, description="API key for Anthropic models", diff --git a/llama_stack/providers/remote/inference/azure/config.py b/llama_stack/providers/remote/inference/azure/config.py index fe9d61d53..8bc7335a3 100644 --- a/llama_stack/providers/remote/inference/azure/config.py +++ b/llama_stack/providers/remote/inference/azure/config.py @@ -9,6 +9,7 @@ from typing import Any from pydantic import BaseModel, Field, HttpUrl, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -30,7 +31,7 @@ class AzureProviderDataValidator(BaseModel): @json_schema_type -class AzureConfig(BaseModel): +class AzureConfig(RemoteInferenceProviderConfig): api_key: SecretStr = Field( description="Azure API key for Azure", ) diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py index 519bd9119..9e7aeb411 100644 --- a/llama_stack/providers/remote/inference/cerebras/config.py +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -7,15 +7,16 @@ import os from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type DEFAULT_BASE_URL = "https://api.cerebras.ai" @json_schema_type -class CerebrasImplConfig(BaseModel): +class CerebrasImplConfig(RemoteInferenceProviderConfig): base_url: str = Field( default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), description="Base URL for the Cerebras API", diff --git a/llama_stack/providers/remote/inference/databricks/config.py b/llama_stack/providers/remote/inference/databricks/config.py index 67cd0480c..b5406a1c5 100644 --- a/llama_stack/providers/remote/inference/databricks/config.py +++ b/llama_stack/providers/remote/inference/databricks/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class DatabricksImplConfig(BaseModel): +class DatabricksImplConfig(RemoteInferenceProviderConfig): url: str = Field( default=None, description="The URL for the Databricks model serving endpoint", diff --git a/llama_stack/providers/remote/inference/gemini/config.py b/llama_stack/providers/remote/inference/gemini/config.py index c897777f7..c7dacec96 100644 --- a/llama_stack/providers/remote/inference/gemini/config.py +++ b/llama_stack/providers/remote/inference/gemini/config.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import BaseModel, Field +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -19,7 +20,7 @@ class GeminiProviderDataValidator(BaseModel): @json_schema_type -class GeminiConfig(BaseModel): +class GeminiConfig(RemoteInferenceProviderConfig): api_key: str | None = Field( default=None, description="API key for Gemini models", diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py index 67e9fa358..23deba22e 100644 --- a/llama_stack/providers/remote/inference/groq/config.py +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import BaseModel, Field +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -19,7 +20,7 @@ class GroqProviderDataValidator(BaseModel): @json_schema_type -class GroqConfig(BaseModel): +class GroqConfig(RemoteInferenceProviderConfig): api_key: str | None = Field( # The Groq client library loads the GROQ_API_KEY environment variable by default default=None, diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/config.py b/llama_stack/providers/remote/inference/llama_openai_compat/config.py index 57bc7240d..0697c041d 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/config.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/config.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import BaseModel, Field +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -19,7 +20,7 @@ class LlamaProviderDataValidator(BaseModel): @json_schema_type -class LlamaCompatConfig(BaseModel): +class LlamaCompatConfig(RemoteInferenceProviderConfig): api_key: str | None = Field( default=None, description="The Llama API key", diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index e1b791719..4b310d770 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -7,13 +7,14 @@ import os from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class NVIDIAConfig(BaseModel): +class NVIDIAConfig(RemoteInferenceProviderConfig): """ Configuration for the NVIDIA NIM inference endpoint. diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index ce13f0d83..d2f104e1e 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -6,12 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field +from pydantic import Field + +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig DEFAULT_OLLAMA_URL = "http://localhost:11434" -class OllamaImplConfig(BaseModel): +class OllamaImplConfig(RemoteInferenceProviderConfig): url: str = DEFAULT_OLLAMA_URL refresh_models: bool = Field( default=False, diff --git a/llama_stack/providers/remote/inference/openai/config.py b/llama_stack/providers/remote/inference/openai/config.py index ad25cdfa5..e494e967b 100644 --- a/llama_stack/providers/remote/inference/openai/config.py +++ b/llama_stack/providers/remote/inference/openai/config.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import BaseModel, Field +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -19,7 +20,7 @@ class OpenAIProviderDataValidator(BaseModel): @json_schema_type -class OpenAIConfig(BaseModel): +class OpenAIConfig(RemoteInferenceProviderConfig): api_key: str | None = Field( default=None, description="API key for OpenAI models", diff --git a/llama_stack/providers/remote/inference/passthrough/config.py b/llama_stack/providers/remote/inference/passthrough/config.py index 647b2db46..f8e8b8ce5 100644 --- a/llama_stack/providers/remote/inference/passthrough/config.py +++ b/llama_stack/providers/remote/inference/passthrough/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class PassthroughImplConfig(BaseModel): +class PassthroughImplConfig(RemoteInferenceProviderConfig): url: str = Field( default=None, description="The URL for the passthrough endpoint", diff --git a/llama_stack/providers/remote/inference/runpod/config.py b/llama_stack/providers/remote/inference/runpod/config.py index 7bc9e8485..cdfe0f885 100644 --- a/llama_stack/providers/remote/inference/runpod/config.py +++ b/llama_stack/providers/remote/inference/runpod/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field +from pydantic import Field +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class RunpodImplConfig(BaseModel): +class RunpodImplConfig(RemoteInferenceProviderConfig): url: str | None = Field( default=None, description="The URL for the Runpod model serving endpoint", diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index 50ad53d06..a614663dc 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import BaseModel, Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -19,7 +20,7 @@ class SambaNovaProviderDataValidator(BaseModel): @json_schema_type -class SambaNovaImplConfig(BaseModel): +class SambaNovaImplConfig(RemoteInferenceProviderConfig): url: str = Field( default="https://api.sambanova.ai/v1", description="The URL for the SambaNova AI server", diff --git a/llama_stack/providers/remote/inference/tgi/config.py b/llama_stack/providers/remote/inference/tgi/config.py index 55136c8ba..d3110b2af 100644 --- a/llama_stack/providers/remote/inference/tgi/config.py +++ b/llama_stack/providers/remote/inference/tgi/config.py @@ -7,11 +7,12 @@ from pydantic import BaseModel, Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class TGIImplConfig(BaseModel): +class TGIImplConfig(RemoteInferenceProviderConfig): url: str = Field( description="The URL for the TGI serving endpoint", ) diff --git a/llama_stack/providers/remote/inference/vertexai/config.py b/llama_stack/providers/remote/inference/vertexai/config.py index 659de653e..97d0852a8 100644 --- a/llama_stack/providers/remote/inference/vertexai/config.py +++ b/llama_stack/providers/remote/inference/vertexai/config.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import BaseModel, Field +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -23,7 +24,7 @@ class VertexAIProviderDataValidator(BaseModel): @json_schema_type -class VertexAIConfig(BaseModel): +class VertexAIConfig(RemoteInferenceProviderConfig): project: str = Field( description="Google Cloud project ID for Vertex AI", ) diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index a5bf0e4bc..86ef3fe26 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -6,13 +6,14 @@ from pathlib import Path -from pydantic import BaseModel, Field, field_validator +from pydantic import Field, field_validator +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class VLLMInferenceAdapterConfig(BaseModel): +class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): url: str | None = Field( default=None, description="The URL for the vLLM model serving endpoint", diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 42c25d93e..4bc0173c4 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -9,6 +9,7 @@ from typing import Any from pydantic import BaseModel, Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -19,7 +20,7 @@ class WatsonXProviderDataValidator(BaseModel): @json_schema_type -class WatsonXConfig(BaseModel): +class WatsonXConfig(RemoteInferenceProviderConfig): url: str = Field( default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), description="A base url for accessing the watsonx.ai", diff --git a/llama_stack/providers/utils/bedrock/config.py b/llama_stack/providers/utils/bedrock/config.py index 2745c88cb..418cf381b 100644 --- a/llama_stack/providers/utils/bedrock/config.py +++ b/llama_stack/providers/utils/bedrock/config.py @@ -6,10 +6,12 @@ import os -from pydantic import BaseModel, Field +from pydantic import Field + +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig -class BedrockBaseConfig(BaseModel): +class BedrockBaseConfig(RemoteInferenceProviderConfig): aws_access_key_id: str | None = Field( default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"), description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",