mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
feat: refactor distro codegen
rework template.py to generate all provider_configs in a directory called `provider_configs/API` where API groups configs for a specific API together to avoid naming collisions Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
31cc971503
commit
a274795532
3 changed files with 44 additions and 14 deletions
|
@ -30,8 +30,8 @@ class SqlAlchemySqlStoreConfig(BaseModel):
|
|||
def engine_str(self) -> str: ...
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
|
||||
|
||||
|
@ -53,9 +53,9 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
|||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return super().pip_packages + ["aiosqlite"]
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return super().pip_packages() + ["aiosqlite"]
|
||||
|
||||
|
||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
|
@ -70,9 +70,9 @@ class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
|||
def engine_str(self) -> str:
|
||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return super().pip_packages + ["asyncpg"]
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return super().pip_packages() + ["asyncpg"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
|
|
|
@ -234,7 +234,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
|
||||
default_models = get_model_registry(available_models)
|
||||
|
||||
postgres_store = PostgresSqlStoreConfig.sample_run_config()
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
|
@ -243,7 +242,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
additional_pip_packages=postgres_store.pip_packages,
|
||||
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
|
@ -94,12 +95,14 @@ class RunConfigSettings(BaseModel):
|
|||
self,
|
||||
name: str,
|
||||
providers: dict[str, list[str]],
|
||||
yaml_output_dir: Path | None = None,
|
||||
container_image: str | None = None,
|
||||
) -> StackRunConfig:
|
||||
provider_registry = get_provider_registry()
|
||||
|
||||
provider_configs = {}
|
||||
for api_str, provider_types in providers.items():
|
||||
# TODO: is this necessary with provider configs? all this does is allow you to hardcode a provider in the `get_distribution_template`
|
||||
if api_providers := self.provider_overrides.get(api_str):
|
||||
provider_configs[api_str] = api_providers
|
||||
continue
|
||||
|
@ -123,11 +126,28 @@ class RunConfigSettings(BaseModel):
|
|||
else:
|
||||
config = {}
|
||||
|
||||
template_path = None
|
||||
if yaml_output_dir and config:
|
||||
path = os.path.join(yaml_output_dir, "provider_configs", api_str, f"{provider_id}.yaml")
|
||||
template_path = os.path.join(
|
||||
"~/.llama/distributions",
|
||||
yaml_output_dir.name,
|
||||
"provider_configs",
|
||||
api_str,
|
||||
f"{provider_id}.yaml",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
yaml.safe_dump(
|
||||
config,
|
||||
f,
|
||||
sort_keys=False,
|
||||
)
|
||||
provider_configs[api_str].append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
config=config,
|
||||
config=template_path if template_path is not None else config,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -191,7 +211,7 @@ class DistributionTemplate(BaseModel):
|
|||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||
# and providers, with a way to specify dependencies for them.
|
||||
if run_config_.inference_store:
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages())
|
||||
if run_config_.metadata_store:
|
||||
additional_pip_packages.extend(run_config_.metadata_store.pip_packages)
|
||||
|
||||
|
@ -283,10 +303,21 @@ class DistributionTemplate(BaseModel):
|
|||
)
|
||||
|
||||
for yaml_pth, settings in self.run_configs.items():
|
||||
run_config = settings.run_config(self.name, self.providers, self.container_image)
|
||||
run_config = settings.run_config(self.name, self.providers, yaml_output_dir, self.container_image)
|
||||
with open(yaml_output_dir / yaml_pth, "w") as f:
|
||||
|
||||
def stringify_paths(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: stringify_paths(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [stringify_paths(v) for v in obj]
|
||||
elif isinstance(obj, Path):
|
||||
return str(obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
yaml.safe_dump(
|
||||
run_config.model_dump(exclude_none=True),
|
||||
stringify_paths(run_config.model_dump(exclude_none=True)),
|
||||
f,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue