refactor: Align build and run provider datatypes

introduce the concept of a `module` for users to specify for a provider upon build time.
In order to facilitate this, align the build and run spec to use `Provider` class rather than the stringified provider_type that build currently uses.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-06 19:56:18 -04:00
parent 21bae296f2
commit 233f8c81bf
6 changed files with 48 additions and 38 deletions

View file

@ -402,7 +402,7 @@ def _run_stack_build_command_from_build_config(
run_config_file = _generate_run_config(build_config, build_dir, image_name)
with open(build_file_path, "w") as f:
to_write = json.loads(build_config.model_dump_json())
to_write = json.loads(build_config.model_dump_json(exclude_none=True))
f.write(yaml.dump(to_write, sort_keys=False))
# We first install the external APIs so that the build process can use them and discover the

View file

@ -91,21 +91,21 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
logger.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider_type in enumerate(plist):
for i, provider in enumerate(plist):
if i >= 1:
others = ", ".join(plist[i:])
others = ", ".join(p.provider_type for p in plist[i:])
logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
logger.info(f"> Configuring provider `({provider_type})`")
logger.info(f"> Configuring provider `({provider.provider_type})`")
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
provider_type=provider_type,
provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id),
provider_type=provider.provider_type,
config={},
),
)

View file

@ -136,29 +136,40 @@ class RoutingTableProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(default_factory=list)
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any] = {}
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class DistributionSpec(BaseModel):
description: str | None = Field(
default="",
description="Description of the distribution",
)
container_image: str | None = None
providers: dict[str, str | list[str]] = Field(
providers: dict[str, list[Provider]] = Field(
default_factory=dict,
description="""
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.""",
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.
""",
)
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any]
class LoggingConfig(BaseModel):
category_levels: dict[str, str] = Field(
default_factory=dict,

View file

@ -249,15 +249,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
file=sys.stderr,
)
if self.config_path_or_template_name.endswith(".yaml"):
# Convert Provider objects to their types
provider_types: dict[str, str | list[str]] = {}
for api, providers in self.config.providers.items():
types = [p.provider_type for p in providers]
# Convert single-item lists to strings
provider_types[api] = types[0] if len(types) == 1 else types
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=provider_types,
providers=self.config.providers,
),
external_providers_dir=self.config.external_providers_dir,
)

View file

@ -345,7 +345,7 @@ async def instantiate_provider(
policy: list[AccessRule],
):
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")

View file

@ -115,6 +115,19 @@ class ProviderSpec(BaseModel):
description="If this provider is deprecated and does NOT work, specify the error message here",
)
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now
deps__: list[str] = Field(default_factory=list)
@ -135,7 +148,7 @@ class AdapterSpec(BaseModel):
description="Unique identifier for this adapter",
)
module: str = Field(
...,
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:
@ -173,14 +186,7 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
@ -223,9 +229,7 @@ API responses, specify the adapter here.
def container_image(self) -> str | None:
return None
@property
def module(self) -> str:
return self.adapter.module
# module field is inherited from ProviderSpec
@property
def pip_packages(self) -> list[str]:
@ -243,6 +247,7 @@ def remote_provider_spec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
)