From 233f8c81bf372ab12da0e838bc452a6b41821ac4 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Sun, 6 Jul 2025 19:56:18 -0400 Subject: [PATCH] refactor: Align build and run provider datatypes introduce the concept of a `module` for users to specify for a provider upon build time. In order to facilitate this, align the build and run spec to use `Provider` class rather than the stringified provider_type that build currently uses. Signed-off-by: Charlie Doern --- llama_stack/cli/stack/_build.py | 2 +- llama_stack/distribution/configure.py | 10 +++---- llama_stack/distribution/datatypes.py | 35 ++++++++++++++-------- llama_stack/distribution/library_client.py | 8 +---- llama_stack/distribution/resolver.py | 2 +- llama_stack/providers/datatypes.py | 29 ++++++++++-------- 6 files changed, 48 insertions(+), 38 deletions(-) diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 83aefa4a9..e4f3836f0 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -402,7 +402,7 @@ def _run_stack_build_command_from_build_config( run_config_file = _generate_run_config(build_config, build_dir, image_name) with open(build_file_path, "w") as f: - to_write = json.loads(build_config.model_dump_json()) + to_write = json.loads(build_config.model_dump_json(exclude_none=True)) f.write(yaml.dump(to_write, sort_keys=False)) # We first install the external APIs so that the build process can use them and discover the diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 2238eef93..355233d53 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -91,21 +91,21 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec logger.info(f"Configuring API `{api_str}`...") updated_providers = [] - for i, provider_type in enumerate(plist): + for i, provider in enumerate(plist): if i >= 1: - others = ", ".join(plist[i:]) + others = ", ".join(p.provider_type for p in plist[i:]) logger.info( f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n" ) break - logger.info(f"> Configuring provider `({provider_type})`") + logger.info(f"> Configuring provider `({provider.provider_type})`") updated_providers.append( configure_single_provider( provider_registry[api], Provider( - provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type), - provider_type=provider_type, + provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id), + provider_type=provider.provider_type, config={}, ), ) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index f0b18606a..c17aadcc1 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -136,29 +136,40 @@ class RoutingTableProviderSpec(ProviderSpec): pip_packages: list[str] = Field(default_factory=list) +class Provider(BaseModel): + # provider_id of None means that the provider is not enabled - this happens + # when the provider is enabled via a conditional environment variable + provider_id: str | None + provider_type: str + config: dict[str, Any] = {} + module: str | None = Field( + default=None, + description=""" + Fully-qualified name of the external provider module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation + + Example: `module: ramalama_stack` + """, + ) + + class DistributionSpec(BaseModel): description: str | None = Field( default="", description="Description of the distribution", ) container_image: str | None = None - providers: dict[str, str | list[str]] = Field( + providers: dict[str, list[Provider]] = Field( default_factory=dict, description=""" -Provider Types for each of the APIs provided by this distribution. If you -select multiple providers, you should provide an appropriate 'routing_map' -in the runtime configuration to help route to the correct provider.""", + Provider Types for each of the APIs provided by this distribution. If you + select multiple providers, you should provide an appropriate 'routing_map' + in the runtime configuration to help route to the correct provider. + """, ) -class Provider(BaseModel): - # provider_id of None means that the provider is not enabled - this happens - # when the provider is enabled via a conditional environment variable - provider_id: str | None - provider_type: str - config: dict[str, Any] - - class LoggingConfig(BaseModel): category_levels: dict[str, str] = Field( default_factory=dict, diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 07949aea7..bcb0b9167 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -249,15 +249,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): file=sys.stderr, ) if self.config_path_or_template_name.endswith(".yaml"): - # Convert Provider objects to their types - provider_types: dict[str, str | list[str]] = {} - for api, providers in self.config.providers.items(): - types = [p.provider_type for p in providers] - # Convert single-item lists to strings - provider_types[api] = types[0] if len(types) == 1 else types build_config = BuildConfig( distribution_spec=DistributionSpec( - providers=provider_types, + providers=self.config.providers, ), external_providers_dir=self.config.external_providers_dir, ) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 95017debb..db6856ed2 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -345,7 +345,7 @@ async def instantiate_provider( policy: list[AccessRule], ): provider_spec = provider.spec - if not hasattr(provider_spec, "module"): + if not hasattr(provider_spec, "module") or provider_spec.module is None: raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}") diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 005bfbab8..055bf5232 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -115,6 +115,19 @@ class ProviderSpec(BaseModel): description="If this provider is deprecated and does NOT work, specify the error message here", ) + module: str | None = Field( + default=None, + description=""" + Fully-qualified name of the module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation + + Example: `module: ramalama_stack` + """, + ) + + is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.") + # used internally by the resolver; this is a hack for now deps__: list[str] = Field(default_factory=list) @@ -135,7 +148,7 @@ class AdapterSpec(BaseModel): description="Unique identifier for this adapter", ) module: str = Field( - ..., + default_factory=str, description=""" Fully-qualified name of the module to import. The module is expected to have: @@ -173,14 +186,7 @@ The container image to use for this implementation. If one is provided, pip_pack If a provider depends on other providers, the dependencies MUST NOT specify a container image. """, ) - module: str = Field( - ..., - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - - `get_provider_impl(config, deps)`: returns the local implementation -""", - ) + # module field is inherited from ProviderSpec provider_data_validator: str | None = Field( default=None, ) @@ -223,9 +229,7 @@ API responses, specify the adapter here. def container_image(self) -> str | None: return None - @property - def module(self) -> str: - return self.adapter.module + # module field is inherited from ProviderSpec @property def pip_packages(self) -> list[str]: @@ -243,6 +247,7 @@ def remote_provider_spec( api=api, provider_type=f"remote::{adapter.adapter_type}", config_class=adapter.config_class, + module=adapter.module, adapter=adapter, api_dependencies=api_dependencies or [], )