From 8422bd102a672c90f3eb189297f8e36332c4d17e Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Thu, 18 Sep 2025 10:10:00 -0400 Subject: [PATCH 1/5] feat: combine ProviderSpec datatypes (#3378) # What does this PR do? currently `RemoteProviderSpec` has an `AdapterSpec` embedded in it. Remove `AdapterSpec`, and put its leftover fields into `RemoteProviderSpec`. Additionally, many of the fields were duplicated between `InlineProviderSpec` and `RemoteProviderSpec`. Move these to `ProviderSpec` so they are shared. Fixup the distro codegen to use `RemoteProviderSpec` directly rather than `remote_provider_spec` which took an AdapterSpec and returned a full provider spec ## Test Plan existing distro tests should pass. Signed-off-by: Charlie Doern --- llama_stack/core/datatypes.py | 4 - llama_stack/core/distribution.py | 22 +- llama_stack/distributions/starter/starter.py | 4 +- llama_stack/providers/datatypes.py | 88 +--- llama_stack/providers/registry/datasetio.py | 41 +- llama_stack/providers/registry/eval.py | 21 +- llama_stack/providers/registry/files.py | 23 +- llama_stack/providers/registry/inference.py | 380 +++++++++--------- .../providers/registry/post_training.py | 17 +- llama_stack/providers/registry/safety.py | 50 ++- .../providers/registry/tool_runtime.py | 88 ++-- llama_stack/providers/registry/vector_io.py | 108 +++-- tests/external/kaze.yaml | 9 +- .../src/llama_stack_api_weather/weather.py | 11 +- tests/unit/distribution/test_distribution.py | 18 +- 15 files changed, 381 insertions(+), 503 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index faaeefd01..b5558c66f 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec): default=None, ) - @property - def pip_packages(self) -> list[str]: - raise AssertionError("Should not be called on AutoRoutedProviderSpec") - # Example: /models, /shields class RoutingTableProviderSpec(ProviderSpec): diff --git a/llama_stack/core/distribution.py b/llama_stack/core/distribution.py index c104b6764..302ecb960 100644 --- a/llama_stack/core/distribution.py +++ b/llama_stack/core/distribution.py @@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec from llama_stack.core.external import load_external_apis from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) logger = get_logger(name=__name__, category="core") @@ -77,27 +76,12 @@ def providable_apis() -> list[Api]: def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec: - adapter = AdapterSpec(**spec_data["adapter"]) - spec = remote_provider_spec( - api=api, - adapter=adapter, - api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])], - ) + spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data) return spec def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: - spec = InlineProviderSpec( - api=api, - provider_type=f"inline::{provider_name}", - pip_packages=spec_data.get("pip_packages", []), - module=spec_data["module"], - config_class=spec_data["config_class"], - api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])], - optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])], - provider_data_validator=spec_data.get("provider_data_validator"), - container_image=spec_data.get("container_image"), - ) + spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data) return spec diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index c2dfe95ad..6bee51ff0 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]: remote_providers = [ provider for provider in available_providers() - if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS + if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS ] inference_providers = [] for provider_spec in remote_providers: - provider_type = provider_spec.adapter.adapter_type + provider_type = provider_spec.adapter_type if provider_type in INFERENCE_PROVIDER_IDS: provider_id = INFERENCE_PROVIDER_IDS[provider_type] diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 5e15dd8e1..c8ff9cecb 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -131,6 +131,15 @@ class ProviderSpec(BaseModel): """, ) + pip_packages: list[str] = Field( + default_factory=list, + description="The pip dependencies needed for this implementation", + ) + + provider_data_validator: str | None = Field( + default=None, + ) + is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.") # used internally by the resolver; this is a hack for now @@ -145,45 +154,8 @@ class RoutingTable(Protocol): async def get_provider_impl(self, routing_key: str) -> Any: ... -# TODO: this can now be inlined into RemoteProviderSpec -@json_schema_type -class AdapterSpec(BaseModel): - adapter_type: str = Field( - ..., - description="Unique identifier for this adapter", - ) - module: str = Field( - default_factory=str, - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - - `get_adapter_impl(config, deps)`: returns the adapter implementation -""", - ) - pip_packages: list[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) - config_class: str = Field( - description="Fully-qualified classname of the config for this provider", - ) - provider_data_validator: str | None = Field( - default=None, - ) - description: str | None = Field( - default=None, - description=""" -A description of the provider. This is used to display in the documentation. -""", - ) - - @json_schema_type class InlineProviderSpec(ProviderSpec): - pip_packages: list[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) container_image: str | None = Field( default=None, description=""" @@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack If a provider depends on other providers, the dependencies MUST NOT specify a container image. """, ) - # module field is inherited from ProviderSpec - provider_data_validator: str | None = Field( - default=None, - ) description: str | None = Field( default=None, description=""" @@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel): @json_schema_type class RemoteProviderSpec(ProviderSpec): - adapter: AdapterSpec = Field( + adapter_type: str = Field( + ..., + description="Unique identifier for this adapter", + ) + + description: str | None = Field( + default=None, description=""" -If some code is needed to convert the remote responses into Llama Stack compatible -API responses, specify the adapter here. +A description of the provider. This is used to display in the documentation. """, ) @@ -234,33 +207,6 @@ API responses, specify the adapter here. def container_image(self) -> str | None: return None - # module field is inherited from ProviderSpec - - @property - def pip_packages(self) -> list[str]: - return self.adapter.pip_packages - - @property - def provider_data_validator(self) -> str | None: - return self.adapter.provider_data_validator - - -def remote_provider_spec( - api: Api, - adapter: AdapterSpec, - api_dependencies: list[Api] | None = None, - optional_api_dependencies: list[Api] | None = None, -) -> RemoteProviderSpec: - return RemoteProviderSpec( - api=api, - provider_type=f"remote::{adapter.adapter_type}", - config_class=adapter.config_class, - module=adapter.module, - adapter=adapter, - api_dependencies=api_dependencies or [], - optional_api_dependencies=optional_api_dependencies or [], - ) - class HealthStatus(StrEnum): OK = "OK" diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index f641b4ce3..a9feb0bac 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]: api_dependencies=[], description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.datasetio, - adapter=AdapterSpec( - adapter_type="huggingface", - pip_packages=[ - "datasets>=4.0.0", - ], - module="llama_stack.providers.remote.datasetio.huggingface", - config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", - description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.", - ), + adapter_type="huggingface", + provider_type="remote::huggingface", + pip_packages=[ + "datasets>=4.0.0", + ], + module="llama_stack.providers.remote.datasetio.huggingface", + config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", + description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.datasetio, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=[ - "datasets>=4.0.0", - ], - module="llama_stack.providers.remote.datasetio.nvidia", - config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig", - description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + module="llama_stack.providers.remote.datasetio.nvidia", + config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig", + pip_packages=[ + "datasets>=4.0.0", + ], + description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.", ), ] diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 9f0d17916..4ef0bb41f 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec def available_providers() -> list[ProviderSpec]: @@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]: ], description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.eval, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=[ - "requests", - ], - module="llama_stack.providers.remote.eval.nvidia", - config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", - description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.", - ), + adapter_type="nvidia", + pip_packages=[ + "requests", + ], + provider_type="remote::nvidia", + module="llama_stack.providers.remote.eval.nvidia", + config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", + description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.", api_dependencies=[ Api.datasetio, Api.datasets, diff --git a/llama_stack/providers/registry/files.py b/llama_stack/providers/registry/files.py index ebe90310c..9acabfacd 100644 --- a/llama_stack/providers/registry/files.py +++ b/llama_stack/providers/registry/files.py @@ -4,13 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.datatypes import ( - AdapterSpec, - Api, - InlineProviderSpec, - ProviderSpec, - remote_provider_spec, -) +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages @@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", description="Local filesystem-based file storage provider for managing files and documents locally.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.files, - adapter=AdapterSpec( - adapter_type="s3", - pip_packages=["boto3"] + sql_store_pip_packages, - module="llama_stack.providers.remote.files.s3", - config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", - description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", - ), + provider_type="remote::s3", + adapter_type="s3", + pip_packages=["boto3"] + sql_store_pip_packages, + module="llama_stack.providers.remote.files.s3", + config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", + description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", ), ] diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 0eb4cf104..658611698 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) META_REFERENCE_DEPS = [ @@ -49,177 +48,167 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", description="Sentence Transformers inference provider for text embeddings and similarity search.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="cerebras", - pip_packages=[ - "cerebras_cloud_sdk", - ], - module="llama_stack.providers.remote.inference.cerebras", - config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", - description="Cerebras inference provider for running models on Cerebras Cloud platform.", - ), + adapter_type="cerebras", + provider_type="remote::cerebras", + pip_packages=[ + "cerebras_cloud_sdk", + ], + module="llama_stack.providers.remote.inference.cerebras", + config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", + description="Cerebras inference provider for running models on Cerebras Cloud platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="ollama", - pip_packages=["ollama", "aiohttp", "h11>=0.16.0"], - config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", - module="llama_stack.providers.remote.inference.ollama", - description="Ollama inference provider for running local models through the Ollama runtime.", - ), + adapter_type="ollama", + provider_type="remote::ollama", + pip_packages=["ollama", "aiohttp", "h11>=0.16.0"], + config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", + module="llama_stack.providers.remote.inference.ollama", + description="Ollama inference provider for running local models through the Ollama runtime.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="vllm", - pip_packages=[], - module="llama_stack.providers.remote.inference.vllm", - config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", - provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", - description="Remote vLLM inference provider for connecting to vLLM servers.", - ), + adapter_type="vllm", + provider_type="remote::vllm", + pip_packages=[], + module="llama_stack.providers.remote.inference.vllm", + config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", + provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", + description="Remote vLLM inference provider for connecting to vLLM servers.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="tgi", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", - description="Text Generation Inference (TGI) provider for HuggingFace model serving.", - ), + adapter_type="tgi", + provider_type="remote::tgi", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", + description="Text Generation Inference (TGI) provider for HuggingFace model serving.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="hf::serverless", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", - description="HuggingFace Inference API serverless provider for on-demand model inference.", - ), + adapter_type="hf::serverless", + provider_type="remote::hf::serverless", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", + description="HuggingFace Inference API serverless provider for on-demand model inference.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="hf::endpoint", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", - description="HuggingFace Inference Endpoints provider for dedicated model serving.", - ), + provider_type="remote::hf::endpoint", + adapter_type="hf::endpoint", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", + description="HuggingFace Inference Endpoints provider for dedicated model serving.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="fireworks", - pip_packages=[ - "fireworks-ai<=0.17.16", - ], - module="llama_stack.providers.remote.inference.fireworks", - config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", - description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.", - ), + adapter_type="fireworks", + provider_type="remote::fireworks", + pip_packages=[ + "fireworks-ai<=0.17.16", + ], + module="llama_stack.providers.remote.inference.fireworks", + config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", + description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="together", - pip_packages=[ - "together", - ], - module="llama_stack.providers.remote.inference.together", - config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", - description="Together AI inference provider for open-source models and collaborative AI development.", - ), + adapter_type="together", + provider_type="remote::together", + pip_packages=[ + "together", + ], + module="llama_stack.providers.remote.inference.together", + config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", + description="Together AI inference provider for open-source models and collaborative AI development.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="bedrock", - pip_packages=["boto3"], - module="llama_stack.providers.remote.inference.bedrock", - config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", - description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", - ), + adapter_type="bedrock", + provider_type="remote::bedrock", + pip_packages=["boto3"], + module="llama_stack.providers.remote.inference.bedrock", + config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", + description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="databricks", - pip_packages=[], - module="llama_stack.providers.remote.inference.databricks", - config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", - description="Databricks inference provider for running models on Databricks' unified analytics platform.", - ), + adapter_type="databricks", + provider_type="remote::databricks", + pip_packages=[], + module="llama_stack.providers.remote.inference.databricks", + config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", + description="Databricks inference provider for running models on Databricks' unified analytics platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=[], - module="llama_stack.providers.remote.inference.nvidia", - config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", - description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + pip_packages=[], + module="llama_stack.providers.remote.inference.nvidia", + config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", + description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="runpod", - pip_packages=[], - module="llama_stack.providers.remote.inference.runpod", - config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", - description="RunPod inference provider for running models on RunPod's cloud GPU platform.", - ), + adapter_type="runpod", + provider_type="remote::runpod", + pip_packages=[], + module="llama_stack.providers.remote.inference.runpod", + config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", + description="RunPod inference provider for running models on RunPod's cloud GPU platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="openai", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.openai", - config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", - provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", - description="OpenAI inference provider for accessing GPT models and other OpenAI services.", - ), + adapter_type="openai", + provider_type="remote::openai", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.openai", + config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", + description="OpenAI inference provider for accessing GPT models and other OpenAI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="anthropic", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.anthropic", - config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", - provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", - description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.", - ), + adapter_type="anthropic", + provider_type="remote::anthropic", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.anthropic", + config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", + provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", + description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="gemini", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.gemini", - config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", - provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", - description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", - ), + adapter_type="gemini", + provider_type="remote::gemini", + pip_packages=[ + "litellm", + ], + module="llama_stack.providers.remote.inference.gemini", + config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", + provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", + description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="vertexai", - pip_packages=["litellm", "google-cloud-aiplatform"], - module="llama_stack.providers.remote.inference.vertexai", - config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", - provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", - description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + adapter_type="vertexai", + provider_type="remote::vertexai", + pip_packages=[ + "litellm", + "google-cloud-aiplatform", + ], + module="llama_stack.providers.remote.inference.vertexai", + config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", + description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: • Enterprise-grade security: Uses Google Cloud's security controls and IAM • Better integration: Seamless integration with other Google Cloud services @@ -239,76 +228,73 @@ Available Models: - vertex_ai/gemini-2.0-flash - vertex_ai/gemini-2.5-flash - vertex_ai/gemini-2.5-pro""", - ), ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="groq", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.groq", - config_class="llama_stack.providers.remote.inference.groq.GroqConfig", - provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", - description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", - ), + adapter_type="groq", + provider_type="remote::groq", + pip_packages=[ + "litellm", + ], + module="llama_stack.providers.remote.inference.groq", + config_class="llama_stack.providers.remote.inference.groq.GroqConfig", + provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", + description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="llama-openai-compat", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.llama_openai_compat", - config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", - provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", - description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", - ), + adapter_type="llama-openai-compat", + provider_type="remote::llama-openai-compat", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.llama_openai_compat", + config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", + provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", + description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="sambanova", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.sambanova", - config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", - description="SambaNova inference provider for running models on SambaNova's dataflow architecture.", - ), + adapter_type="sambanova", + provider_type="remote::sambanova", + pip_packages=[ + "litellm", + ], + module="llama_stack.providers.remote.inference.sambanova", + config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", + description="SambaNova inference provider for running models on SambaNova's dataflow architecture.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="passthrough", - pip_packages=[], - module="llama_stack.providers.remote.inference.passthrough", - config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", - description="Passthrough inference provider for connecting to any external inference service not directly supported.", - ), + adapter_type="passthrough", + provider_type="remote::passthrough", + pip_packages=[], + module="llama_stack.providers.remote.inference.passthrough", + config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", + description="Passthrough inference provider for connecting to any external inference service not directly supported.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="watsonx", - pip_packages=["ibm_watsonx_ai"], - module="llama_stack.providers.remote.inference.watsonx", - config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", - provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", - description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", - ), + adapter_type="watsonx", + provider_type="remote::watsonx", + pip_packages=["ibm_watsonx_ai"], + module="llama_stack.providers.remote.inference.watsonx", + config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="azure", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.azure", - config_class="llama_stack.providers.remote.inference.azure.AzureConfig", - provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator", - description=""" + provider_type="remote::azure", + adapter_type="azure", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.azure", + config_class="llama_stack.providers.remote.inference.azure.AzureConfig", + provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator", + description=""" Azure OpenAI inference provider for accessing GPT models and other Azure services. Provider documentation https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview """, - ), ), ] diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 47aeb401e..2092e3b2d 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -7,7 +7,7 @@ from typing import cast -from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec # We provide two versions of these providers so that distributions can package the appropriate version of torch. # The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. @@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]: ], description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.post_training, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=["requests", "aiohttp"], - module="llama_stack.providers.remote.post_training.nvidia", - config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", - description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + pip_packages=["requests", "aiohttp"], + module="llama_stack.providers.remote.post_training.nvidia", + config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", + description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.", ), ] diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 9dd791bd8..b30074398 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.safety, - adapter=AdapterSpec( - adapter_type="bedrock", - pip_packages=["boto3"], - module="llama_stack.providers.remote.safety.bedrock", - config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", - description="AWS Bedrock safety provider for content moderation using AWS's safety services.", - ), + adapter_type="bedrock", + provider_type="remote::bedrock", + pip_packages=["boto3"], + module="llama_stack.providers.remote.safety.bedrock", + config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", + description="AWS Bedrock safety provider for content moderation using AWS's safety services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.safety, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=["requests"], - module="llama_stack.providers.remote.safety.nvidia", - config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", - description="NVIDIA's safety provider for content moderation and safety filtering.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + pip_packages=["requests"], + module="llama_stack.providers.remote.safety.nvidia", + config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", + description="NVIDIA's safety provider for content moderation and safety filtering.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.safety, - adapter=AdapterSpec( - adapter_type="sambanova", - pip_packages=["litellm", "requests"], - module="llama_stack.providers.remote.safety.sambanova", - config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", - provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", - description="SambaNova's safety provider for content moderation and safety filtering.", - ), + adapter_type="sambanova", + provider_type="remote::sambanova", + pip_packages=["litellm", "requests"], + module="llama_stack.providers.remote.safety.sambanova", + config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", + provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", + description="SambaNova's safety provider for content moderation and safety filtering.", ), ] diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 5a58fa7af..ad8c31dfd 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]: api_dependencies=[Api.vector_io, Api.inference, Api.files], description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="brave-search", - module="llama_stack.providers.remote.tool_runtime.brave_search", - config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", - description="Brave Search tool for web search capabilities with privacy-focused results.", - ), + adapter_type="brave-search", + provider_type="remote::brave-search", + module="llama_stack.providers.remote.tool_runtime.brave_search", + config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", + description="Brave Search tool for web search capabilities with privacy-focused results.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="bing-search", - module="llama_stack.providers.remote.tool_runtime.bing_search", - config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", - description="Bing Search tool for web search capabilities using Microsoft's search engine.", - ), + adapter_type="bing-search", + provider_type="remote::bing-search", + module="llama_stack.providers.remote.tool_runtime.bing_search", + config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", + description="Bing Search tool for web search capabilities using Microsoft's search engine.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="tavily-search", - module="llama_stack.providers.remote.tool_runtime.tavily_search", - config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", - description="Tavily Search tool for AI-optimized web search with structured results.", - ), + adapter_type="tavily-search", + provider_type="remote::tavily-search", + module="llama_stack.providers.remote.tool_runtime.tavily_search", + config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", + description="Tavily Search tool for AI-optimized web search with structured results.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="wolfram-alpha", - module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", - config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", - description="Wolfram Alpha tool for computational knowledge and mathematical calculations.", - ), + adapter_type="wolfram-alpha", + provider_type="remote::wolfram-alpha", + module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", + config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", + description="Wolfram Alpha tool for computational knowledge and mathematical calculations.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="model-context-protocol", - module="llama_stack.providers.remote.tool_runtime.model_context_protocol", - config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", - pip_packages=["mcp>=1.8.1"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", - description="Model Context Protocol (MCP) tool for standardized tool calling and context management.", - ), + adapter_type="model-context-protocol", + provider_type="remote::model-context-protocol", + module="llama_stack.providers.remote.tool_runtime.model_context_protocol", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", + pip_packages=["mcp>=1.8.1"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", + description="Model Context Protocol (MCP) tool for standardized tool calling and context management.", ), ] diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 511734d57..3b82f6eb6 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f Please refer to the sqlite-vec provider documentation. """, ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="chromadb", - pip_packages=["chromadb-client"], - module="llama_stack.providers.remote.vector_io.chroma", - config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", - description=""" + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="chromadb", + provider_type="remote::chromadb", + pip_packages=["chromadb-client"], + module="llama_stack.providers.remote.vector_io.chroma", + config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" [Chroma](https://www.trychroma.com/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. That means you're not limited to storing vectors in memory or in a separate service. @@ -340,9 +341,6 @@ pip install chromadb ## Documentation See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general. """, - ), - api_dependencies=[Api.inference], - optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti """, ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="pgvector", - pip_packages=["psycopg2-binary"], - module="llama_stack.providers.remote.vector_io.pgvector", - config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", - description=""" + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="pgvector", + provider_type="remote::pgvector", + pip_packages=["psycopg2-binary"], + module="llama_stack.providers.remote.vector_io.pgvector", + config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" [PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It allows you to store and query vectors directly in memory. That means you'll get fast and efficient vector retrieval. @@ -495,19 +495,18 @@ docker pull pgvector/pgvector:pg17 ## Documentation See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general. """, - ), + ), + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="weaviate", + provider_type="remote::weaviate", + pip_packages=["weaviate-client"], + module="llama_stack.providers.remote.vector_io.weaviate", + config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", + provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", api_dependencies=[Api.inference], optional_api_dependencies=[Api.files], - ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="weaviate", - pip_packages=["weaviate-client"], - module="llama_stack.providers.remote.vector_io.weaviate", - config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", - provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", - description=""" + description=""" [Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack. It allows you to store and query vectors directly within a Weaviate database. That means you're not limited to storing vectors in memory or in a separate service. @@ -538,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate ## Documentation See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. """, - ), - api_dependencies=[Api.inference], - optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -594,28 +590,29 @@ docker pull qdrant/qdrant See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general. """, ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="qdrant", - pip_packages=["qdrant-client"], - module="llama_stack.providers.remote.vector_io.qdrant", - config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", - description=""" -Please refer to the inline provider documentation. -""", - ), + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="qdrant", + provider_type="remote::qdrant", + pip_packages=["qdrant-client"], + module="llama_stack.providers.remote.vector_io.qdrant", + config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", api_dependencies=[Api.inference], optional_api_dependencies=[Api.files], + description=""" +Please refer to the inline provider documentation. +""", ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="milvus", - pip_packages=["pymilvus>=2.4.10"], - module="llama_stack.providers.remote.vector_io.milvus", - config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", - description=""" + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="milvus", + provider_type="remote::milvus", + pip_packages=["pymilvus>=2.4.10"], + module="llama_stack.providers.remote.vector_io.milvus", + config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" [Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly within a Milvus database. That means you're not limited to storing vectors in memory or in a separate service. @@ -806,9 +803,6 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md). """, - ), - api_dependencies=[Api.inference], - optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, diff --git a/tests/external/kaze.yaml b/tests/external/kaze.yaml index c61ac0e31..1b42f2e14 100644 --- a/tests/external/kaze.yaml +++ b/tests/external/kaze.yaml @@ -1,6 +1,5 @@ -adapter: - adapter_type: kaze - pip_packages: ["tests/external/llama-stack-provider-kaze"] - config_class: llama_stack_provider_kaze.config.KazeProviderConfig - module: llama_stack_provider_kaze +adapter_type: kaze +pip_packages: ["tests/external/llama-stack-provider-kaze"] +config_class: llama_stack_provider_kaze.config.KazeProviderConfig +module: llama_stack_provider_kaze optional_api_dependencies: [] diff --git a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py index 4b3bfb641..de1427bfd 100644 --- a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py +++ b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py @@ -6,7 +6,7 @@ from typing import Protocol -from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec +from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec from llama_stack.schema_utils import webmethod @@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]: api=Api.weather, provider_type="remote::kaze", config_class="llama_stack_provider_kaze.KazeProviderConfig", - adapter=AdapterSpec( - adapter_type="kaze", - module="llama_stack_provider_kaze", - pip_packages=["llama_stack_provider_kaze"], - config_class="llama_stack_provider_kaze.KazeProviderConfig", - ), + adapter_type="kaze", + module="llama_stack_provider_kaze", + pip_packages=["llama_stack_provider_kaze"], ), ] diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index c6c2eb2c7..f24de0644 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -66,10 +66,9 @@ def base_config(tmp_path): def provider_spec_yaml(): """Common provider spec YAML for testing.""" return """ -adapter: - adapter_type: test_provider - config_class: test_provider.config.TestProviderConfig - module: test_provider +adapter_type: test_provider +config_class: test_provider.config.TestProviderConfig +module: test_provider api_dependencies: - safety """ @@ -182,9 +181,9 @@ class TestProviderRegistry: assert Api.inference in registry assert "remote::test_provider" in registry[Api.inference] provider = registry[Api.inference]["remote::test_provider"] - assert provider.adapter.adapter_type == "test_provider" - assert provider.adapter.module == "test_provider" - assert provider.adapter.config_class == "test_provider.config.TestProviderConfig" + assert provider.adapter_type == "test_provider" + assert provider.module == "test_provider" + assert provider.config_class == "test_provider.config.TestProviderConfig" assert Api.safety in provider.api_dependencies def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml): @@ -246,8 +245,7 @@ class TestProviderRegistry: """Test handling of malformed remote provider spec (missing required fields).""" remote_dir, _ = api_directories malformed_spec = """ -adapter: - adapter_type: test_provider +adapter_type: test_provider # Missing required fields api_dependencies: - safety @@ -270,7 +268,7 @@ pip_packages: with open(inline_dir / "malformed.yaml", "w") as f: f.write(malformed_spec) - with pytest.raises(KeyError) as exc_info: + with pytest.raises(ValidationError) as exc_info: get_provider_registry(base_config) assert "config_class" in str(exc_info.value) From 4c2fcb6b515d9f3cbb068d979eba694bc10e1ff0 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 18 Sep 2025 21:11:13 -0700 Subject: [PATCH 2/5] chore: refactor server.main (#3462) # What does this PR do? As shown in #3421, we can scale stack to handle more RPS with k8s replicas. This PR enables multi process stack with uvicorn --workers so that we can achieve the same scaling without being in k8s. To achieve that we refactor main to split out the app construction logic. This method needs to be non-async. We created a new `Stack` class to house impls and have a `start()` method to be called in lifespan to start background tasks instead of starting them in the old `construct_stack`. This way we avoid having to manage an event loop manually. ## Test Plan CI > uv run --with llama-stack python -m llama_stack.core.server.server benchmarking/k8s-benchmark/stack_run_config.yaml works. > LLAMA_STACK_CONFIG=benchmarking/k8s-benchmark/stack_run_config.yaml uv run uvicorn llama_stack.core.server.server:create_app --port 8321 --workers 4 works. --- benchmarking/k8s-benchmark/apply.sh | 5 +- .../k8s-benchmark/stack-configmap.yaml | 9 ++ .../k8s-benchmark/stack-k8s.yaml.template | 13 +- llama_stack/core/library_client.py | 8 +- llama_stack/core/server/server.py | 152 +++++++++++------- llama_stack/core/stack.py | 142 ++++++++-------- .../test_library_client_initialization.py | 50 ++++-- 7 files changed, 233 insertions(+), 146 deletions(-) diff --git a/benchmarking/k8s-benchmark/apply.sh b/benchmarking/k8s-benchmark/apply.sh index 4f2270da8..6e6607663 100755 --- a/benchmarking/k8s-benchmark/apply.sh +++ b/benchmarking/k8s-benchmark/apply.sh @@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B -export MOCK_INFERENCE_MODEL=mock-inference - -export MOCK_INFERENCE_URL=openai-mock-service:8080 - export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL +export LLAMA_STACK_WORKERS=4 set -euo pipefail set -x diff --git a/benchmarking/k8s-benchmark/stack-configmap.yaml b/benchmarking/k8s-benchmark/stack-configmap.yaml index bf6109b68..286ba5f77 100644 --- a/benchmarking/k8s-benchmark/stack-configmap.yaml +++ b/benchmarking/k8s-benchmark/stack-configmap.yaml @@ -5,6 +5,7 @@ data: image_name: kubernetes-benchmark-demo apis: - agents + - files - inference - files - safety @@ -23,6 +24,14 @@ data: - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db vector_io: - provider_id: ${env.ENABLE_CHROMADB:+chromadb} provider_type: remote::chromadb diff --git a/benchmarking/k8s-benchmark/stack-k8s.yaml.template b/benchmarking/k8s-benchmark/stack-k8s.yaml.template index 9cb1e5be3..8842c0bea 100644 --- a/benchmarking/k8s-benchmark/stack-k8s.yaml.template +++ b/benchmarking/k8s-benchmark/stack-k8s.yaml.template @@ -52,9 +52,20 @@ spec: value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 - name: VLLM_TLS_VERIFY value: "false" - command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] + - name: LLAMA_STACK_LOGGING + value: "all=WARNING" + - name: LLAMA_STACK_CONFIG + value: "/etc/config/stack_run_config.yaml" + - name: LLAMA_STACK_WORKERS + value: "${LLAMA_STACK_WORKERS}" + command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"] ports: - containerPort: 8323 + resources: + requests: + cpu: "${LLAMA_STACK_WORKERS}" + limits: + cpu: "${LLAMA_STACK_WORKERS}" volumeMounts: - name: llama-storage mountPath: /root/.llama diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index ea5a2ac8e..e722e4de6 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -40,7 +40,7 @@ from llama_stack.core.request_headers import ( from llama_stack.core.resolver import ProviderRegistry from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls from llama_stack.core.stack import ( - construct_stack, + Stack, get_stack_run_config_from_distro, replace_env_vars, ) @@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): try: self.route_impls = None - self.impls = await construct_stack(self.config, self.custom_provider_registry) + + stack = Stack(self.config, self.custom_provider_registry) + await stack.initialize() + self.impls = stack.impls except ModuleNotFoundError as _e: cprint(_e.msg, color="red", file=sys.stderr) cprint( @@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) raise _e + assert self.impls is not None if Api.telemetry in self.impls: setup_logger(self.impls[Api.telemetry]) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index d3e875fec..9cca42268 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -6,6 +6,7 @@ import argparse import asyncio +import concurrent.futures import functools import inspect import json @@ -50,17 +51,15 @@ from llama_stack.core.request_headers import ( request_provider_data_context, user_from_scope, ) -from llama_stack.core.resolver import InvalidProviderError from llama_stack.core.server.routes import ( find_matching_route, get_all_api_routes, initialize_route_impls, ) from llama_stack.core.stack import ( + Stack, cast_image_name_to_string, - construct_stack, replace_env_vars, - shutdown_stack, validate_env_pair, ) from llama_stack.core.utils.config import redact_sensitive_fields @@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro ) -async def shutdown(app): - """Initiate a graceful shutdown of the application. - - Handled by the lifespan context manager. The shutdown process involves - shutting down all implementations registered in the application. +class StackApp(FastAPI): """ - await shutdown_stack(app.__llama_stack_impls__) + A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can + start background tasks (e.g. refresh model registry periodically) from the lifespan context manager. + """ + + def __init__(self, config: StackRunConfig, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stack: Stack = Stack(config) + + # This code is called from a running event loop managed by uvicorn so we cannot simply call + # asyncio.run() to initialize the stack. We cannot await either since this is not an async + # function. + # As a workaround, we use a thread pool executor to run the initialize() method + # in a separate thread. + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.stack.initialize()) + future.result() @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: StackApp): logger.info("Starting up") + assert app.stack is not None + app.stack.create_registry_refresh_task() yield logger.info("Shutting down") - await shutdown(app) + await app.stack.shutdown() def is_streaming_request(func_name: str, request: Request, **kwargs): @@ -386,73 +398,61 @@ class ClientVersionMiddleware: return await self.app(scope, receive, send) -def main(args: argparse.Namespace | None = None): - """Start the LlamaStack server.""" - parser = argparse.ArgumentParser(description="Start the LlamaStack server.") +def create_app( + config_file: str | None = None, + env_vars: list[str] | None = None, +) -> StackApp: + """Create and configure the FastAPI application. - add_config_distro_args(parser) - parser.add_argument( - "--port", - type=int, - default=int(os.getenv("LLAMA_STACK_PORT", 8321)), - help="Port to listen on", - ) - parser.add_argument( - "--env", - action="append", - help="Environment variables in KEY=value format. Can be specified multiple times.", - ) + Args: + config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution. + env_vars: List of environment variables in KEY=value format. + disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var. - # Determine whether the server args are being passed by the "run" command, if this is the case - # the args will be passed as a Namespace object to the main function, otherwise they will be - # parsed from the command line - if args is None: - args = parser.parse_args() + Returns: + Configured StackApp instance. + """ + config_file = config_file or os.getenv("LLAMA_STACK_CONFIG") + if config_file is None: + raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set") - config_or_distro = get_config_from_args(args) - config_file = resolve_config_or_distro(config_or_distro, Mode.RUN) + config_file = resolve_config_or_distro(config_file, Mode.RUN) + # Load and process configuration logger_config = None with open(config_file) as fp: config_contents = yaml.safe_load(fp) if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): logger_config = LoggingConfig(**cfg) logger = get_logger(name=__name__, category="core::server", config=logger_config) - if args.env: - for env_pair in args.env: + + if env_vars: + for env_pair in env_vars: try: key, value = validate_env_pair(env_pair) - logger.info(f"Setting CLI environment variable {key} => {value}") + logger.info(f"Setting environment variable {key} => {value}") os.environ[key] = value except ValueError as e: logger.error(f"Error: {str(e)}") - sys.exit(1) + raise ValueError(f"Invalid environment variable format: {env_pair}") from e + config = replace_env_vars(config_contents) config = StackRunConfig(**cast_image_name_to_string(config)) _log_run_config(run_config=config) - app = FastAPI( + app = StackApp( lifespan=lifespan, docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json", + config=config, ) if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) - try: - # Create and set the event loop that will be used for both construction and server runtime - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Construct the stack in the persistent event loop - impls = loop.run_until_complete(construct_stack(config)) - - except InvalidProviderError as e: - logger.error(f"Error: {str(e)}") - sys.exit(1) + impls = app.stack.impls if config.server.auth: logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") @@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None): app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) - app.__llama_stack_impls__ = impls app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) + return app + + +def main(args: argparse.Namespace | None = None): + """Start the LlamaStack server.""" + parser = argparse.ArgumentParser(description="Start the LlamaStack server.") + + add_config_distro_args(parser) + parser.add_argument( + "--port", + type=int, + default=int(os.getenv("LLAMA_STACK_PORT", 8321)), + help="Port to listen on", + ) + parser.add_argument( + "--env", + action="append", + help="Environment variables in KEY=value format. Can be specified multiple times.", + ) + + # Determine whether the server args are being passed by the "run" command, if this is the case + # the args will be passed as a Namespace object to the main function, otherwise they will be + # parsed from the command line + if args is None: + args = parser.parse_args() + + config_or_distro = get_config_from_args(args) + + try: + app = create_app( + config_file=config_or_distro, + env_vars=args.env, + ) + except Exception as e: + logger.error(f"Error creating app: {str(e)}") + sys.exit(1) + + config_file = resolve_config_or_distro(config_or_distro, Mode.RUN) + with open(config_file) as fp: + config_contents = yaml.safe_load(fp) + if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): + logger_config = LoggingConfig(**cfg) + else: + logger_config = None + config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents))) + import uvicorn # Configure SSL if certificates are provided @@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None): if ssl_config: uvicorn_config.update(ssl_config) - # Run uvicorn in the existing event loop to preserve background tasks # We need to catch KeyboardInterrupt because uvicorn's signal handling # re-raises SIGINT signals using signal.raise_signal(), which Python # converts to KeyboardInterrupt. Without this catch, we'd get a confusing @@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None): # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own # signal handling but this is quite intrusive and not worth the effort. try: - loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) + asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) except (KeyboardInterrupt, SystemExit): logger.info("Received interrupt signal, shutting down gracefully...") - finally: - if not loop.is_closed(): - logger.debug("Closing event loop") - loop.close() def _log_run_config(run_config: StackRunConfig): diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 7ab8d2c64..a6c5093eb 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -315,78 +315,84 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf impls[Api.prompts] = prompts_impl -# Produces a stack of providers for the given run config. Not all APIs may be -# asked for in the run config. -async def construct_stack( - run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None -) -> dict[Api, Any]: - if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: - from llama_stack.testing.inference_recorder import setup_inference_recording +class Stack: + def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None): + self.run_config = run_config + self.provider_registry = provider_registry + self.impls = None + + # Produces a stack of providers for the given run config. Not all APIs may be + # asked for in the run config. + async def initialize(self): + if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: + from llama_stack.testing.inference_recorder import setup_inference_recording + + global TEST_RECORDING_CONTEXT + TEST_RECORDING_CONTEXT = setup_inference_recording() + if TEST_RECORDING_CONTEXT: + TEST_RECORDING_CONTEXT.__enter__() + logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") + + dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name) + policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else [] + impls = await resolve_impls( + self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy + ) + + # Add internal implementations after all other providers are resolved + add_internal_implementations(impls, self.run_config) + + if Api.prompts in impls: + await impls[Api.prompts].initialize() + + await register_resources(self.run_config, impls) + + await refresh_registry_once(impls) + self.impls = impls + + def create_registry_refresh_task(self): + assert self.impls is not None, "Must call initialize() before starting" + + global REGISTRY_REFRESH_TASK + REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls)) + + def cb(task): + import traceback + + if task.cancelled(): + logger.error("Model refresh task cancelled") + elif task.exception(): + logger.error(f"Model refresh task failed: {task.exception()}") + traceback.print_exception(task.exception()) + else: + logger.debug("Model refresh task completed") + + REGISTRY_REFRESH_TASK.add_done_callback(cb) + + async def shutdown(self): + for impl in self.impls.values(): + impl_name = impl.__class__.__name__ + logger.info(f"Shutting down {impl_name}") + try: + if hasattr(impl, "shutdown"): + await asyncio.wait_for(impl.shutdown(), timeout=5) + else: + logger.warning(f"No shutdown method for {impl_name}") + except TimeoutError: + logger.exception(f"Shutdown timeout for {impl_name}") + except (Exception, asyncio.CancelledError) as e: + logger.exception(f"Failed to shutdown {impl_name}: {e}") global TEST_RECORDING_CONTEXT - TEST_RECORDING_CONTEXT = setup_inference_recording() if TEST_RECORDING_CONTEXT: - TEST_RECORDING_CONTEXT.__enter__() - logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") + try: + TEST_RECORDING_CONTEXT.__exit__(None, None, None) + except Exception as e: + logger.error(f"Error during inference recording cleanup: {e}") - dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) - policy = run_config.server.auth.access_policy if run_config.server.auth else [] - impls = await resolve_impls( - run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy - ) - - # Add internal implementations after all other providers are resolved - add_internal_implementations(impls, run_config) - - if Api.prompts in impls: - await impls[Api.prompts].initialize() - - await register_resources(run_config, impls) - - await refresh_registry_once(impls) - - global REGISTRY_REFRESH_TASK - REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls)) - - def cb(task): - import traceback - - if task.cancelled(): - logger.error("Model refresh task cancelled") - elif task.exception(): - logger.error(f"Model refresh task failed: {task.exception()}") - traceback.print_exception(task.exception()) - else: - logger.debug("Model refresh task completed") - - REGISTRY_REFRESH_TASK.add_done_callback(cb) - return impls - - -async def shutdown_stack(impls: dict[Api, Any]): - for impl in impls.values(): - impl_name = impl.__class__.__name__ - logger.info(f"Shutting down {impl_name}") - try: - if hasattr(impl, "shutdown"): - await asyncio.wait_for(impl.shutdown(), timeout=5) - else: - logger.warning(f"No shutdown method for {impl_name}") - except TimeoutError: - logger.exception(f"Shutdown timeout for {impl_name}") - except (Exception, asyncio.CancelledError) as e: - logger.exception(f"Failed to shutdown {impl_name}: {e}") - - global TEST_RECORDING_CONTEXT - if TEST_RECORDING_CONTEXT: - try: - TEST_RECORDING_CONTEXT.__exit__(None, None, None) - except Exception as e: - logger.error(f"Error during inference recording cleanup: {e}") - - global REGISTRY_REFRESH_TASK - if REGISTRY_REFRESH_TASK: - REGISTRY_REFRESH_TASK.cancel() + global REGISTRY_REFRESH_TASK + if REGISTRY_REFRESH_TASK: + REGISTRY_REFRESH_TASK.cancel() async def refresh_registry_once(impls: dict[Api, Any]): diff --git a/tests/unit/distribution/test_library_client_initialization.py b/tests/unit/distribution/test_library_client_initialization.py index b7e7a1857..b01a5c3e2 100644 --- a/tests/unit/distribution/test_library_client_initialization.py +++ b/tests/unit/distribution/test_library_client_initialization.py @@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = LlamaStackAsLibraryClient("ci-tests") @@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = AsyncLlamaStackAsLibraryClient("ci-tests") @@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = LlamaStackAsLibraryClient("ci-tests") @@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = AsyncLlamaStackAsLibraryClient("ci-tests") @@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) sync_client = LlamaStackAsLibraryClient("ci-tests") From 9378bdca43e886c7ba3b79f017693726bed9b740 Mon Sep 17 00:00:00 2001 From: adam-d-young Date: Fri, 19 Sep 2025 10:41:26 -0500 Subject: [PATCH 3/5] docs: Fix incorrect vector_db_id usage in RAG tutorial (#3444) # What does this PR do? This PR fixes a blocking issue in the detailed RAG tutorial where the code fails with a 400 Bad Request error. The root cause is that recent versions of Llama-Stack ignore the client-generated vector_db_id and assign a new server-side ID. The tutorial was not updated to reflect this, causing the rag_tool.insert call to fail. This change updates the code to capture the authoritative ID from the .identifier attribute of the register() method's response. This ensures the tutorial code runs successfully and reflects the current API behavior. ## Test Plan The fix can be verified by running the Python code snippet from the detailed tutorial page. Run the original code (Before this change): Result: The script fails with a 400 Bad Request error on the rag_tool.insert step. Run the updated code (After this change): Result: The script runs successfully to completion. Co-authored-by: Adam Young --- docs/source/getting_started/detailed_tutorial.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/detailed_tutorial.md b/docs/source/getting_started/detailed_tutorial.md index 14f888628..77a899c48 100644 --- a/docs/source/getting_started/detailed_tutorial.md +++ b/docs/source/getting_started/detailed_tutorial.md @@ -460,10 +460,12 @@ client = LlamaStackClient(base_url="http://localhost:8321") embed_lm = next(m for m in client.models.list() if m.model_type == "embedding") embedding_model = embed_lm.identifier vector_db_id = f"v{uuid.uuid4().hex}" -client.vector_dbs.register( +# The VectorDB API is deprecated; the server now returns its own authoritative ID. +# We capture the correct ID from the response's .identifier attribute. +vector_db_id = client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model, -) +).identifier # Create Documents urls = [ From d3600b92d10a3e11126c6120eadf0ba5ef4e1c3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 19 Sep 2025 22:12:08 +0200 Subject: [PATCH 4/5] fix: force milvus-lite installation for inline::milvus (#3488) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? pymilvus recently made `milvus-lite` an optional dependency to their package. If someone wants to use the inline provider we must include the extra dependency. For more details see: https://github.com/milvus-io/pymilvus/pull/2976 Signed-off-by: Sébastien Han --- docs/source/providers/vector_io/remote_milvus.md | 8 +++++++- llama_stack/providers/registry/vector_io.py | 10 ++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index 075423d04..8974ada10 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -23,7 +23,13 @@ To use Milvus in your Llama Stack project, follow these steps: ## Installation -You can install Milvus using pymilvus: +If you want to use inline Milvus, you can install: + +```bash +pip install pymilvus[milvus-lite] +``` + +If you want to use remote Milvus, you can install: ```bash pip install pymilvus diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 3b82f6eb6..e8237bc62 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -633,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps: ## Installation -You can install Milvus using pymilvus: +If you want to use inline Milvus, you can install: + +```bash +pip install pymilvus[milvus-lite] +``` + +If you want to use remote Milvus, you can install: ```bash pip install pymilvus @@ -807,7 +813,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi InlineProviderSpec( api=Api.vector_io, provider_type="inline::milvus", - pip_packages=["pymilvus>=2.4.10"], + pip_packages=["pymilvus[milvus-lite]>=2.4.10"], module="llama_stack.providers.inline.vector_io.milvus", config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", api_dependencies=[Api.inference], From f44eb935c4ff110278758ff2972bd5a5fd544915 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 19 Sep 2025 16:13:56 -0700 Subject: [PATCH 5/5] chore: simplify authorized sqlstore (#3496) # What does this PR do? This PR is generated with AI and reviewed by me. Refactors the AuthorizedSqlStore class to store the access policy as an instance variable rather than passing it as a parameter to each method call. This simplifies the API. # Test Plan existing tests --- .../providers/inline/files/localfs/files.py | 5 ++--- .../providers/remote/files/s3/files.py | 5 ++--- .../utils/inference/inference_store.py | 4 +--- .../utils/responses/responses_store.py | 7 ++----- .../utils/sqlstore/authorized_sqlstore.py | 11 +++++------ .../sqlstore/test_authorized_sqlstore.py | 19 +++++++++++-------- tests/unit/utils/test_authorized_sqlstore.py | 18 +++++++++--------- 7 files changed, 32 insertions(+), 37 deletions(-) diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 9c610c1ba..65cf8d815 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files): storage_path.mkdir(parents=True, exist_ok=True) # Initialize SQL store for metadata - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy) await self.sql_store.create_table( "openai_files", { @@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files): if not self.sql_store: raise RuntimeError("Files provider not initialized") - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "client.files.list()") @@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions if where_conditions else None, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 54742d900..8ea96af9e 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -137,7 +137,7 @@ class S3FilesImpl(Files): where: dict[str, str | dict] = {"id": file_id} if not return_expired: where["expires_at"] = {">": self._now()} - if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): + if not (row := await self.sql_store.fetch_one("openai_files", where=where)): raise ResourceNotFoundError(file_id, "File", "files.list()") return row @@ -164,7 +164,7 @@ class S3FilesImpl(Files): self._client = _create_s3_client(self._config) await _create_bucket_if_not_exists(self._client, self._config) - self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) + self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy) await self._sql_store.create_table( "openai_files", { @@ -268,7 +268,6 @@ class S3FilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 17f4c6268..ffc9f3e11 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -54,7 +54,7 @@ class InferenceStore: async def initialize(self): """Create the necessary tables if they don't exist.""" - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy) await self.sql_store.create_table( "chat_completions", { @@ -202,7 +202,6 @@ class InferenceStore: order_by=[("created", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [ @@ -229,7 +228,6 @@ class InferenceStore: row = await self.sql_store.fetch_one( table="chat_completions", where={"id": completion_id}, - policy=self.policy, ) if not row: diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..829cd8a62 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -28,8 +28,7 @@ class ResponsesStore: sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) - self.policy = policy + self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy) async def initialize(self): """Create the necessary tables if they don't exist.""" @@ -87,7 +86,6 @@ class ResponsesStore: order_by=[("created_at", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] @@ -105,7 +103,6 @@ class ResponsesStore: row = await self.sql_store.fetch_one( "openai_responses", where={"id": response_id}, - policy=self.policy, ) if not row: @@ -116,7 +113,7 @@ class ResponsesStore: return OpenAIResponseObjectWithInput(**row["response_object"]) async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: - row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) if not row: raise ValueError(f"Response with id {response_id} not found") await self.sql_store.delete("openai_responses", where={"id": response_id}) diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index acb688f96..ab67f7052 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -53,13 +53,15 @@ class AuthorizedSqlStore: access control policies, user attribute capture, and SQL filtering optimization. """ - def __init__(self, sql_store: SqlStore): + def __init__(self, sql_store: SqlStore, policy: list[AccessRule]): """ Initialize the authorization layer. :param sql_store: Base SqlStore implementation to wrap + :param policy: Access control policy to use for authorization """ self.sql_store = sql_store + self.policy = policy self._detect_database_type() self._validate_sql_optimized_policy() @@ -117,14 +119,13 @@ class AuthorizedSqlStore: async def fetch_all( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, cursor: tuple[str, str] | None = None, ) -> PaginatedResponse: """Fetch all rows with automatic access control filtering.""" - access_where = self._build_access_control_where_clause(policy) + access_where = self._build_access_control_where_clause(self.policy) rows = await self.sql_store.fetch_all( table=table, where=where, @@ -146,7 +147,7 @@ class AuthorizedSqlStore: str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) ) - if is_action_allowed(policy, Action.READ, sql_record, current_user): + if is_action_allowed(self.policy, Action.READ, sql_record, current_user): filtered_rows.append(row) return PaginatedResponse( @@ -157,14 +158,12 @@ class AuthorizedSqlStore: async def fetch_one( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> dict[str, Any] | None: """Fetch one row with automatic access control checking.""" results = await self.fetch_all( table=table, - policy=policy, where=where, limit=1, order_by=order_by, diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index 4002f2e1f..98bef0f2c 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -57,7 +57,7 @@ def authorized_store(backend_config): config = config_func() base_sqlstore = sqlstore_impl(config) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) yield authorized_store @@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) # Test fetching with no user - should not error on JSON comparison - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" assert result.data[0]["access_attributes"] is None @@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) # Fetch all - admin should see both - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 2 # Test with non-admin user @@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz mock_get_authenticated_user.return_value = regular_user # Should only see public record - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" @@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz # Now test with the multi-user who has both roles=admin and teams=dev mock_get_authenticated_user.return_value = multi_user - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) # Should see: # - public record (1) - no access_attributes @@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto ), ] + # Create a new authorized store with the owner-only policy + owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy) + # Test user1 access - should only see their own record mock_get_authenticated_user.return_value = user1 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" # Test user2 access - should only see their own record mock_get_authenticated_user.return_value = user2 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" # Test with anonymous user - should see no records mock_get_authenticated_user.return_value = None - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" finally: diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index 90eb706e4..d85e784a9 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) # Create table with access control await sqlstore.create_table( @@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic mock_get_authenticated_user.return_value = admin_user # Admin should see both documents - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 1 assert result.data[0]["title"] == "Admin Document" # User should only see their document mock_get_authenticated_user.return_value = regular_user - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 0 - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) + result = await sqlstore.fetch_all("documents", where={"id": 2}) assert len(result.data) == 1 assert result.data[0]["title"] == "User Document" - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) + row = await sqlstore.fetch_one("documents", where={"id": 1}) assert row is None - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) + row = await sqlstore.fetch_one("documents", where={"id": 2}) assert row is not None assert row["title"] == "User Document" @@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) await sqlstore.create_table( table="resources", @@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): user = User(principal=user_data["principal"], attributes=user_data["attributes"]) mock_get_authenticated_user.return_value = user - sql_results = await sqlstore.fetch_all("resources", policy=policy) + sql_results = await sqlstore.fetch_all("resources") sql_ids = {row["id"] for row in sql_results.data} policy_ids = set() for scenario in test_scenarios: @@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us db_path=tmp_dir + "/" + db_name, ) ) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) await authorized_store.create_table( table="user_data",