From adf5eeafb6f0a063d090210f57b507085fa89747 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 4 Nov 2024 20:44:06 -0800 Subject: [PATCH] wip --- llama_stack/cli/stack/build.py | 68 +++++++++++++++++-- .../adapters/memory/pgvector/config.py | 6 +- llama_stack/providers/datatypes.py | 2 +- .../impls/meta_reference/agents/config.py | 5 +- 4 files changed, 68 insertions(+), 13 deletions(-) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 93ca4924e..d2dae9319 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -12,6 +12,10 @@ import os from functools import lru_cache from pathlib import Path +from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.utils.dynamic import instantiate_class_type + + TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates" @@ -176,6 +180,59 @@ class StackBuild(Subcommand): return self._run_stack_build_command_from_build_config(build_config) + def _generate_run_config(self, build_config: BuildConfig, build_dir: Path) -> None: + """ + Generate a run.yaml template file for user to edit from a build.yaml file + """ + import json + + import yaml + from termcolor import cprint + + from llama_stack.distribution.utils.serialize import EnumEncoder + + # TODO: we should make the run.yaml file invisible to users by with ENV variables + # but keeping it visible and exposed to users for now + # generate a default run.yaml file using the build_config + apis = list(build_config.distribution_spec.providers.keys()) + run_config = StackRunConfig( + built_at=datetime.now(), + image_name=build_config.name, + conda_env=build_config.name, + apis=apis, + providers={}, + ) + # build providers dict + provider_registry = get_provider_registry() + for api in apis: + run_config.providers[api] = [] + provider_types = build_config.distribution_spec.providers[api] + if isinstance(provider_types, str): + provider_types = [provider_types] + + for i, provider_type in enumerate(provider_types): + print(provider_type) + p_spec = Provider( + provider_id=f"{provider_type}-{i}", + provider_type=provider_type, + config={}, + ) + config_type = instantiate_class_type( + provider_registry[Api(api)][provider_type].config_class + ) + p_spec.config = config_type() + run_config.providers[api].append(p_spec) + + run_config_file = build_dir / f"{build_config.name}-run.yaml" + with open(run_config_file, "w") as f: + to_write = json.loads(json.dumps(run_config.model_dump(), cls=EnumEncoder)) + f.write(yaml.dump(to_write, sort_keys=False)) + + cprint( + f"You can now edit {run_config_file} and run `llama stack run {run_config_file}`", + color="green", + ) + def _run_stack_build_command_from_build_config( self, build_config: BuildConfig ) -> None: @@ -203,19 +260,16 @@ class StackBuild(Subcommand): build_file_path = build_dir / f"{build_config.name}-build.yaml" with open(build_file_path, "w") as f: - to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder)) + to_write = json.loads( + json.dumps(build_config.model_dump(), cls=EnumEncoder) + ) f.write(yaml.dump(to_write, sort_keys=False)) return_code = build_image(build_config, build_file_path) if return_code != 0: return - # TODO: we should make the run.yaml file invisible to users by with ENV variables - # but keeping it visiable and exposed to users for now - cprint( - f"You can now edit run.yaml files in ./llama-stack/distributions/ and run `llama stack run `", - color="green", - ) + self._generate_run_config(build_config, build_dir) def _run_template_list_cmd(self, args: argparse.Namespace) -> None: import json diff --git a/llama_stack/providers/adapters/memory/pgvector/config.py b/llama_stack/providers/adapters/memory/pgvector/config.py index 87b2f4a3b..405995842 100644 --- a/llama_stack/providers/adapters/memory/pgvector/config.py +++ b/llama_stack/providers/adapters/memory/pgvector/config.py @@ -12,6 +12,6 @@ from pydantic import BaseModel, Field class PGVectorConfig(BaseModel): host: str = Field(default="localhost") port: int = Field(default=5432) - db: str - user: str - password: str + db: str = Field(default="postgres_db") + user: str = Field(default="postgres_user") + password: str = Field(default="postgres_password") diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 9a37a28a9..c0682df35 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -145,7 +145,7 @@ Fully-qualified name of the module to import. The module is expected to have: class RemoteProviderConfig(BaseModel): host: str = "localhost" - port: int + port: int = 0 @property def url(self) -> str: diff --git a/llama_stack/providers/impls/meta_reference/agents/config.py b/llama_stack/providers/impls/meta_reference/agents/config.py index 0146cb436..2770ed13c 100644 --- a/llama_stack/providers/impls/meta_reference/agents/config.py +++ b/llama_stack/providers/impls/meta_reference/agents/config.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel +from pydantic import BaseModel, Field from llama_stack.providers.utils.kvstore import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class MetaReferenceAgentsImplConfig(BaseModel): - persistence_store: KVStoreConfig + persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())