mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-26 11:01:58 +00:00
feat: refactor distro codegen
rework template.py to generate all provider_configs in a directory called `provider_configs/API` where API groups configs for a specific API together to avoid naming collisions Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
31cc971503
commit
a274795532
3 changed files with 44 additions and 14 deletions
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
|
|
@ -94,12 +95,14 @@ class RunConfigSettings(BaseModel):
|
|||
self,
|
||||
name: str,
|
||||
providers: dict[str, list[str]],
|
||||
yaml_output_dir: Path | None = None,
|
||||
container_image: str | None = None,
|
||||
) -> StackRunConfig:
|
||||
provider_registry = get_provider_registry()
|
||||
|
||||
provider_configs = {}
|
||||
for api_str, provider_types in providers.items():
|
||||
# TODO: is this necessary with provider configs? all this does is allow you to hardcode a provider in the `get_distribution_template`
|
||||
if api_providers := self.provider_overrides.get(api_str):
|
||||
provider_configs[api_str] = api_providers
|
||||
continue
|
||||
|
|
@ -123,11 +126,28 @@ class RunConfigSettings(BaseModel):
|
|||
else:
|
||||
config = {}
|
||||
|
||||
template_path = None
|
||||
if yaml_output_dir and config:
|
||||
path = os.path.join(yaml_output_dir, "provider_configs", api_str, f"{provider_id}.yaml")
|
||||
template_path = os.path.join(
|
||||
"~/.llama/distributions",
|
||||
yaml_output_dir.name,
|
||||
"provider_configs",
|
||||
api_str,
|
||||
f"{provider_id}.yaml",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
yaml.safe_dump(
|
||||
config,
|
||||
f,
|
||||
sort_keys=False,
|
||||
)
|
||||
provider_configs[api_str].append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
config=config,
|
||||
config=template_path if template_path is not None else config,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -191,7 +211,7 @@ class DistributionTemplate(BaseModel):
|
|||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||
# and providers, with a way to specify dependencies for them.
|
||||
if run_config_.inference_store:
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages())
|
||||
if run_config_.metadata_store:
|
||||
additional_pip_packages.extend(run_config_.metadata_store.pip_packages)
|
||||
|
||||
|
|
@ -283,10 +303,21 @@ class DistributionTemplate(BaseModel):
|
|||
)
|
||||
|
||||
for yaml_pth, settings in self.run_configs.items():
|
||||
run_config = settings.run_config(self.name, self.providers, self.container_image)
|
||||
run_config = settings.run_config(self.name, self.providers, yaml_output_dir, self.container_image)
|
||||
with open(yaml_output_dir / yaml_pth, "w") as f:
|
||||
|
||||
def stringify_paths(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: stringify_paths(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [stringify_paths(v) for v in obj]
|
||||
elif isinstance(obj, Path):
|
||||
return str(obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
yaml.safe_dump(
|
||||
run_config.model_dump(exclude_none=True),
|
||||
stringify_paths(run_config.model_dump(exclude_none=True)),
|
||||
f,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue