mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
refactor: Align build and run provider datatypes
introduce the concept of a `module` for users to specify for a provider upon build time. In order to facilitate this, align the build and run spec to use `Provider` class rather than the stringified provider_type that build currently uses. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
21bae296f2
commit
233f8c81bf
6 changed files with 48 additions and 38 deletions
|
@ -402,7 +402,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
with open(build_file_path, "w") as f:
|
with open(build_file_path, "w") as f:
|
||||||
to_write = json.loads(build_config.model_dump_json())
|
to_write = json.loads(build_config.model_dump_json(exclude_none=True))
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
# We first install the external APIs so that the build process can use them and discover the
|
# We first install the external APIs so that the build process can use them and discover the
|
||||||
|
|
|
@ -91,21 +91,21 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
|
|
||||||
logger.info(f"Configuring API `{api_str}`...")
|
logger.info(f"Configuring API `{api_str}`...")
|
||||||
updated_providers = []
|
updated_providers = []
|
||||||
for i, provider_type in enumerate(plist):
|
for i, provider in enumerate(plist):
|
||||||
if i >= 1:
|
if i >= 1:
|
||||||
others = ", ".join(plist[i:])
|
others = ", ".join(p.provider_type for p in plist[i:])
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"> Configuring provider `({provider_type})`")
|
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
||||||
updated_providers.append(
|
updated_providers.append(
|
||||||
configure_single_provider(
|
configure_single_provider(
|
||||||
provider_registry[api],
|
provider_registry[api],
|
||||||
Provider(
|
Provider(
|
||||||
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
|
provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id),
|
||||||
provider_type=provider_type,
|
provider_type=provider.provider_type,
|
||||||
config={},
|
config={},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -136,29 +136,40 @@ class RoutingTableProviderSpec(ProviderSpec):
|
||||||
pip_packages: list[str] = Field(default_factory=list)
|
pip_packages: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class Provider(BaseModel):
|
||||||
|
# provider_id of None means that the provider is not enabled - this happens
|
||||||
|
# when the provider is enabled via a conditional environment variable
|
||||||
|
provider_id: str | None
|
||||||
|
provider_type: str
|
||||||
|
config: dict[str, Any] = {}
|
||||||
|
module: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||||
|
|
||||||
|
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||||
|
|
||||||
|
Example: `module: ramalama_stack`
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DistributionSpec(BaseModel):
|
class DistributionSpec(BaseModel):
|
||||||
description: str | None = Field(
|
description: str | None = Field(
|
||||||
default="",
|
default="",
|
||||||
description="Description of the distribution",
|
description="Description of the distribution",
|
||||||
)
|
)
|
||||||
container_image: str | None = None
|
container_image: str | None = None
|
||||||
providers: dict[str, str | list[str]] = Field(
|
providers: dict[str, list[Provider]] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="""
|
description="""
|
||||||
Provider Types for each of the APIs provided by this distribution. If you
|
Provider Types for each of the APIs provided by this distribution. If you
|
||||||
select multiple providers, you should provide an appropriate 'routing_map'
|
select multiple providers, you should provide an appropriate 'routing_map'
|
||||||
in the runtime configuration to help route to the correct provider.""",
|
in the runtime configuration to help route to the correct provider.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Provider(BaseModel):
|
|
||||||
# provider_id of None means that the provider is not enabled - this happens
|
|
||||||
# when the provider is enabled via a conditional environment variable
|
|
||||||
provider_id: str | None
|
|
||||||
provider_type: str
|
|
||||||
config: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class LoggingConfig(BaseModel):
|
class LoggingConfig(BaseModel):
|
||||||
category_levels: dict[str, str] = Field(
|
category_levels: dict[str, str] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
|
|
|
@ -249,15 +249,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
if self.config_path_or_template_name.endswith(".yaml"):
|
if self.config_path_or_template_name.endswith(".yaml"):
|
||||||
# Convert Provider objects to their types
|
|
||||||
provider_types: dict[str, str | list[str]] = {}
|
|
||||||
for api, providers in self.config.providers.items():
|
|
||||||
types = [p.provider_type for p in providers]
|
|
||||||
# Convert single-item lists to strings
|
|
||||||
provider_types[api] = types[0] if len(types) == 1 else types
|
|
||||||
build_config = BuildConfig(
|
build_config = BuildConfig(
|
||||||
distribution_spec=DistributionSpec(
|
distribution_spec=DistributionSpec(
|
||||||
providers=provider_types,
|
providers=self.config.providers,
|
||||||
),
|
),
|
||||||
external_providers_dir=self.config.external_providers_dir,
|
external_providers_dir=self.config.external_providers_dir,
|
||||||
)
|
)
|
||||||
|
|
|
@ -345,7 +345,7 @@ async def instantiate_provider(
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
):
|
):
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module") or provider_spec.module is None:
|
||||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
|
|
||||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||||
|
|
|
@ -115,6 +115,19 @@ class ProviderSpec(BaseModel):
|
||||||
description="If this provider is deprecated and does NOT work, specify the error message here",
|
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
module: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
|
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||||
|
|
||||||
|
Example: `module: ramalama_stack`
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||||
|
|
||||||
# used internally by the resolver; this is a hack for now
|
# used internally by the resolver; this is a hack for now
|
||||||
deps__: list[str] = Field(default_factory=list)
|
deps__: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@ -135,7 +148,7 @@ class AdapterSpec(BaseModel):
|
||||||
description="Unique identifier for this adapter",
|
description="Unique identifier for this adapter",
|
||||||
)
|
)
|
||||||
module: str = Field(
|
module: str = Field(
|
||||||
...,
|
default_factory=str,
|
||||||
description="""
|
description="""
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
|
@ -173,14 +186,7 @@ The container image to use for this implementation. If one is provided, pip_pack
|
||||||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
module: str = Field(
|
# module field is inherited from ProviderSpec
|
||||||
...,
|
|
||||||
description="""
|
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
|
||||||
|
|
||||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
provider_data_validator: str | None = Field(
|
provider_data_validator: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -223,9 +229,7 @@ API responses, specify the adapter here.
|
||||||
def container_image(self) -> str | None:
|
def container_image(self) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
# module field is inherited from ProviderSpec
|
||||||
def module(self) -> str:
|
|
||||||
return self.adapter.module
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> list[str]:
|
def pip_packages(self) -> list[str]:
|
||||||
|
@ -243,6 +247,7 @@ def remote_provider_spec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_type=f"remote::{adapter.adapter_type}",
|
provider_type=f"remote::{adapter.adapter_type}",
|
||||||
config_class=adapter.config_class,
|
config_class=adapter.config_class,
|
||||||
|
module=adapter.module,
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
api_dependencies=api_dependencies or [],
|
api_dependencies=api_dependencies or [],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue