chore: use remoteinferenceproviderconfig for remote inference providers (#3668)

# What does this PR do?

on the path to maintainable impls of inference providers. make all
configs instances of RemoteInferenceProviderConfig.

## Test Plan

ci
This commit is contained in:
Matthew Farrellee 2025-10-03 11:48:42 -04:00 committed by GitHub
parent a20e8eac8c
commit ce77c27ff8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 65 additions and 26 deletions

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type
class AnthropicConfig(BaseModel):
class AnthropicConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",

View file

@ -9,6 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -30,7 +31,7 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(BaseModel):
class AzureConfig(RemoteInferenceProviderConfig):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)

View file

@ -7,15 +7,16 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
@json_schema_type
class CerebrasImplConfig(BaseModel):
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class DatabricksImplConfig(BaseModel):
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default=None,
description="The URL for the Databricks model serving endpoint",

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type
class GeminiConfig(BaseModel):
class GeminiConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Gemini models",

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(BaseModel):
class GroqConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type
class LlamaCompatConfig(BaseModel):
class LlamaCompatConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="The Llama API key",

View file

@ -7,13 +7,14 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIAConfig(BaseModel):
class NVIDIAConfig(RemoteInferenceProviderConfig):
"""
Configuration for the NVIDIA NIM inference endpoint.

View file

@ -6,12 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
class OllamaImplConfig(RemoteInferenceProviderConfig):
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(
default=False,

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(BaseModel):
class OpenAIConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for OpenAI models",

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PassthroughImplConfig(BaseModel):
class PassthroughImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default=None,
description="The URL for the passthrough endpoint",

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RunpodImplConfig(BaseModel):
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the Runpod model serving endpoint",

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class SambaNovaProviderDataValidator(BaseModel):
@json_schema_type
class SambaNovaImplConfig(BaseModel):
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",

View file

@ -7,11 +7,12 @@
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TGIImplConfig(BaseModel):
class TGIImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
description="The URL for the TGI serving endpoint",
)

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -23,7 +24,7 @@ class VertexAIProviderDataValidator(BaseModel):
@json_schema_type
class VertexAIConfig(BaseModel):
class VertexAIConfig(RemoteInferenceProviderConfig):
project: str = Field(
description="Google Cloud project ID for Vertex AI",
)

View file

@ -6,13 +6,14 @@
from pathlib import Path
from pydantic import BaseModel, Field, field_validator
from pydantic import Field, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VLLMInferenceAdapterConfig(BaseModel):
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the vLLM model serving endpoint",

View file

@ -9,6 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class WatsonXProviderDataValidator(BaseModel):
@json_schema_type
class WatsonXConfig(BaseModel):
class WatsonXConfig(RemoteInferenceProviderConfig):
url: str = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",

View file

@ -6,10 +6,12 @@
import os
from pydantic import BaseModel, Field
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
class BedrockBaseConfig(BaseModel):
class BedrockBaseConfig(RemoteInferenceProviderConfig):
aws_access_key_id: str | None = Field(
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",