diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index af2a46739..fbf4871c4 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -31,6 +31,7 @@ from llama_stack.distribution.build import ( from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import ( BuildConfig, + BuildProvider, DistributionSpec, Provider, StackRunConfig, @@ -94,7 +95,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) elif args.providers: - provider_list: dict[str, list[Provider]] = dict() + provider_list: dict[str, list[BuildProvider]] = dict() for api_provider in args.providers.split(","): if "=" not in api_provider: cprint( @@ -113,10 +114,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) if provider_type in providers_for_api: - provider = Provider( + provider = BuildProvider( provider_type=provider_type, - provider_id=provider_type.split("::")[1], - config={}, module=None, ) provider_list.setdefault(api, []).append(provider) @@ -189,7 +188,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint("Tip: use to see options for the providers.\n", color="green", file=sys.stderr) - providers: dict[str, list[Provider]] = dict() + providers: dict[str, list[BuildProvider]] = dict() for api, providers_for_api in get_provider_registry().items(): available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")] if not available_providers: @@ -204,7 +203,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ), ) - providers[api.value] = api_provider + string_providers = api_provider.split(" ") + + for provider in string_providers: + providers.setdefault(api.value, []).append(BuildProvider(provider_type=provider)) description = prompt( "\n > (Optional) Enter a short description for your Llama Stack: ", @@ -307,7 +309,7 @@ def _generate_run_config( providers = build_config.distribution_spec.providers[api] for provider in providers: - pid = provider.provider_id + pid = provider.provider_type.split("::")[-1] p = provider_registry[Api(api)][provider.provider_type] if p.deprecation_error: diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 355233d53..20be040a0 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -100,11 +100,12 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec break logger.info(f"> Configuring provider `({provider.provider_type})`") + pid = provider.provider_type.split("::")[-1] updated_providers.append( configure_single_provider( provider_registry[api], Provider( - provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id), + provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid), provider_type=provider.provider_type, config={}, ), diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index c17aadcc1..60c317337 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -154,13 +154,27 @@ class Provider(BaseModel): ) +class BuildProvider(BaseModel): + provider_type: str + module: str | None = Field( + default=None, + description=""" + Fully-qualified name of the external provider module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation + + Example: `module: ramalama_stack` + """, + ) + + class DistributionSpec(BaseModel): description: str | None = Field( default="", description="Description of the distribution", ) container_image: str | None = None - providers: dict[str, list[Provider]] = Field( + providers: dict[str, list[BuildProvider]] = Field( default_factory=dict, description=""" Provider Types for each of the APIs provided by this distribution. If you diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index bcb0b9167..1c28983cf 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -33,7 +33,7 @@ from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config -from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec +from llama_stack.distribution.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec from llama_stack.distribution.request_headers import ( PROVIDER_DATA_VAR, request_provider_data_context, @@ -249,9 +249,16 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): file=sys.stderr, ) if self.config_path_or_template_name.endswith(".yaml"): + providers: dict[str, list[BuildProvider]] = {} + for api, run_providers in self.config.providers.items(): + for provider in run_providers: + providers.setdefault(api, []).append( + BuildProvider(provider_type=provider.provider_type, module=provider.module) + ) + providers = dict(providers) build_config = BuildConfig( distribution_spec=DistributionSpec( - providers=self.config.providers, + providers=providers, ), external_providers_dir=self.config.external_providers_dir, ) diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index 2f18e5d26..2421842ec 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -3,96 +3,56 @@ distribution_spec: description: CI tests for Llama Stack providers: inference: - - provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} - provider_type: remote::cerebras - - provider_id: ${env.ENABLE_OLLAMA:=__disabled__} - provider_type: remote::ollama - - provider_id: ${env.ENABLE_VLLM:=__disabled__} - provider_type: remote::vllm - - provider_id: ${env.ENABLE_TGI:=__disabled__} - provider_type: remote::tgi - - provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__} - provider_type: remote::hf::serverless - - provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__} - provider_type: remote::hf::endpoint - - provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_type: remote::fireworks - - provider_id: ${env.ENABLE_TOGETHER:=__disabled__} - provider_type: remote::together - - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} - provider_type: remote::bedrock - - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} - provider_type: remote::databricks - - provider_id: ${env.ENABLE_NVIDIA:=__disabled__} - provider_type: remote::nvidia - - provider_id: ${env.ENABLE_RUNPOD:=__disabled__} - provider_type: remote::runpod - - provider_id: ${env.ENABLE_OPENAI:=__disabled__} - provider_type: remote::openai - - provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} - provider_type: remote::anthropic - - provider_id: ${env.ENABLE_GEMINI:=__disabled__} - provider_type: remote::gemini - - provider_id: ${env.ENABLE_GROQ:=__disabled__} - provider_type: remote::groq - - provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__} - provider_type: remote::llama-openai-compat - - provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} - provider_type: remote::sambanova - - provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__} - provider_type: remote::passthrough - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers + - provider_type: remote::cerebras + - provider_type: remote::ollama + - provider_type: remote::vllm + - provider_type: remote::tgi + - provider_type: remote::hf::serverless + - provider_type: remote::hf::endpoint + - provider_type: remote::fireworks + - provider_type: remote::together + - provider_type: remote::bedrock + - provider_type: remote::databricks + - provider_type: remote::nvidia + - provider_type: remote::runpod + - provider_type: remote::openai + - provider_type: remote::anthropic + - provider_type: remote::gemini + - provider_type: remote::groq + - provider_type: remote::llama-openai-compat + - provider_type: remote::sambanova + - provider_type: remote::passthrough + - provider_type: inline::sentence-transformers vector_io: - - provider_id: ${env.ENABLE_FAISS:=faiss} - provider_type: inline::faiss - - provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__} - provider_type: inline::sqlite-vec - - provider_id: ${env.ENABLE_MILVUS:=__disabled__} - provider_type: inline::milvus - - provider_id: ${env.ENABLE_CHROMADB:=__disabled__} - provider_type: remote::chromadb - - provider_id: ${env.ENABLE_PGVECTOR:=__disabled__} - provider_type: remote::pgvector + - provider_type: inline::faiss + - provider_type: inline::sqlite-vec + - provider_type: inline::milvus + - provider_type: remote::chromadb + - provider_type: remote::pgvector files: - - provider_id: localfs - provider_type: inline::localfs + - provider_type: inline::localfs safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_type: inline::llama-guard agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_type: inline::huggingface eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - - provider_id: localfs - provider_type: inline::localfs + - provider_type: remote::huggingface + - provider_type: inline::localfs scoring: - - provider_id: basic - provider_type: inline::basic - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - - provider_id: braintrust - provider_type: inline::braintrust + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - - provider_id: tavily-search - provider_type: remote::tavily-search - - provider_id: rag-runtime - provider_type: inline::rag-runtime - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol image_type: conda image_name: ci-tests additional_pip_packages: diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/templates/dell/build.yaml index d19934ee5..4a4ecb37c 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/templates/dell/build.yaml @@ -4,48 +4,31 @@ distribution_spec: container providers: inference: - - provider_id: tgi - provider_type: remote::tgi - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers + - provider_type: remote::tgi + - provider_type: inline::sentence-transformers vector_io: - - provider_id: faiss - provider_type: inline::faiss - - provider_id: chromadb - provider_type: remote::chromadb - - provider_id: pgvector - provider_type: remote::pgvector + - provider_type: inline::faiss + - provider_type: remote::chromadb + - provider_type: remote::pgvector safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_type: inline::llama-guard agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - - provider_id: localfs - provider_type: inline::localfs + - provider_type: remote::huggingface + - provider_type: inline::localfs scoring: - - provider_id: basic - provider_type: inline::basic - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - - provider_id: braintrust - provider_type: inline::braintrust + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - - provider_id: tavily-search - provider_type: remote::tavily-search - - provider_id: rag-runtime - provider_type: inline::rag-runtime + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime image_type: conda image_name: dell additional_pip_packages: diff --git a/llama_stack/templates/dell/dell.py b/llama_stack/templates/dell/dell.py index b2210e7dc..64e01535c 100644 --- a/llama_stack/templates/dell/dell.py +++ b/llama_stack/templates/dell/dell.py @@ -6,6 +6,7 @@ from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import ( + BuildProvider, ModelInput, Provider, ShieldInput, @@ -20,31 +21,31 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { "inference": [ - Provider(provider_id="tgi", provider_type="remote::tgi"), - Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"), + BuildProvider(provider_type="remote::tgi"), + BuildProvider(provider_type="inline::sentence-transformers"), ], "vector_io": [ - Provider(provider_id="faiss", provider_type="inline::faiss"), - Provider(provider_id="chromadb", provider_type="remote::chromadb"), - Provider(provider_id="pgvector", provider_type="remote::pgvector"), + BuildProvider(provider_type="inline::faiss"), + BuildProvider(provider_type="remote::chromadb"), + BuildProvider(provider_type="remote::pgvector"), ], - "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], - "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], - "telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], - "eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "safety": [BuildProvider(provider_type="inline::llama-guard")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "telemetry": [BuildProvider(provider_type="inline::meta-reference")], + "eval": [BuildProvider(provider_type="inline::meta-reference")], "datasetio": [ - Provider(provider_id="huggingface", provider_type="remote::huggingface"), - Provider(provider_id="localfs", provider_type="inline::localfs"), + BuildProvider(provider_type="remote::huggingface"), + BuildProvider(provider_type="inline::localfs"), ], "scoring": [ - Provider(provider_id="basic", provider_type="inline::basic"), - Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"), - Provider(provider_id="braintrust", provider_type="inline::braintrust"), + BuildProvider(provider_type="inline::basic"), + BuildProvider(provider_type="inline::llm-as-judge"), + BuildProvider(provider_type="inline::braintrust"), ], "tool_runtime": [ - Provider(provider_id="brave-search", provider_type="remote::brave-search"), - Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), - Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), + BuildProvider(provider_type="remote::brave-search"), + BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), ], } name = "dell" diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index 0a0bc0aea..e929f1683 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -3,48 +3,31 @@ distribution_spec: description: Use Meta Reference for running LLM inference providers: inference: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference vector_io: - - provider_id: faiss - provider_type: inline::faiss - - provider_id: chromadb - provider_type: remote::chromadb - - provider_id: pgvector - provider_type: remote::pgvector + - provider_type: inline::faiss + - provider_type: remote::chromadb + - provider_type: remote::pgvector safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_type: inline::llama-guard agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - - provider_id: localfs - provider_type: inline::localfs + - provider_type: remote::huggingface + - provider_type: inline::localfs scoring: - - provider_id: basic - provider_type: inline::basic - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - - provider_id: braintrust - provider_type: inline::braintrust + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - - provider_id: tavily-search - provider_type: remote::tavily-search - - provider_id: rag-runtime - provider_type: inline::rag-runtime - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol image_type: conda image_name: meta-reference-gpu additional_pip_packages: diff --git a/llama_stack/templates/meta-reference-gpu/meta_reference.py b/llama_stack/templates/meta-reference-gpu/meta_reference.py index 6ca500eff..981c66bf5 100644 --- a/llama_stack/templates/meta-reference-gpu/meta_reference.py +++ b/llama_stack/templates/meta-reference-gpu/meta_reference.py @@ -8,6 +8,7 @@ from pathlib import Path from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import ( + BuildProvider, ModelInput, Provider, ShieldInput, @@ -25,91 +26,30 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) - ], + "inference": [BuildProvider(provider_type="inline::meta-reference")], "vector_io": [ - Provider( - provider_id="faiss", - provider_type="inline::faiss", - ), - Provider( - provider_id="chromadb", - provider_type="remote::chromadb", - ), - Provider( - provider_id="pgvector", - provider_type="remote::pgvector", - ), - ], - "safety": [ - Provider( - provider_id="llama-guard", - provider_type="inline::llama-guard", - ) - ], - "agents": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) - ], - "telemetry": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) - ], - "eval": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) + BuildProvider(provider_type="inline::faiss"), + BuildProvider(provider_type="remote::chromadb"), + BuildProvider(provider_type="remote::pgvector"), ], + "safety": [BuildProvider(provider_type="inline::llama-guard")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "telemetry": [BuildProvider(provider_type="inline::meta-reference")], + "eval": [BuildProvider(provider_type="inline::meta-reference")], "datasetio": [ - Provider( - provider_id="huggingface", - provider_type="remote::huggingface", - ), - Provider( - provider_id="localfs", - provider_type="inline::localfs", - ), + BuildProvider(provider_type="remote::huggingface"), + BuildProvider(provider_type="inline::localfs"), ], "scoring": [ - Provider( - provider_id="basic", - provider_type="inline::basic", - ), - Provider( - provider_id="llm-as-judge", - provider_type="inline::llm-as-judge", - ), - Provider( - provider_id="braintrust", - provider_type="inline::braintrust", - ), + BuildProvider(provider_type="inline::basic"), + BuildProvider(provider_type="inline::llm-as-judge"), + BuildProvider(provider_type="inline::braintrust"), ], "tool_runtime": [ - Provider( - provider_id="brave-search", - provider_type="remote::brave-search", - ), - Provider( - provider_id="tavily-search", - provider_type="remote::tavily-search", - ), - Provider( - provider_id="rag-runtime", - provider_type="inline::rag-runtime", - ), - Provider( - provider_id="model-context-protocol", - provider_type="remote::model-context-protocol", - ), + BuildProvider(provider_type="remote::brave-search"), + BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), + BuildProvider(provider_type="remote::model-context-protocol"), ], } name = "meta-reference-gpu" diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index 572a70408..7d9a6140f 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -3,37 +3,26 @@ distribution_spec: description: Use NVIDIA NIM for running LLM inference, evaluation and safety providers: inference: - - provider_id: nvidia - provider_type: remote::nvidia + - provider_type: remote::nvidia vector_io: - - provider_id: faiss - provider_type: inline::faiss + - provider_type: inline::faiss safety: - - provider_id: nvidia - provider_type: remote::nvidia + - provider_type: remote::nvidia agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference eval: - - provider_id: nvidia - provider_type: remote::nvidia + - provider_type: remote::nvidia post_training: - - provider_id: nvidia - provider_type: remote::nvidia + - provider_type: remote::nvidia datasetio: - - provider_id: localfs - provider_type: inline::localfs - - provider_id: nvidia - provider_type: remote::nvidia + - provider_type: inline::localfs + - provider_type: remote::nvidia scoring: - - provider_id: basic - provider_type: inline::basic + - provider_type: inline::basic tool_runtime: - - provider_id: rag-runtime - provider_type: inline::rag-runtime + - provider_type: inline::rag-runtime image_type: conda image_name: nvidia additional_pip_packages: diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index 25beeae75..df82cf7c0 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -6,7 +6,7 @@ from pathlib import Path -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput +from llama_stack.distribution.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig @@ -17,65 +17,19 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": [ - Provider( - provider_id="nvidia", - provider_type="remote::nvidia", - ) - ], - "vector_io": [ - Provider( - provider_id="faiss", - provider_type="inline::faiss", - ) - ], - "safety": [ - Provider( - provider_id="nvidia", - provider_type="remote::nvidia", - ) - ], - "agents": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) - ], - "telemetry": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) - ], - "eval": [ - Provider( - provider_id="nvidia", - provider_type="remote::nvidia", - ) - ], - "post_training": [Provider(provider_id="nvidia", provider_type="remote::nvidia", config={})], + "inference": [BuildProvider(provider_type="remote::nvidia")], + "vector_io": [BuildProvider(provider_type="inline::faiss")], + "safety": [BuildProvider(provider_type="remote::nvidia")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "telemetry": [BuildProvider(provider_type="inline::meta-reference")], + "eval": [BuildProvider(provider_type="remote::nvidia")], + "post_training": [BuildProvider(provider_type="remote::nvidia")], "datasetio": [ - Provider( - provider_id="localfs", - provider_type="inline::localfs", - ), - Provider( - provider_id="nvidia", - provider_type="remote::nvidia", - ), - ], - "scoring": [ - Provider( - provider_id="basic", - provider_type="inline::basic", - ) - ], - "tool_runtime": [ - Provider( - provider_id="rag-runtime", - provider_type="inline::rag-runtime", - ) + BuildProvider(provider_type="inline::localfs"), + BuildProvider(provider_type="remote::nvidia"), ], + "scoring": [BuildProvider(provider_type="inline::basic")], + "tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")], } inference_provider = Provider( diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/templates/open-benchmark/build.yaml index 6647b471c..1c7966cdd 100644 --- a/llama_stack/templates/open-benchmark/build.yaml +++ b/llama_stack/templates/open-benchmark/build.yaml @@ -3,56 +3,35 @@ distribution_spec: description: Distribution for running open benchmarks providers: inference: - - provider_id: openai - provider_type: remote::openai - - provider_id: anthropic - provider_type: remote::anthropic - - provider_id: gemini - provider_type: remote::gemini - - provider_id: groq - provider_type: remote::groq - - provider_id: together - provider_type: remote::together + - provider_type: remote::openai + - provider_type: remote::anthropic + - provider_type: remote::gemini + - provider_type: remote::groq + - provider_type: remote::together vector_io: - - provider_id: sqlite-vec - provider_type: inline::sqlite-vec - - provider_id: chromadb - provider_type: remote::chromadb - - provider_id: pgvector - provider_type: remote::pgvector + - provider_type: inline::sqlite-vec + - provider_type: remote::chromadb + - provider_type: remote::pgvector safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_type: inline::llama-guard agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - - provider_id: localfs - provider_type: inline::localfs + - provider_type: remote::huggingface + - provider_type: inline::localfs scoring: - - provider_id: basic - provider_type: inline::basic - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - - provider_id: braintrust - provider_type: inline::braintrust + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - - provider_id: tavily-search - provider_type: remote::tavily-search - - provider_id: rag-runtime - provider_type: inline::rag-runtime - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol image_type: conda image_name: open-benchmark additional_pip_packages: diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 3a17e7525..0a0d9fb14 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -9,6 +9,7 @@ from llama_stack.apis.datasets import DatasetPurpose, URIDataSource from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import ( BenchmarkInput, + BuildProvider, DatasetInput, ModelInput, Provider, @@ -96,33 +97,30 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo def get_distribution_template() -> DistributionTemplate: inference_providers, available_models = get_inference_providers() providers = { - "inference": inference_providers, + "inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in inference_providers], "vector_io": [ - Provider(provider_id="sqlite-vec", provider_type="inline::sqlite-vec"), - Provider(provider_id="chromadb", provider_type="remote::chromadb"), - Provider(provider_id="pgvector", provider_type="remote::pgvector"), + BuildProvider(provider_type="inline::sqlite-vec"), + BuildProvider(provider_type="remote::chromadb"), + BuildProvider(provider_type="remote::pgvector"), ], - "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], - "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], - "telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], - "eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "safety": [BuildProvider(provider_type="inline::llama-guard")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "telemetry": [BuildProvider(provider_type="inline::meta-reference")], + "eval": [BuildProvider(provider_type="inline::meta-reference")], "datasetio": [ - Provider(provider_id="huggingface", provider_type="remote::huggingface"), - Provider(provider_id="localfs", provider_type="inline::localfs"), + BuildProvider(provider_type="remote::huggingface"), + BuildProvider(provider_type="inline::localfs"), ], "scoring": [ - Provider(provider_id="basic", provider_type="inline::basic"), - Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"), - Provider(provider_id="braintrust", provider_type="inline::braintrust"), + BuildProvider(provider_type="inline::basic"), + BuildProvider(provider_type="inline::llm-as-judge"), + BuildProvider(provider_type="inline::braintrust"), ], "tool_runtime": [ - Provider(provider_id="brave-search", provider_type="remote::brave-search"), - Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), - Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), - Provider( - provider_id="model-context-protocol", - provider_type="remote::model-context-protocol", - ), + BuildProvider(provider_type="remote::brave-search"), + BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), + BuildProvider(provider_type="remote::model-context-protocol"), ], } name = "open-benchmark" diff --git a/llama_stack/templates/postgres-demo/build.yaml b/llama_stack/templates/postgres-demo/build.yaml index d5e816a54..5a1c9129a 100644 --- a/llama_stack/templates/postgres-demo/build.yaml +++ b/llama_stack/templates/postgres-demo/build.yaml @@ -3,31 +3,21 @@ distribution_spec: description: Quick start template for running Llama Stack with several popular providers providers: inference: - - provider_id: vllm-inference - provider_type: remote::vllm - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers + - provider_type: remote::vllm + - provider_type: inline::sentence-transformers vector_io: - - provider_id: chromadb - provider_type: remote::chromadb + - provider_type: remote::chromadb safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_type: inline::llama-guard agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - - provider_id: tavily-search - provider_type: remote::tavily-search - - provider_id: rag-runtime - provider_type: inline::rag-runtime - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol image_type: conda image_name: postgres-demo additional_pip_packages: diff --git a/llama_stack/templates/postgres-demo/postgres_demo.py b/llama_stack/templates/postgres-demo/postgres_demo.py index 24e3f6f27..d9ded9a86 100644 --- a/llama_stack/templates/postgres-demo/postgres_demo.py +++ b/llama_stack/templates/postgres-demo/postgres_demo.py @@ -7,6 +7,7 @@ from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import ( + BuildProvider, ModelInput, Provider, ShieldInput, @@ -34,24 +35,19 @@ def get_distribution_template() -> DistributionTemplate: ), ] providers = { - "inference": inference_providers - + [ - Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"), + "inference": [ + BuildProvider(provider_type="remote::vllm"), + BuildProvider(provider_type="inline::sentence-transformers"), ], - "vector_io": [ - Provider(provider_id="chromadb", provider_type="remote::chromadb"), - ], - "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], - "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], - "telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "vector_io": [BuildProvider(provider_type="remote::chromadb")], + "safety": [BuildProvider(provider_type="inline::llama-guard")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "telemetry": [BuildProvider(provider_type="inline::meta-reference")], "tool_runtime": [ - Provider(provider_id="brave-search", provider_type="remote::brave-search"), - Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), - Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), - Provider( - provider_id="model-context-protocol", - provider_type="remote::model-context-protocol", - ), + BuildProvider(provider_type="remote::brave-search"), + BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), + BuildProvider(provider_type="remote::model-context-protocol"), ], } name = "postgres-demo" diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index 9b540ab62..14e7e157b 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -3,96 +3,56 @@ distribution_spec: description: Quick start template for running Llama Stack with several popular providers providers: inference: - - provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} - provider_type: remote::cerebras - - provider_id: ${env.ENABLE_OLLAMA:=__disabled__} - provider_type: remote::ollama - - provider_id: ${env.ENABLE_VLLM:=__disabled__} - provider_type: remote::vllm - - provider_id: ${env.ENABLE_TGI:=__disabled__} - provider_type: remote::tgi - - provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__} - provider_type: remote::hf::serverless - - provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__} - provider_type: remote::hf::endpoint - - provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_type: remote::fireworks - - provider_id: ${env.ENABLE_TOGETHER:=__disabled__} - provider_type: remote::together - - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} - provider_type: remote::bedrock - - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} - provider_type: remote::databricks - - provider_id: ${env.ENABLE_NVIDIA:=__disabled__} - provider_type: remote::nvidia - - provider_id: ${env.ENABLE_RUNPOD:=__disabled__} - provider_type: remote::runpod - - provider_id: ${env.ENABLE_OPENAI:=__disabled__} - provider_type: remote::openai - - provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} - provider_type: remote::anthropic - - provider_id: ${env.ENABLE_GEMINI:=__disabled__} - provider_type: remote::gemini - - provider_id: ${env.ENABLE_GROQ:=__disabled__} - provider_type: remote::groq - - provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__} - provider_type: remote::llama-openai-compat - - provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} - provider_type: remote::sambanova - - provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__} - provider_type: remote::passthrough - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers + - provider_type: remote::cerebras + - provider_type: remote::ollama + - provider_type: remote::vllm + - provider_type: remote::tgi + - provider_type: remote::hf::serverless + - provider_type: remote::hf::endpoint + - provider_type: remote::fireworks + - provider_type: remote::together + - provider_type: remote::bedrock + - provider_type: remote::databricks + - provider_type: remote::nvidia + - provider_type: remote::runpod + - provider_type: remote::openai + - provider_type: remote::anthropic + - provider_type: remote::gemini + - provider_type: remote::groq + - provider_type: remote::llama-openai-compat + - provider_type: remote::sambanova + - provider_type: remote::passthrough + - provider_type: inline::sentence-transformers vector_io: - - provider_id: ${env.ENABLE_FAISS:=faiss} - provider_type: inline::faiss - - provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__} - provider_type: inline::sqlite-vec - - provider_id: ${env.ENABLE_MILVUS:=__disabled__} - provider_type: inline::milvus - - provider_id: ${env.ENABLE_CHROMADB:=__disabled__} - provider_type: remote::chromadb - - provider_id: ${env.ENABLE_PGVECTOR:=__disabled__} - provider_type: remote::pgvector + - provider_type: inline::faiss + - provider_type: inline::sqlite-vec + - provider_type: inline::milvus + - provider_type: remote::chromadb + - provider_type: remote::pgvector files: - - provider_id: localfs - provider_type: inline::localfs + - provider_type: inline::localfs safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_type: inline::llama-guard agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_type: inline::huggingface eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - - provider_id: localfs - provider_type: inline::localfs + - provider_type: remote::huggingface + - provider_type: inline::localfs scoring: - - provider_id: basic - provider_type: inline::basic - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - - provider_id: braintrust - provider_type: inline::braintrust + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - - provider_id: tavily-search - provider_type: remote::tavily-search - - provider_id: rag-runtime - provider_type: inline::rag-runtime - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol image_type: conda image_name: starter additional_pip_packages: diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index 489117702..6c4ffbc06 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -9,6 +9,7 @@ from typing import Any from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import ( + BuildProvider, ModelInput, Provider, ProviderSpec, @@ -213,131 +214,38 @@ def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list def get_distribution_template() -> DistributionTemplate: remote_inference_providers, available_models = get_remote_inference_providers() - name = "starter" - - vector_io_providers = [ - Provider( - provider_id="${env.ENABLE_FAISS:=faiss}", - provider_type="inline::faiss", - config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), - ), - Provider( - provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}", - provider_type="inline::sqlite-vec", - config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), - ), - Provider( - provider_id="${env.ENABLE_MILVUS:=__disabled__}", - provider_type="inline::milvus", - config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), - ), - Provider( - provider_id="${env.ENABLE_CHROMADB:=__disabled__}", - provider_type="remote::chromadb", - config=ChromaVectorIOConfig.sample_run_config( - f"~/.llama/distributions/{name}/", - url="${env.CHROMADB_URL:=}", - ), - ), - Provider( - provider_id="${env.ENABLE_PGVECTOR:=__disabled__}", - provider_type="remote::pgvector", - config=PGVectorVectorIOConfig.sample_run_config( - f"~/.llama/distributions/{name}", - db="${env.PGVECTOR_DB:=}", - user="${env.PGVECTOR_USER:=}", - password="${env.PGVECTOR_PASSWORD:=}", - ), - ), - ] - + # For build config, use BuildProvider with only provider_type and module providers = { - "inference": remote_inference_providers - + [ - Provider( - provider_id="sentence-transformers", - provider_type="inline::sentence-transformers", - ) - ], - "vector_io": vector_io_providers, - "files": [ - Provider( - provider_id="localfs", - provider_type="inline::localfs", - ) - ], - "safety": [ - Provider( - provider_id="llama-guard", - provider_type="inline::llama-guard", - ) - ], - "agents": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) - ], - "telemetry": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) - ], - "post_training": [ - Provider( - provider_id="huggingface", - provider_type="inline::huggingface", - ) - ], - "eval": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) + "inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers] + + [BuildProvider(provider_type="inline::sentence-transformers")], + "vector_io": [ + BuildProvider(provider_type="inline::faiss"), + BuildProvider(provider_type="inline::sqlite-vec"), + BuildProvider(provider_type="inline::milvus"), + BuildProvider(provider_type="remote::chromadb"), + BuildProvider(provider_type="remote::pgvector"), ], + "files": [BuildProvider(provider_type="inline::localfs")], + "safety": [BuildProvider(provider_type="inline::llama-guard")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "telemetry": [BuildProvider(provider_type="inline::meta-reference")], + "post_training": [BuildProvider(provider_type="inline::huggingface")], + "eval": [BuildProvider(provider_type="inline::meta-reference")], "datasetio": [ - Provider( - provider_id="huggingface", - provider_type="remote::huggingface", - ), - Provider( - provider_id="localfs", - provider_type="inline::localfs", - ), + BuildProvider(provider_type="remote::huggingface"), + BuildProvider(provider_type="inline::localfs"), ], "scoring": [ - Provider( - provider_id="basic", - provider_type="inline::basic", - ), - Provider( - provider_id="llm-as-judge", - provider_type="inline::llm-as-judge", - ), - Provider( - provider_id="braintrust", - provider_type="inline::braintrust", - ), + BuildProvider(provider_type="inline::basic"), + BuildProvider(provider_type="inline::llm-as-judge"), + BuildProvider(provider_type="inline::braintrust"), ], "tool_runtime": [ - Provider( - provider_id="brave-search", - provider_type="remote::brave-search", - ), - Provider( - provider_id="tavily-search", - provider_type="remote::tavily-search", - ), - Provider( - provider_id="rag-runtime", - provider_type="inline::rag-runtime", - ), - Provider( - provider_id="model-context-protocol", - provider_type="remote::model-context-protocol", - ), + BuildProvider(provider_type="remote::brave-search"), + BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), + BuildProvider(provider_type="remote::model-context-protocol"), ], } files_provider = Provider( @@ -392,7 +300,41 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": remote_inference_providers + [embedding_provider], - "vector_io": vector_io_providers, + "vector_io": [ + Provider( + provider_id="${env.ENABLE_FAISS:=faiss}", + provider_type="inline::faiss", + config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ), + Provider( + provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}", + provider_type="inline::sqlite-vec", + config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ), + Provider( + provider_id="${env.ENABLE_MILVUS:=__disabled__}", + provider_type="inline::milvus", + config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ), + Provider( + provider_id="${env.ENABLE_CHROMADB:=__disabled__}", + provider_type="remote::chromadb", + config=ChromaVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}/", + url="${env.CHROMADB_URL:=}", + ), + ), + Provider( + provider_id="${env.ENABLE_PGVECTOR:=__disabled__}", + provider_type="remote::pgvector", + config=PGVectorVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}", + db="${env.PGVECTOR_DB:=}", + user="${env.PGVECTOR_USER:=}", + password="${env.PGVECTOR_PASSWORD:=}", + ), + ), + ], "files": [files_provider], "post_training": [post_training_provider], }, diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index e9054f95d..084996cd4 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -19,6 +19,7 @@ from llama_stack.distribution.datatypes import ( Api, BenchmarkInput, BuildConfig, + BuildProvider, DatasetInput, DistributionSpec, ModelInput, @@ -183,7 +184,7 @@ class RunConfigSettings(BaseModel): def run_config( self, name: str, - providers: dict[str, list[Provider]], + providers: dict[str, list[BuildProvider]], container_image: str | None = None, ) -> dict: provider_registry = get_provider_registry() @@ -199,7 +200,7 @@ class RunConfigSettings(BaseModel): api = Api(api_str) if provider.provider_type not in provider_registry[api]: raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}") - + provider_id = provider.provider_type.split("::")[-1] config_class = provider_registry[api][provider.provider_type].config_class assert config_class is not None, ( f"No config class for provider type: {provider.provider_type} for API: {api_str}" @@ -210,10 +211,14 @@ class RunConfigSettings(BaseModel): config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}") else: config = {} - - provider.config = config - # Convert Provider object to dict for YAML serialization - provider_configs[api_str].append(provider.model_dump(exclude_none=True)) + # BuildProvider does not have a config attribute; skip assignment + provider_configs[api_str].append( + Provider( + provider_id=provider_id, + provider_type=provider.provider_type, + config=config, + ).model_dump(exclude_none=True) + ) # Get unique set of APIs from providers apis = sorted(providers.keys()) @@ -257,7 +262,8 @@ class DistributionTemplate(BaseModel): description: str distro_type: Literal["self_hosted", "remote_hosted", "ondevice"] - providers: dict[str, list[Provider]] + # Now uses BuildProvider for build config, not Provider + providers: dict[str, list[BuildProvider]] run_configs: dict[str, RunConfigSettings] template_path: Path | None = None @@ -295,11 +301,9 @@ class DistributionTemplate(BaseModel): for api, providers in self.providers.items(): build_providers[api] = [] for provider in providers: - # Create a minimal provider object with only essential build information - build_provider = Provider( - provider_id=provider.provider_id, + # Create a minimal build provider object with only essential build information + build_provider = BuildProvider( provider_type=provider.provider_type, - config={}, # Empty config for build module=provider.module, ) build_providers[api].append(build_provider) @@ -323,50 +327,52 @@ class DistributionTemplate(BaseModel): providers_str = ", ".join(f"`{p.provider_type}`" for p in providers) providers_table += f"| {api} | {providers_str} |\n" - template = self.template_path.read_text() - comment = "\n" - orphantext = "---\norphan: true\n---\n" + if self.template_path is not None: + template = self.template_path.read_text() + comment = "\n" + orphantext = "---\norphan: true\n---\n" - if template.startswith(orphantext): - template = template.replace(orphantext, orphantext + comment) - else: - template = comment + template + if template.startswith(orphantext): + template = template.replace(orphantext, orphantext + comment) + else: + template = comment + template - # Render template with rich-generated table - env = jinja2.Environment( - trim_blocks=True, - lstrip_blocks=True, - # NOTE: autoescape is required to prevent XSS attacks - autoescape=True, - ) - template = env.from_string(template) + # Render template with rich-generated table + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + # NOTE: autoescape is required to prevent XSS attacks + autoescape=True, + ) + template = env.from_string(template) - default_models = [] - if self.available_models_by_provider: - has_multiple_providers = len(self.available_models_by_provider.keys()) > 1 - for provider_id, model_entries in self.available_models_by_provider.items(): - for model_entry in model_entries: - doc_parts = [] - if model_entry.aliases: - doc_parts.append(f"aliases: {', '.join(model_entry.aliases)}") - if has_multiple_providers: - doc_parts.append(f"provider: {provider_id}") + default_models = [] + if self.available_models_by_provider: + has_multiple_providers = len(self.available_models_by_provider.keys()) > 1 + for provider_id, model_entries in self.available_models_by_provider.items(): + for model_entry in model_entries: + doc_parts = [] + if model_entry.aliases: + doc_parts.append(f"aliases: {', '.join(model_entry.aliases)}") + if has_multiple_providers: + doc_parts.append(f"provider: {provider_id}") - default_models.append( - DefaultModel( - model_id=model_entry.provider_model_id, - doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""), + default_models.append( + DefaultModel( + model_id=model_entry.provider_model_id, + doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""), + ) ) - ) - return template.render( - name=self.name, - description=self.description, - providers=self.providers, - providers_table=providers_table, - run_config_env_vars=self.run_config_env_vars, - default_models=default_models, - ) + return template.render( + name=self.name, + description=self.description, + providers=self.providers, + providers_table=providers_table, + run_config_env_vars=self.run_config_env_vars, + default_models=default_models, + ) + return "" def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None: def enum_representer(dumper, data): diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py index c13bbea36..5d8332c4f 100644 --- a/llama_stack/templates/watsonx/watsonx.py +++ b/llama_stack/templates/watsonx/watsonx.py @@ -7,7 +7,7 @@ from pathlib import Path from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput +from llama_stack.distribution.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -19,86 +19,28 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { "inference": [ - Provider( - provider_id="watsonx", - provider_type="remote::watsonx", - ), - Provider( - provider_id="sentence-transformers", - provider_type="inline::sentence-transformers", - ), - ], - "vector_io": [ - Provider( - provider_id="faiss", - provider_type="inline::faiss", - ) - ], - "safety": [ - Provider( - provider_id="llama-guard", - provider_type="inline::llama-guard", - ) - ], - "agents": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) - ], - "telemetry": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) - ], - "eval": [ - Provider( - provider_id="meta-reference", - provider_type="inline::meta-reference", - ) + BuildProvider(provider_type="remote::watsonx"), + BuildProvider(provider_type="inline::sentence-transformers"), ], + "vector_io": [BuildProvider(provider_type="inline::faiss")], + "safety": [BuildProvider(provider_type="inline::llama-guard")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "telemetry": [BuildProvider(provider_type="inline::meta-reference")], + "eval": [BuildProvider(provider_type="inline::meta-reference")], "datasetio": [ - Provider( - provider_id="huggingface", - provider_type="remote::huggingface", - ), - Provider( - provider_id="localfs", - provider_type="inline::localfs", - ), + BuildProvider(provider_type="remote::huggingface"), + BuildProvider(provider_type="inline::localfs"), ], "scoring": [ - Provider( - provider_id="basic", - provider_type="inline::basic", - ), - Provider( - provider_id="llm-as-judge", - provider_type="inline::llm-as-judge", - ), - Provider( - provider_id="braintrust", - provider_type="inline::braintrust", - ), + BuildProvider(provider_type="inline::basic"), + BuildProvider(provider_type="inline::llm-as-judge"), + BuildProvider(provider_type="inline::braintrust"), ], "tool_runtime": [ - Provider( - provider_id="brave-search", - provider_type="remote::brave-search", - ), - Provider( - provider_id="tavily-search", - provider_type="remote::tavily-search", - ), - Provider( - provider_id="rag-runtime", - provider_type="inline::rag-runtime", - ), - Provider( - provider_id="model-context-protocol", - provider_type="remote::model-context-protocol", - ), + BuildProvider(provider_type="remote::brave-search"), + BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), + BuildProvider(provider_type="remote::model-context-protocol"), ], }