mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-07 14:26:44 +00:00
fix: store configs (#2593)
# What does this PR do? https://github.com/meta-llama/llama-stack/pull/2490 broke postgres_demo, as the config expected a str but the value was converted to int. This PR: 1. Updates the type of port in sqlstore to be int 2. template generation uses `dict` instead of `StackRunConfig` so as to avoid failing pydantic typechecks. 3. Adds `replace_env_vars` to StackRunConfig instantiation in `configure.py` (not sure why this wasn't needed before). ## Test Plan `llama stack build --template postgres_demo --image-type conda --run`
This commit is contained in:
parent
aa273944fd
commit
3c43a2f529
47 changed files with 110 additions and 223 deletions
|
@ -15,6 +15,7 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.apis.datasets import DatasetPurpose
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
Api,
|
||||
BenchmarkInput,
|
||||
BuildConfig,
|
||||
|
@ -23,14 +24,15 @@ from llama_stack.distribution.datatypes import (
|
|||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
StackRunConfig,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
||||
|
||||
|
||||
def get_model_registry(
|
||||
|
@ -87,21 +89,24 @@ class RunConfigSettings(BaseModel):
|
|||
default_tool_groups: list[ToolGroupInput] | None = None
|
||||
default_datasets: list[DatasetInput] | None = None
|
||||
default_benchmarks: list[BenchmarkInput] | None = None
|
||||
metadata_store: KVStoreConfig | None = None
|
||||
inference_store: SqlStoreConfig | None = None
|
||||
metadata_store: dict | None = None
|
||||
inference_store: dict | None = None
|
||||
|
||||
def run_config(
|
||||
self,
|
||||
name: str,
|
||||
providers: dict[str, list[str]],
|
||||
container_image: str | None = None,
|
||||
) -> StackRunConfig:
|
||||
) -> dict:
|
||||
provider_registry = get_provider_registry()
|
||||
|
||||
provider_configs = {}
|
||||
for api_str, provider_types in providers.items():
|
||||
if api_providers := self.provider_overrides.get(api_str):
|
||||
provider_configs[api_str] = api_providers
|
||||
# Convert Provider objects to dicts for YAML serialization
|
||||
provider_configs[api_str] = [
|
||||
p.model_dump(exclude_none=True) if isinstance(p, Provider) else p for p in api_providers
|
||||
]
|
||||
continue
|
||||
|
||||
provider_configs[api_str] = []
|
||||
|
@ -128,33 +133,40 @@ class RunConfigSettings(BaseModel):
|
|||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
config=config,
|
||||
)
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
# Get unique set of APIs from providers
|
||||
apis = sorted(providers.keys())
|
||||
|
||||
return StackRunConfig(
|
||||
image_name=name,
|
||||
container_image=container_image,
|
||||
apis=apis,
|
||||
providers=provider_configs,
|
||||
metadata_store=self.metadata_store
|
||||
# Return a dict that matches StackRunConfig structure
|
||||
return {
|
||||
"version": LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
"image_name": name,
|
||||
"container_image": container_image,
|
||||
"apis": apis,
|
||||
"providers": provider_configs,
|
||||
"metadata_store": self.metadata_store
|
||||
or SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="registry.db",
|
||||
),
|
||||
inference_store=self.inference_store
|
||||
"inference_store": self.inference_store
|
||||
or SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="inference_store.db",
|
||||
),
|
||||
models=self.default_models or [],
|
||||
shields=self.default_shields or [],
|
||||
tool_groups=self.default_tool_groups or [],
|
||||
datasets=self.default_datasets or [],
|
||||
benchmarks=self.default_benchmarks or [],
|
||||
)
|
||||
"models": [m.model_dump(exclude_none=True) for m in (self.default_models or [])],
|
||||
"shields": [s.model_dump(exclude_none=True) for s in (self.default_shields or [])],
|
||||
"vector_dbs": [],
|
||||
"datasets": [d.model_dump(exclude_none=True) for d in (self.default_datasets or [])],
|
||||
"scoring_fns": [],
|
||||
"benchmarks": [b.model_dump(exclude_none=True) for b in (self.default_benchmarks or [])],
|
||||
"tool_groups": [t.model_dump(exclude_none=True) for t in (self.default_tool_groups or [])],
|
||||
"server": {
|
||||
"port": 8321,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DistributionTemplate(BaseModel):
|
||||
|
@ -190,10 +202,12 @@ class DistributionTemplate(BaseModel):
|
|||
# TODO: This is a hack to get the dependencies for internal APIs into build
|
||||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||
# and providers, with a way to specify dependencies for them.
|
||||
if run_config_.inference_store:
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||
if run_config_.metadata_store:
|
||||
additional_pip_packages.extend(run_config_.metadata_store.pip_packages)
|
||||
|
||||
if run_config_.get("inference_store"):
|
||||
additional_pip_packages.extend(get_sql_pip_packages(run_config_["inference_store"]))
|
||||
|
||||
if run_config_.get("metadata_store"):
|
||||
additional_pip_packages.extend(get_kv_pip_packages(run_config_["metadata_store"]))
|
||||
|
||||
if self.additional_pip_packages:
|
||||
additional_pip_packages.extend(self.additional_pip_packages)
|
||||
|
@ -286,7 +300,7 @@ class DistributionTemplate(BaseModel):
|
|||
run_config = settings.run_config(self.name, self.providers, self.container_image)
|
||||
with open(yaml_output_dir / yaml_pth, "w") as f:
|
||||
yaml.safe_dump(
|
||||
run_config.model_dump(exclude_none=True),
|
||||
{k: v for k, v in run_config.items() if v is not None},
|
||||
f,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue