mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: make all remote inference provider configs RemoteInferenceProviderConfigs
This commit is contained in:
parent
4dfbe46954
commit
71d67a983e
37 changed files with 65 additions and 26 deletions
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ class AnthropicProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(BaseModel):
|
||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -30,7 +31,7 @@ class AzureProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AzureConfig(BaseModel):
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
|
|
|
@ -7,15 +7,16 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(BaseModel):
|
||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
|
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(BaseModel):
|
||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ class GeminiProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class GeminiConfig(BaseModel):
|
||||
class GeminiConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ class GroqProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(BaseModel):
|
||||
class GroqConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||
default=None,
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ class LlamaProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(BaseModel):
|
||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Llama API key",
|
||||
|
|
|
@ -7,13 +7,14 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIAConfig(BaseModel):
|
||||
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||
"""
|
||||
Configuration for the NVIDIA NIM inference endpoint.
|
||||
|
||||
|
|
|
@ -6,12 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ class OpenAIProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(BaseModel):
|
||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
|
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PassthroughImplConfig(BaseModel):
|
||||
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the passthrough endpoint",
|
||||
|
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RunpodImplConfig(BaseModel):
|
||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Runpod model serving endpoint",
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ class SambaNovaProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaImplConfig(BaseModel):
|
||||
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
|
|
|
@ -7,11 +7,12 @@
|
|||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(BaseModel):
|
||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -23,7 +24,7 @@ class VertexAIProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class VertexAIConfig(BaseModel):
|
||||
class VertexAIConfig(RemoteInferenceProviderConfig):
|
||||
project: str = Field(
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
|
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMInferenceAdapterConfig(BaseModel):
|
||||
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the vLLM model serving endpoint",
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ class WatsonXProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class WatsonXConfig(BaseModel):
|
||||
class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
|
|
|
@ -6,10 +6,12 @@
|
|||
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
|
||||
class BedrockBaseConfig(BaseModel):
|
||||
class BedrockBaseConfig(RemoteInferenceProviderConfig):
|
||||
aws_access_key_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue