mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
refactor: Align build and run provider datatypes
introduce the concept of a `module` for users to specify for a provider upon build time. In order to facilitate this, align the build and run spec to use `Provider` class rather than the stringified provider_type that build currently uses. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
21bae296f2
commit
233f8c81bf
6 changed files with 48 additions and 38 deletions
|
@ -91,21 +91,21 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
|
||||
logger.info(f"Configuring API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for i, provider_type in enumerate(plist):
|
||||
for i, provider in enumerate(plist):
|
||||
if i >= 1:
|
||||
others = ", ".join(plist[i:])
|
||||
others = ", ".join(p.provider_type for p in plist[i:])
|
||||
logger.info(
|
||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"> Configuring provider `({provider_type})`")
|
||||
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
|
||||
provider_type=provider_type,
|
||||
provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id),
|
||||
provider_type=provider.provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
|
|
|
@ -136,29 +136,40 @@ class RoutingTableProviderSpec(ProviderSpec):
|
|||
pip_packages: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
# provider_id of None means that the provider is not enabled - this happens
|
||||
# when the provider is enabled via a conditional environment variable
|
||||
provider_id: str | None
|
||||
provider_type: str
|
||||
config: dict[str, Any] = {}
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class DistributionSpec(BaseModel):
|
||||
description: str | None = Field(
|
||||
default="",
|
||||
description="Description of the distribution",
|
||||
)
|
||||
container_image: str | None = None
|
||||
providers: dict[str, str | list[str]] = Field(
|
||||
providers: dict[str, list[Provider]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.""",
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
select multiple providers, you should provide an appropriate 'routing_map'
|
||||
in the runtime configuration to help route to the correct provider.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
# provider_id of None means that the provider is not enabled - this happens
|
||||
# when the provider is enabled via a conditional environment variable
|
||||
provider_id: str | None
|
||||
provider_type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
category_levels: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
|
|
|
@ -249,15 +249,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
file=sys.stderr,
|
||||
)
|
||||
if self.config_path_or_template_name.endswith(".yaml"):
|
||||
# Convert Provider objects to their types
|
||||
provider_types: dict[str, str | list[str]] = {}
|
||||
for api, providers in self.config.providers.items():
|
||||
types = [p.provider_type for p in providers]
|
||||
# Convert single-item lists to strings
|
||||
provider_types[api] = types[0] if len(types) == 1 else types
|
||||
build_config = BuildConfig(
|
||||
distribution_spec=DistributionSpec(
|
||||
providers=provider_types,
|
||||
providers=self.config.providers,
|
||||
),
|
||||
external_providers_dir=self.config.external_providers_dir,
|
||||
)
|
||||
|
|
|
@ -345,7 +345,7 @@ async def instantiate_provider(
|
|||
policy: list[AccessRule],
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
if not hasattr(provider_spec, "module") or provider_spec.module is None:
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue