diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index faaeefd01..b5558c66f 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec): default=None, ) - @property - def pip_packages(self) -> list[str]: - raise AssertionError("Should not be called on AutoRoutedProviderSpec") - # Example: /models, /shields class RoutingTableProviderSpec(ProviderSpec): diff --git a/llama_stack/core/distribution.py b/llama_stack/core/distribution.py index c104b6764..302ecb960 100644 --- a/llama_stack/core/distribution.py +++ b/llama_stack/core/distribution.py @@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec from llama_stack.core.external import load_external_apis from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) logger = get_logger(name=__name__, category="core") @@ -77,27 +76,12 @@ def providable_apis() -> list[Api]: def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec: - adapter = AdapterSpec(**spec_data["adapter"]) - spec = remote_provider_spec( - api=api, - adapter=adapter, - api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])], - ) + spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data) return spec def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: - spec = InlineProviderSpec( - api=api, - provider_type=f"inline::{provider_name}", - pip_packages=spec_data.get("pip_packages", []), - module=spec_data["module"], - config_class=spec_data["config_class"], - api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])], - optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])], - provider_data_validator=spec_data.get("provider_data_validator"), - container_image=spec_data.get("container_image"), - ) + spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data) return spec diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index c2dfe95ad..6bee51ff0 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]: remote_providers = [ provider for provider in available_providers() - if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS + if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS ] inference_providers = [] for provider_spec in remote_providers: - provider_type = provider_spec.adapter.adapter_type + provider_type = provider_spec.adapter_type if provider_type in INFERENCE_PROVIDER_IDS: provider_id = INFERENCE_PROVIDER_IDS[provider_type] diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 5e15dd8e1..c8ff9cecb 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -131,6 +131,15 @@ class ProviderSpec(BaseModel): """, ) + pip_packages: list[str] = Field( + default_factory=list, + description="The pip dependencies needed for this implementation", + ) + + provider_data_validator: str | None = Field( + default=None, + ) + is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.") # used internally by the resolver; this is a hack for now @@ -145,45 +154,8 @@ class RoutingTable(Protocol): async def get_provider_impl(self, routing_key: str) -> Any: ... -# TODO: this can now be inlined into RemoteProviderSpec -@json_schema_type -class AdapterSpec(BaseModel): - adapter_type: str = Field( - ..., - description="Unique identifier for this adapter", - ) - module: str = Field( - default_factory=str, - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - - `get_adapter_impl(config, deps)`: returns the adapter implementation -""", - ) - pip_packages: list[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) - config_class: str = Field( - description="Fully-qualified classname of the config for this provider", - ) - provider_data_validator: str | None = Field( - default=None, - ) - description: str | None = Field( - default=None, - description=""" -A description of the provider. This is used to display in the documentation. -""", - ) - - @json_schema_type class InlineProviderSpec(ProviderSpec): - pip_packages: list[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) container_image: str | None = Field( default=None, description=""" @@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack If a provider depends on other providers, the dependencies MUST NOT specify a container image. """, ) - # module field is inherited from ProviderSpec - provider_data_validator: str | None = Field( - default=None, - ) description: str | None = Field( default=None, description=""" @@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel): @json_schema_type class RemoteProviderSpec(ProviderSpec): - adapter: AdapterSpec = Field( + adapter_type: str = Field( + ..., + description="Unique identifier for this adapter", + ) + + description: str | None = Field( + default=None, description=""" -If some code is needed to convert the remote responses into Llama Stack compatible -API responses, specify the adapter here. +A description of the provider. This is used to display in the documentation. """, ) @@ -234,33 +207,6 @@ API responses, specify the adapter here. def container_image(self) -> str | None: return None - # module field is inherited from ProviderSpec - - @property - def pip_packages(self) -> list[str]: - return self.adapter.pip_packages - - @property - def provider_data_validator(self) -> str | None: - return self.adapter.provider_data_validator - - -def remote_provider_spec( - api: Api, - adapter: AdapterSpec, - api_dependencies: list[Api] | None = None, - optional_api_dependencies: list[Api] | None = None, -) -> RemoteProviderSpec: - return RemoteProviderSpec( - api=api, - provider_type=f"remote::{adapter.adapter_type}", - config_class=adapter.config_class, - module=adapter.module, - adapter=adapter, - api_dependencies=api_dependencies or [], - optional_api_dependencies=optional_api_dependencies or [], - ) - class HealthStatus(StrEnum): OK = "OK" diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index f641b4ce3..a9feb0bac 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]: api_dependencies=[], description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.datasetio, - adapter=AdapterSpec( - adapter_type="huggingface", - pip_packages=[ - "datasets>=4.0.0", - ], - module="llama_stack.providers.remote.datasetio.huggingface", - config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", - description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.", - ), + adapter_type="huggingface", + provider_type="remote::huggingface", + pip_packages=[ + "datasets>=4.0.0", + ], + module="llama_stack.providers.remote.datasetio.huggingface", + config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", + description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.datasetio, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=[ - "datasets>=4.0.0", - ], - module="llama_stack.providers.remote.datasetio.nvidia", - config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig", - description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + module="llama_stack.providers.remote.datasetio.nvidia", + config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig", + pip_packages=[ + "datasets>=4.0.0", + ], + description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.", ), ] diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 9f0d17916..4ef0bb41f 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec def available_providers() -> list[ProviderSpec]: @@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]: ], description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.eval, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=[ - "requests", - ], - module="llama_stack.providers.remote.eval.nvidia", - config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", - description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.", - ), + adapter_type="nvidia", + pip_packages=[ + "requests", + ], + provider_type="remote::nvidia", + module="llama_stack.providers.remote.eval.nvidia", + config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", + description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.", api_dependencies=[ Api.datasetio, Api.datasets, diff --git a/llama_stack/providers/registry/files.py b/llama_stack/providers/registry/files.py index ebe90310c..9acabfacd 100644 --- a/llama_stack/providers/registry/files.py +++ b/llama_stack/providers/registry/files.py @@ -4,13 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.datatypes import ( - AdapterSpec, - Api, - InlineProviderSpec, - ProviderSpec, - remote_provider_spec, -) +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages @@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", description="Local filesystem-based file storage provider for managing files and documents locally.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.files, - adapter=AdapterSpec( - adapter_type="s3", - pip_packages=["boto3"] + sql_store_pip_packages, - module="llama_stack.providers.remote.files.s3", - config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", - description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", - ), + provider_type="remote::s3", + adapter_type="s3", + pip_packages=["boto3"] + sql_store_pip_packages, + module="llama_stack.providers.remote.files.s3", + config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", + description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", ), ] diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 0eb4cf104..658611698 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) META_REFERENCE_DEPS = [ @@ -49,177 +48,167 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", description="Sentence Transformers inference provider for text embeddings and similarity search.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="cerebras", - pip_packages=[ - "cerebras_cloud_sdk", - ], - module="llama_stack.providers.remote.inference.cerebras", - config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", - description="Cerebras inference provider for running models on Cerebras Cloud platform.", - ), + adapter_type="cerebras", + provider_type="remote::cerebras", + pip_packages=[ + "cerebras_cloud_sdk", + ], + module="llama_stack.providers.remote.inference.cerebras", + config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", + description="Cerebras inference provider for running models on Cerebras Cloud platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="ollama", - pip_packages=["ollama", "aiohttp", "h11>=0.16.0"], - config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", - module="llama_stack.providers.remote.inference.ollama", - description="Ollama inference provider for running local models through the Ollama runtime.", - ), + adapter_type="ollama", + provider_type="remote::ollama", + pip_packages=["ollama", "aiohttp", "h11>=0.16.0"], + config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", + module="llama_stack.providers.remote.inference.ollama", + description="Ollama inference provider for running local models through the Ollama runtime.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="vllm", - pip_packages=[], - module="llama_stack.providers.remote.inference.vllm", - config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", - provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", - description="Remote vLLM inference provider for connecting to vLLM servers.", - ), + adapter_type="vllm", + provider_type="remote::vllm", + pip_packages=[], + module="llama_stack.providers.remote.inference.vllm", + config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", + provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", + description="Remote vLLM inference provider for connecting to vLLM servers.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="tgi", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", - description="Text Generation Inference (TGI) provider for HuggingFace model serving.", - ), + adapter_type="tgi", + provider_type="remote::tgi", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", + description="Text Generation Inference (TGI) provider for HuggingFace model serving.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="hf::serverless", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", - description="HuggingFace Inference API serverless provider for on-demand model inference.", - ), + adapter_type="hf::serverless", + provider_type="remote::hf::serverless", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", + description="HuggingFace Inference API serverless provider for on-demand model inference.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="hf::endpoint", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", - description="HuggingFace Inference Endpoints provider for dedicated model serving.", - ), + provider_type="remote::hf::endpoint", + adapter_type="hf::endpoint", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", + description="HuggingFace Inference Endpoints provider for dedicated model serving.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="fireworks", - pip_packages=[ - "fireworks-ai<=0.17.16", - ], - module="llama_stack.providers.remote.inference.fireworks", - config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", - description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.", - ), + adapter_type="fireworks", + provider_type="remote::fireworks", + pip_packages=[ + "fireworks-ai<=0.17.16", + ], + module="llama_stack.providers.remote.inference.fireworks", + config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", + description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="together", - pip_packages=[ - "together", - ], - module="llama_stack.providers.remote.inference.together", - config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", - description="Together AI inference provider for open-source models and collaborative AI development.", - ), + adapter_type="together", + provider_type="remote::together", + pip_packages=[ + "together", + ], + module="llama_stack.providers.remote.inference.together", + config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", + description="Together AI inference provider for open-source models and collaborative AI development.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="bedrock", - pip_packages=["boto3"], - module="llama_stack.providers.remote.inference.bedrock", - config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", - description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", - ), + adapter_type="bedrock", + provider_type="remote::bedrock", + pip_packages=["boto3"], + module="llama_stack.providers.remote.inference.bedrock", + config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", + description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="databricks", - pip_packages=[], - module="llama_stack.providers.remote.inference.databricks", - config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", - description="Databricks inference provider for running models on Databricks' unified analytics platform.", - ), + adapter_type="databricks", + provider_type="remote::databricks", + pip_packages=[], + module="llama_stack.providers.remote.inference.databricks", + config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", + description="Databricks inference provider for running models on Databricks' unified analytics platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=[], - module="llama_stack.providers.remote.inference.nvidia", - config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", - description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + pip_packages=[], + module="llama_stack.providers.remote.inference.nvidia", + config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", + description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="runpod", - pip_packages=[], - module="llama_stack.providers.remote.inference.runpod", - config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", - description="RunPod inference provider for running models on RunPod's cloud GPU platform.", - ), + adapter_type="runpod", + provider_type="remote::runpod", + pip_packages=[], + module="llama_stack.providers.remote.inference.runpod", + config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", + description="RunPod inference provider for running models on RunPod's cloud GPU platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="openai", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.openai", - config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", - provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", - description="OpenAI inference provider for accessing GPT models and other OpenAI services.", - ), + adapter_type="openai", + provider_type="remote::openai", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.openai", + config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", + description="OpenAI inference provider for accessing GPT models and other OpenAI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="anthropic", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.anthropic", - config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", - provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", - description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.", - ), + adapter_type="anthropic", + provider_type="remote::anthropic", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.anthropic", + config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", + provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", + description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="gemini", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.gemini", - config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", - provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", - description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", - ), + adapter_type="gemini", + provider_type="remote::gemini", + pip_packages=[ + "litellm", + ], + module="llama_stack.providers.remote.inference.gemini", + config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", + provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", + description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="vertexai", - pip_packages=["litellm", "google-cloud-aiplatform"], - module="llama_stack.providers.remote.inference.vertexai", - config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", - provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", - description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + adapter_type="vertexai", + provider_type="remote::vertexai", + pip_packages=[ + "litellm", + "google-cloud-aiplatform", + ], + module="llama_stack.providers.remote.inference.vertexai", + config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", + description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: • Enterprise-grade security: Uses Google Cloud's security controls and IAM • Better integration: Seamless integration with other Google Cloud services @@ -239,76 +228,73 @@ Available Models: - vertex_ai/gemini-2.0-flash - vertex_ai/gemini-2.5-flash - vertex_ai/gemini-2.5-pro""", - ), ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="groq", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.groq", - config_class="llama_stack.providers.remote.inference.groq.GroqConfig", - provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", - description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", - ), + adapter_type="groq", + provider_type="remote::groq", + pip_packages=[ + "litellm", + ], + module="llama_stack.providers.remote.inference.groq", + config_class="llama_stack.providers.remote.inference.groq.GroqConfig", + provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", + description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="llama-openai-compat", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.llama_openai_compat", - config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", - provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", - description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", - ), + adapter_type="llama-openai-compat", + provider_type="remote::llama-openai-compat", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.llama_openai_compat", + config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", + provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", + description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="sambanova", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.sambanova", - config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", - description="SambaNova inference provider for running models on SambaNova's dataflow architecture.", - ), + adapter_type="sambanova", + provider_type="remote::sambanova", + pip_packages=[ + "litellm", + ], + module="llama_stack.providers.remote.inference.sambanova", + config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", + description="SambaNova inference provider for running models on SambaNova's dataflow architecture.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="passthrough", - pip_packages=[], - module="llama_stack.providers.remote.inference.passthrough", - config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", - description="Passthrough inference provider for connecting to any external inference service not directly supported.", - ), + adapter_type="passthrough", + provider_type="remote::passthrough", + pip_packages=[], + module="llama_stack.providers.remote.inference.passthrough", + config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", + description="Passthrough inference provider for connecting to any external inference service not directly supported.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="watsonx", - pip_packages=["ibm_watsonx_ai"], - module="llama_stack.providers.remote.inference.watsonx", - config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", - provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", - description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", - ), + adapter_type="watsonx", + provider_type="remote::watsonx", + pip_packages=["ibm_watsonx_ai"], + module="llama_stack.providers.remote.inference.watsonx", + config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="azure", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.azure", - config_class="llama_stack.providers.remote.inference.azure.AzureConfig", - provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator", - description=""" + provider_type="remote::azure", + adapter_type="azure", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.azure", + config_class="llama_stack.providers.remote.inference.azure.AzureConfig", + provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator", + description=""" Azure OpenAI inference provider for accessing GPT models and other Azure services. Provider documentation https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview """, - ), ), ] diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 47aeb401e..2092e3b2d 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -7,7 +7,7 @@ from typing import cast -from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec # We provide two versions of these providers so that distributions can package the appropriate version of torch. # The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. @@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]: ], description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.post_training, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=["requests", "aiohttp"], - module="llama_stack.providers.remote.post_training.nvidia", - config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", - description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + pip_packages=["requests", "aiohttp"], + module="llama_stack.providers.remote.post_training.nvidia", + config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", + description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.", ), ] diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 9dd791bd8..b30074398 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.safety, - adapter=AdapterSpec( - adapter_type="bedrock", - pip_packages=["boto3"], - module="llama_stack.providers.remote.safety.bedrock", - config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", - description="AWS Bedrock safety provider for content moderation using AWS's safety services.", - ), + adapter_type="bedrock", + provider_type="remote::bedrock", + pip_packages=["boto3"], + module="llama_stack.providers.remote.safety.bedrock", + config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", + description="AWS Bedrock safety provider for content moderation using AWS's safety services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.safety, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=["requests"], - module="llama_stack.providers.remote.safety.nvidia", - config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", - description="NVIDIA's safety provider for content moderation and safety filtering.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + pip_packages=["requests"], + module="llama_stack.providers.remote.safety.nvidia", + config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", + description="NVIDIA's safety provider for content moderation and safety filtering.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.safety, - adapter=AdapterSpec( - adapter_type="sambanova", - pip_packages=["litellm", "requests"], - module="llama_stack.providers.remote.safety.sambanova", - config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", - provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", - description="SambaNova's safety provider for content moderation and safety filtering.", - ), + adapter_type="sambanova", + provider_type="remote::sambanova", + pip_packages=["litellm", "requests"], + module="llama_stack.providers.remote.safety.sambanova", + config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", + provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", + description="SambaNova's safety provider for content moderation and safety filtering.", ), ] diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 5a58fa7af..ad8c31dfd 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]: api_dependencies=[Api.vector_io, Api.inference, Api.files], description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="brave-search", - module="llama_stack.providers.remote.tool_runtime.brave_search", - config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", - description="Brave Search tool for web search capabilities with privacy-focused results.", - ), + adapter_type="brave-search", + provider_type="remote::brave-search", + module="llama_stack.providers.remote.tool_runtime.brave_search", + config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", + description="Brave Search tool for web search capabilities with privacy-focused results.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="bing-search", - module="llama_stack.providers.remote.tool_runtime.bing_search", - config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", - description="Bing Search tool for web search capabilities using Microsoft's search engine.", - ), + adapter_type="bing-search", + provider_type="remote::bing-search", + module="llama_stack.providers.remote.tool_runtime.bing_search", + config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", + description="Bing Search tool for web search capabilities using Microsoft's search engine.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="tavily-search", - module="llama_stack.providers.remote.tool_runtime.tavily_search", - config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", - description="Tavily Search tool for AI-optimized web search with structured results.", - ), + adapter_type="tavily-search", + provider_type="remote::tavily-search", + module="llama_stack.providers.remote.tool_runtime.tavily_search", + config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", + description="Tavily Search tool for AI-optimized web search with structured results.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="wolfram-alpha", - module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", - config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", - description="Wolfram Alpha tool for computational knowledge and mathematical calculations.", - ), + adapter_type="wolfram-alpha", + provider_type="remote::wolfram-alpha", + module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", + config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", + description="Wolfram Alpha tool for computational knowledge and mathematical calculations.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="model-context-protocol", - module="llama_stack.providers.remote.tool_runtime.model_context_protocol", - config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", - pip_packages=["mcp>=1.8.1"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", - description="Model Context Protocol (MCP) tool for standardized tool calling and context management.", - ), + adapter_type="model-context-protocol", + provider_type="remote::model-context-protocol", + module="llama_stack.providers.remote.tool_runtime.model_context_protocol", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", + pip_packages=["mcp>=1.8.1"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", + description="Model Context Protocol (MCP) tool for standardized tool calling and context management.", ), ] diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 511734d57..3b82f6eb6 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f Please refer to the sqlite-vec provider documentation. """, ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="chromadb", - pip_packages=["chromadb-client"], - module="llama_stack.providers.remote.vector_io.chroma", - config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", - description=""" + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="chromadb", + provider_type="remote::chromadb", + pip_packages=["chromadb-client"], + module="llama_stack.providers.remote.vector_io.chroma", + config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" [Chroma](https://www.trychroma.com/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. That means you're not limited to storing vectors in memory or in a separate service. @@ -340,9 +341,6 @@ pip install chromadb ## Documentation See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general. """, - ), - api_dependencies=[Api.inference], - optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti """, ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="pgvector", - pip_packages=["psycopg2-binary"], - module="llama_stack.providers.remote.vector_io.pgvector", - config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", - description=""" + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="pgvector", + provider_type="remote::pgvector", + pip_packages=["psycopg2-binary"], + module="llama_stack.providers.remote.vector_io.pgvector", + config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" [PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It allows you to store and query vectors directly in memory. That means you'll get fast and efficient vector retrieval. @@ -495,19 +495,18 @@ docker pull pgvector/pgvector:pg17 ## Documentation See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general. """, - ), + ), + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="weaviate", + provider_type="remote::weaviate", + pip_packages=["weaviate-client"], + module="llama_stack.providers.remote.vector_io.weaviate", + config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", + provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", api_dependencies=[Api.inference], optional_api_dependencies=[Api.files], - ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="weaviate", - pip_packages=["weaviate-client"], - module="llama_stack.providers.remote.vector_io.weaviate", - config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", - provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", - description=""" + description=""" [Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack. It allows you to store and query vectors directly within a Weaviate database. That means you're not limited to storing vectors in memory or in a separate service. @@ -538,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate ## Documentation See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. """, - ), - api_dependencies=[Api.inference], - optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -594,28 +590,29 @@ docker pull qdrant/qdrant See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general. """, ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="qdrant", - pip_packages=["qdrant-client"], - module="llama_stack.providers.remote.vector_io.qdrant", - config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", - description=""" -Please refer to the inline provider documentation. -""", - ), + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="qdrant", + provider_type="remote::qdrant", + pip_packages=["qdrant-client"], + module="llama_stack.providers.remote.vector_io.qdrant", + config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", api_dependencies=[Api.inference], optional_api_dependencies=[Api.files], + description=""" +Please refer to the inline provider documentation. +""", ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="milvus", - pip_packages=["pymilvus>=2.4.10"], - module="llama_stack.providers.remote.vector_io.milvus", - config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", - description=""" + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="milvus", + provider_type="remote::milvus", + pip_packages=["pymilvus>=2.4.10"], + module="llama_stack.providers.remote.vector_io.milvus", + config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" [Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly within a Milvus database. That means you're not limited to storing vectors in memory or in a separate service. @@ -806,9 +803,6 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md). """, - ), - api_dependencies=[Api.inference], - optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, diff --git a/tests/external/kaze.yaml b/tests/external/kaze.yaml index c61ac0e31..1b42f2e14 100644 --- a/tests/external/kaze.yaml +++ b/tests/external/kaze.yaml @@ -1,6 +1,5 @@ -adapter: - adapter_type: kaze - pip_packages: ["tests/external/llama-stack-provider-kaze"] - config_class: llama_stack_provider_kaze.config.KazeProviderConfig - module: llama_stack_provider_kaze +adapter_type: kaze +pip_packages: ["tests/external/llama-stack-provider-kaze"] +config_class: llama_stack_provider_kaze.config.KazeProviderConfig +module: llama_stack_provider_kaze optional_api_dependencies: [] diff --git a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py index 4b3bfb641..de1427bfd 100644 --- a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py +++ b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py @@ -6,7 +6,7 @@ from typing import Protocol -from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec +from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec from llama_stack.schema_utils import webmethod @@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]: api=Api.weather, provider_type="remote::kaze", config_class="llama_stack_provider_kaze.KazeProviderConfig", - adapter=AdapterSpec( - adapter_type="kaze", - module="llama_stack_provider_kaze", - pip_packages=["llama_stack_provider_kaze"], - config_class="llama_stack_provider_kaze.KazeProviderConfig", - ), + adapter_type="kaze", + module="llama_stack_provider_kaze", + pip_packages=["llama_stack_provider_kaze"], ), ] diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index c6c2eb2c7..f24de0644 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -66,10 +66,9 @@ def base_config(tmp_path): def provider_spec_yaml(): """Common provider spec YAML for testing.""" return """ -adapter: - adapter_type: test_provider - config_class: test_provider.config.TestProviderConfig - module: test_provider +adapter_type: test_provider +config_class: test_provider.config.TestProviderConfig +module: test_provider api_dependencies: - safety """ @@ -182,9 +181,9 @@ class TestProviderRegistry: assert Api.inference in registry assert "remote::test_provider" in registry[Api.inference] provider = registry[Api.inference]["remote::test_provider"] - assert provider.adapter.adapter_type == "test_provider" - assert provider.adapter.module == "test_provider" - assert provider.adapter.config_class == "test_provider.config.TestProviderConfig" + assert provider.adapter_type == "test_provider" + assert provider.module == "test_provider" + assert provider.config_class == "test_provider.config.TestProviderConfig" assert Api.safety in provider.api_dependencies def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml): @@ -246,8 +245,7 @@ class TestProviderRegistry: """Test handling of malformed remote provider spec (missing required fields).""" remote_dir, _ = api_directories malformed_spec = """ -adapter: - adapter_type: test_provider +adapter_type: test_provider # Missing required fields api_dependencies: - safety @@ -270,7 +268,7 @@ pip_packages: with open(inline_dir / "malformed.yaml", "w") as f: f.write(malformed_spec) - with pytest.raises(KeyError) as exc_info: + with pytest.raises(ValidationError) as exc_info: get_provider_registry(base_config) assert "config_class" in str(exc_info.value)