mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
feat: refactor distro codegen
rework template.py to generate all provider_configs in a directory called `provider_configs/API` where API groups configs for a specific API together to avoid naming collisions Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
31cc971503
commit
a274795532
3 changed files with 44 additions and 14 deletions
|
@ -30,8 +30,8 @@ class SqlAlchemySqlStoreConfig(BaseModel):
|
||||||
def engine_str(self) -> str: ...
|
def engine_str(self) -> str: ...
|
||||||
|
|
||||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||||
@property
|
@classmethod
|
||||||
def pip_packages(self) -> list[str]:
|
def pip_packages(cls) -> list[str]:
|
||||||
return ["sqlalchemy[asyncio]"]
|
return ["sqlalchemy[asyncio]"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,9 +53,9 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def pip_packages(self) -> list[str]:
|
def pip_packages(cls) -> list[str]:
|
||||||
return super().pip_packages + ["aiosqlite"]
|
return super().pip_packages() + ["aiosqlite"]
|
||||||
|
|
||||||
|
|
||||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
|
@ -70,9 +70,9 @@ class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
def engine_str(self) -> str:
|
def engine_str(self) -> str:
|
||||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def pip_packages(self) -> list[str]:
|
def pip_packages(cls) -> list[str]:
|
||||||
return super().pip_packages + ["asyncpg"]
|
return super().pip_packages() + ["asyncpg"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs):
|
def sample_run_config(cls, **kwargs):
|
||||||
|
|
|
@ -234,7 +234,6 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models = get_model_registry(available_models)
|
||||||
|
|
||||||
postgres_store = PostgresSqlStoreConfig.sample_run_config()
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
|
@ -243,7 +242,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
template_path=None,
|
template_path=None,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
available_models_by_provider=available_models,
|
available_models_by_provider=available_models,
|
||||||
additional_pip_packages=postgres_store.pip_packages,
|
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
@ -94,12 +95,14 @@ class RunConfigSettings(BaseModel):
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
providers: dict[str, list[str]],
|
providers: dict[str, list[str]],
|
||||||
|
yaml_output_dir: Path | None = None,
|
||||||
container_image: str | None = None,
|
container_image: str | None = None,
|
||||||
) -> StackRunConfig:
|
) -> StackRunConfig:
|
||||||
provider_registry = get_provider_registry()
|
provider_registry = get_provider_registry()
|
||||||
|
|
||||||
provider_configs = {}
|
provider_configs = {}
|
||||||
for api_str, provider_types in providers.items():
|
for api_str, provider_types in providers.items():
|
||||||
|
# TODO: is this necessary with provider configs? all this does is allow you to hardcode a provider in the `get_distribution_template`
|
||||||
if api_providers := self.provider_overrides.get(api_str):
|
if api_providers := self.provider_overrides.get(api_str):
|
||||||
provider_configs[api_str] = api_providers
|
provider_configs[api_str] = api_providers
|
||||||
continue
|
continue
|
||||||
|
@ -123,11 +126,28 @@ class RunConfigSettings(BaseModel):
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
|
template_path = None
|
||||||
|
if yaml_output_dir and config:
|
||||||
|
path = os.path.join(yaml_output_dir, "provider_configs", api_str, f"{provider_id}.yaml")
|
||||||
|
template_path = os.path.join(
|
||||||
|
"~/.llama/distributions",
|
||||||
|
yaml_output_dir.name,
|
||||||
|
"provider_configs",
|
||||||
|
api_str,
|
||||||
|
f"{provider_id}.yaml",
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
yaml.safe_dump(
|
||||||
|
config,
|
||||||
|
f,
|
||||||
|
sort_keys=False,
|
||||||
|
)
|
||||||
provider_configs[api_str].append(
|
provider_configs[api_str].append(
|
||||||
Provider(
|
Provider(
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
config=config,
|
config=template_path if template_path is not None else config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -191,7 +211,7 @@ class DistributionTemplate(BaseModel):
|
||||||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||||
# and providers, with a way to specify dependencies for them.
|
# and providers, with a way to specify dependencies for them.
|
||||||
if run_config_.inference_store:
|
if run_config_.inference_store:
|
||||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
additional_pip_packages.extend(run_config_.inference_store.pip_packages())
|
||||||
if run_config_.metadata_store:
|
if run_config_.metadata_store:
|
||||||
additional_pip_packages.extend(run_config_.metadata_store.pip_packages)
|
additional_pip_packages.extend(run_config_.metadata_store.pip_packages)
|
||||||
|
|
||||||
|
@ -283,10 +303,21 @@ class DistributionTemplate(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
for yaml_pth, settings in self.run_configs.items():
|
for yaml_pth, settings in self.run_configs.items():
|
||||||
run_config = settings.run_config(self.name, self.providers, self.container_image)
|
run_config = settings.run_config(self.name, self.providers, yaml_output_dir, self.container_image)
|
||||||
with open(yaml_output_dir / yaml_pth, "w") as f:
|
with open(yaml_output_dir / yaml_pth, "w") as f:
|
||||||
|
|
||||||
|
def stringify_paths(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: stringify_paths(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [stringify_paths(v) for v in obj]
|
||||||
|
elif isinstance(obj, Path):
|
||||||
|
return str(obj)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
yaml.safe_dump(
|
yaml.safe_dump(
|
||||||
run_config.model_dump(exclude_none=True),
|
stringify_paths(run_config.model_dump(exclude_none=True)),
|
||||||
f,
|
f,
|
||||||
sort_keys=False,
|
sort_keys=False,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue