refactor: Align build and run provider datatypes

introduce the concept of a `module` for users to specify for a provider upon build time.
In order to facilitate this, align the build and run spec to use `Provider` class rather than the stringified provider_type that build currently uses.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-06 19:56:18 -04:00
parent 21bae296f2
commit 233f8c81bf
6 changed files with 48 additions and 38 deletions

View file

@ -402,7 +402,7 @@ def _run_stack_build_command_from_build_config(
run_config_file = _generate_run_config(build_config, build_dir, image_name) run_config_file = _generate_run_config(build_config, build_dir, image_name)
with open(build_file_path, "w") as f: with open(build_file_path, "w") as f:
to_write = json.loads(build_config.model_dump_json()) to_write = json.loads(build_config.model_dump_json(exclude_none=True))
f.write(yaml.dump(to_write, sort_keys=False)) f.write(yaml.dump(to_write, sort_keys=False))
# We first install the external APIs so that the build process can use them and discover the # We first install the external APIs so that the build process can use them and discover the

View file

@ -91,21 +91,21 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
logger.info(f"Configuring API `{api_str}`...") logger.info(f"Configuring API `{api_str}`...")
updated_providers = [] updated_providers = []
for i, provider_type in enumerate(plist): for i, provider in enumerate(plist):
if i >= 1: if i >= 1:
others = ", ".join(plist[i:]) others = ", ".join(p.provider_type for p in plist[i:])
logger.info( logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n" f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
) )
break break
logger.info(f"> Configuring provider `({provider_type})`") logger.info(f"> Configuring provider `({provider.provider_type})`")
updated_providers.append( updated_providers.append(
configure_single_provider( configure_single_provider(
provider_registry[api], provider_registry[api],
Provider( Provider(
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type), provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id),
provider_type=provider_type, provider_type=provider.provider_type,
config={}, config={},
), ),
) )

View file

@ -136,29 +136,40 @@ class RoutingTableProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(default_factory=list) pip_packages: list[str] = Field(default_factory=list)
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any] = {}
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class DistributionSpec(BaseModel): class DistributionSpec(BaseModel):
description: str | None = Field( description: str | None = Field(
default="", default="",
description="Description of the distribution", description="Description of the distribution",
) )
container_image: str | None = None container_image: str | None = None
providers: dict[str, str | list[str]] = Field( providers: dict[str, list[Provider]] = Field(
default_factory=dict, default_factory=dict,
description=""" description="""
Provider Types for each of the APIs provided by this distribution. If you Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map' select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.""", in the runtime configuration to help route to the correct provider.
""",
) )
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any]
class LoggingConfig(BaseModel): class LoggingConfig(BaseModel):
category_levels: dict[str, str] = Field( category_levels: dict[str, str] = Field(
default_factory=dict, default_factory=dict,

View file

@ -249,15 +249,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
file=sys.stderr, file=sys.stderr,
) )
if self.config_path_or_template_name.endswith(".yaml"): if self.config_path_or_template_name.endswith(".yaml"):
# Convert Provider objects to their types
provider_types: dict[str, str | list[str]] = {}
for api, providers in self.config.providers.items():
types = [p.provider_type for p in providers]
# Convert single-item lists to strings
provider_types[api] = types[0] if len(types) == 1 else types
build_config = BuildConfig( build_config = BuildConfig(
distribution_spec=DistributionSpec( distribution_spec=DistributionSpec(
providers=provider_types, providers=self.config.providers,
), ),
external_providers_dir=self.config.external_providers_dir, external_providers_dir=self.config.external_providers_dir,
) )

View file

@ -345,7 +345,7 @@ async def instantiate_provider(
policy: list[AccessRule], policy: list[AccessRule],
): ):
provider_spec = provider.spec provider_spec = provider.spec
if not hasattr(provider_spec, "module"): if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}") logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")

View file

@ -115,6 +115,19 @@ class ProviderSpec(BaseModel):
description="If this provider is deprecated and does NOT work, specify the error message here", description="If this provider is deprecated and does NOT work, specify the error message here",
) )
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now # used internally by the resolver; this is a hack for now
deps__: list[str] = Field(default_factory=list) deps__: list[str] = Field(default_factory=list)
@ -135,7 +148,7 @@ class AdapterSpec(BaseModel):
description="Unique identifier for this adapter", description="Unique identifier for this adapter",
) )
module: str = Field( module: str = Field(
..., default_factory=str,
description=""" description="""
Fully-qualified name of the module to import. The module is expected to have: Fully-qualified name of the module to import. The module is expected to have:
@ -173,14 +186,7 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image. If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""", """,
) )
module: str = Field( # module field is inherited from ProviderSpec
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
provider_data_validator: str | None = Field( provider_data_validator: str | None = Field(
default=None, default=None,
) )
@ -223,9 +229,7 @@ API responses, specify the adapter here.
def container_image(self) -> str | None: def container_image(self) -> str | None:
return None return None
@property # module field is inherited from ProviderSpec
def module(self) -> str:
return self.adapter.module
@property @property
def pip_packages(self) -> list[str]: def pip_packages(self) -> list[str]:
@ -243,6 +247,7 @@ def remote_provider_spec(
api=api, api=api,
provider_type=f"remote::{adapter.adapter_type}", provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class, config_class=adapter.config_class,
module=adapter.module,
adapter=adapter, adapter=adapter,
api_dependencies=api_dependencies or [], api_dependencies=api_dependencies or [],
) )