feat: refactor distro codegen

rework template.py to generate all provider_configs in a directory called `provider_configs/API` where API groups configs for a specific API together to avoid naming collisions

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-02 13:20:47 -04:00
parent 31cc971503
commit a274795532
3 changed files with 44 additions and 14 deletions

View file

@ -30,8 +30,8 @@ class SqlAlchemySqlStoreConfig(BaseModel):
def engine_str(self) -> str: ... def engine_str(self) -> str: ...
# TODO: move this when we have a better way to specify dependencies with internal APIs # TODO: move this when we have a better way to specify dependencies with internal APIs
@property @classmethod
def pip_packages(self) -> list[str]: def pip_packages(cls) -> list[str]:
return ["sqlalchemy[asyncio]"] return ["sqlalchemy[asyncio]"]
@ -53,9 +53,9 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name, "db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
} }
@property @classmethod
def pip_packages(self) -> list[str]: def pip_packages(cls) -> list[str]:
return super().pip_packages + ["aiosqlite"] return super().pip_packages() + ["aiosqlite"]
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig): class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
@ -70,9 +70,9 @@ class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
def engine_str(self) -> str: def engine_str(self) -> str:
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}" return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
@property @classmethod
def pip_packages(self) -> list[str]: def pip_packages(cls) -> list[str]:
return super().pip_packages + ["asyncpg"] return super().pip_packages() + ["asyncpg"]
@classmethod @classmethod
def sample_run_config(cls, **kwargs): def sample_run_config(cls, **kwargs):

View file

@ -234,7 +234,6 @@ def get_distribution_template() -> DistributionTemplate:
default_models = get_model_registry(available_models) default_models = get_model_registry(available_models)
postgres_store = PostgresSqlStoreConfig.sample_run_config()
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
distro_type="self_hosted", distro_type="self_hosted",
@ -243,7 +242,7 @@ def get_distribution_template() -> DistributionTemplate:
template_path=None, template_path=None,
providers=providers, providers=providers,
available_models_by_provider=available_models, available_models_by_provider=available_models,
additional_pip_packages=postgres_store.pip_packages, additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
@ -94,12 +95,14 @@ class RunConfigSettings(BaseModel):
self, self,
name: str, name: str,
providers: dict[str, list[str]], providers: dict[str, list[str]],
yaml_output_dir: Path | None = None,
container_image: str | None = None, container_image: str | None = None,
) -> StackRunConfig: ) -> StackRunConfig:
provider_registry = get_provider_registry() provider_registry = get_provider_registry()
provider_configs = {} provider_configs = {}
for api_str, provider_types in providers.items(): for api_str, provider_types in providers.items():
# TODO: is this necessary with provider configs? all this does is allow you to hardcode a provider in the `get_distribution_template`
if api_providers := self.provider_overrides.get(api_str): if api_providers := self.provider_overrides.get(api_str):
provider_configs[api_str] = api_providers provider_configs[api_str] = api_providers
continue continue
@ -123,11 +126,28 @@ class RunConfigSettings(BaseModel):
else: else:
config = {} config = {}
template_path = None
if yaml_output_dir and config:
path = os.path.join(yaml_output_dir, "provider_configs", api_str, f"{provider_id}.yaml")
template_path = os.path.join(
"~/.llama/distributions",
yaml_output_dir.name,
"provider_configs",
api_str,
f"{provider_id}.yaml",
)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
yaml.safe_dump(
config,
f,
sort_keys=False,
)
provider_configs[api_str].append( provider_configs[api_str].append(
Provider( Provider(
provider_id=provider_id, provider_id=provider_id,
provider_type=provider_type, provider_type=provider_type,
config=config, config=template_path if template_path is not None else config,
) )
) )
@ -191,7 +211,7 @@ class DistributionTemplate(BaseModel):
# We should have a better way to do this by formalizing the concept of "internal" APIs # We should have a better way to do this by formalizing the concept of "internal" APIs
# and providers, with a way to specify dependencies for them. # and providers, with a way to specify dependencies for them.
if run_config_.inference_store: if run_config_.inference_store:
additional_pip_packages.extend(run_config_.inference_store.pip_packages) additional_pip_packages.extend(run_config_.inference_store.pip_packages())
if run_config_.metadata_store: if run_config_.metadata_store:
additional_pip_packages.extend(run_config_.metadata_store.pip_packages) additional_pip_packages.extend(run_config_.metadata_store.pip_packages)
@ -283,10 +303,21 @@ class DistributionTemplate(BaseModel):
) )
for yaml_pth, settings in self.run_configs.items(): for yaml_pth, settings in self.run_configs.items():
run_config = settings.run_config(self.name, self.providers, self.container_image) run_config = settings.run_config(self.name, self.providers, yaml_output_dir, self.container_image)
with open(yaml_output_dir / yaml_pth, "w") as f: with open(yaml_output_dir / yaml_pth, "w") as f:
def stringify_paths(obj):
if isinstance(obj, dict):
return {k: stringify_paths(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [stringify_paths(v) for v in obj]
elif isinstance(obj, Path):
return str(obj)
else:
return obj
yaml.safe_dump( yaml.safe_dump(
run_config.model_dump(exclude_none=True), stringify_paths(run_config.model_dump(exclude_none=True)),
f, f,
sort_keys=False, sort_keys=False,
) )