chore: use remoteinferenceproviderconfig for remote inference providers (#3668)

# What does this PR do?

on the path to maintainable impls of inference providers. make all
configs instances of RemoteInferenceProviderConfig.

## Test Plan

ci
This commit is contained in:
Matthew Farrellee 2025-10-03 11:48:42 -04:00 committed by GitHub
parent a20e8eac8c
commit ce77c27ff8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 65 additions and 26 deletions

View file

@ -14,6 +14,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | API key for Anthropic models |
## Sample Configuration

View file

@ -21,6 +21,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure |
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |

View file

@ -14,6 +14,7 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |

View file

@ -14,6 +14,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |

View file

@ -14,6 +14,7 @@ Databricks inference provider for running models on Databricks' unified analytic
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | | The URL for the Databricks model serving endpoint |
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |

View file

@ -14,6 +14,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | API key for Gemini models |
## Sample Configuration

View file

@ -14,6 +14,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | The Groq API key |
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |

View file

@ -14,6 +14,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | The Llama API key |
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |

View file

@ -14,6 +14,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |

View file

@ -14,6 +14,7 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |

View file

@ -14,6 +14,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | API key for OpenAI models |
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |

View file

@ -14,6 +14,7 @@ Passthrough inference provider for connecting to any external inference service
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |

View file

@ -14,6 +14,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
| `api_token` | `str \| None` | No | | The API token |

View file

@ -14,6 +14,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |

View file

@ -14,6 +14,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | | The URL for the TGI serving endpoint |
## Sample Configuration

View file

@ -53,6 +53,7 @@ Available Models:
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `project` | `<class 'str'>` | No | | Google Cloud project ID for Vertex AI |
| `location` | `<class 'str'>` | No | us-central1 | Google Cloud location for Vertex AI |

View file

@ -14,6 +14,7 @@ Remote vLLM inference provider for connecting to vLLM servers.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token |

View file

@ -14,6 +14,7 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key |
| `project_id` | `str \| None` | No | | The Project ID key |

View file

@ -14,6 +14,7 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type
class AnthropicConfig(BaseModel):
class AnthropicConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",

View file

@ -9,6 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -30,7 +31,7 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(BaseModel):
class AzureConfig(RemoteInferenceProviderConfig):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)

View file

@ -7,15 +7,16 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
@json_schema_type
class CerebrasImplConfig(BaseModel):
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class DatabricksImplConfig(BaseModel):
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default=None,
description="The URL for the Databricks model serving endpoint",

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type
class GeminiConfig(BaseModel):
class GeminiConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Gemini models",

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(BaseModel):
class GroqConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type
class LlamaCompatConfig(BaseModel):
class LlamaCompatConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="The Llama API key",

View file

@ -7,13 +7,14 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIAConfig(BaseModel):
class NVIDIAConfig(RemoteInferenceProviderConfig):
"""
Configuration for the NVIDIA NIM inference endpoint.

View file

@ -6,12 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
class OllamaImplConfig(RemoteInferenceProviderConfig):
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(
default=False,

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(BaseModel):
class OpenAIConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for OpenAI models",

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PassthroughImplConfig(BaseModel):
class PassthroughImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default=None,
description="The URL for the passthrough endpoint",

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RunpodImplConfig(BaseModel):
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the Runpod model serving endpoint",

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class SambaNovaProviderDataValidator(BaseModel):
@json_schema_type
class SambaNovaImplConfig(BaseModel):
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",

View file

@ -7,11 +7,12 @@
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TGIImplConfig(BaseModel):
class TGIImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
description="The URL for the TGI serving endpoint",
)

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -23,7 +24,7 @@ class VertexAIProviderDataValidator(BaseModel):
@json_schema_type
class VertexAIConfig(BaseModel):
class VertexAIConfig(RemoteInferenceProviderConfig):
project: str = Field(
description="Google Cloud project ID for Vertex AI",
)

View file

@ -6,13 +6,14 @@
from pathlib import Path
from pydantic import BaseModel, Field, field_validator
from pydantic import Field, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VLLMInferenceAdapterConfig(BaseModel):
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the vLLM model serving endpoint",

View file

@ -9,6 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class WatsonXProviderDataValidator(BaseModel):
@json_schema_type
class WatsonXConfig(BaseModel):
class WatsonXConfig(RemoteInferenceProviderConfig):
url: str = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",

View file

@ -6,10 +6,12 @@
import os
from pydantic import BaseModel, Field
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
class BedrockBaseConfig(BaseModel):
class BedrockBaseConfig(RemoteInferenceProviderConfig):
aws_access_key_id: str | None = Field(
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",