mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
wip
This commit is contained in:
parent
a42943c25b
commit
adf5eeafb6
4 changed files with 68 additions and 13 deletions
|
@ -12,6 +12,10 @@ import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates"
|
TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates"
|
||||||
|
|
||||||
|
|
||||||
|
@ -176,6 +180,59 @@ class StackBuild(Subcommand):
|
||||||
return
|
return
|
||||||
self._run_stack_build_command_from_build_config(build_config)
|
self._run_stack_build_command_from_build_config(build_config)
|
||||||
|
|
||||||
|
def _generate_run_config(self, build_config: BuildConfig, build_dir: Path) -> None:
|
||||||
|
"""
|
||||||
|
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||||
|
|
||||||
|
# TODO: we should make the run.yaml file invisible to users by with ENV variables
|
||||||
|
# but keeping it visible and exposed to users for now
|
||||||
|
# generate a default run.yaml file using the build_config
|
||||||
|
apis = list(build_config.distribution_spec.providers.keys())
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
built_at=datetime.now(),
|
||||||
|
image_name=build_config.name,
|
||||||
|
conda_env=build_config.name,
|
||||||
|
apis=apis,
|
||||||
|
providers={},
|
||||||
|
)
|
||||||
|
# build providers dict
|
||||||
|
provider_registry = get_provider_registry()
|
||||||
|
for api in apis:
|
||||||
|
run_config.providers[api] = []
|
||||||
|
provider_types = build_config.distribution_spec.providers[api]
|
||||||
|
if isinstance(provider_types, str):
|
||||||
|
provider_types = [provider_types]
|
||||||
|
|
||||||
|
for i, provider_type in enumerate(provider_types):
|
||||||
|
print(provider_type)
|
||||||
|
p_spec = Provider(
|
||||||
|
provider_id=f"{provider_type}-{i}",
|
||||||
|
provider_type=provider_type,
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
config_type = instantiate_class_type(
|
||||||
|
provider_registry[Api(api)][provider_type].config_class
|
||||||
|
)
|
||||||
|
p_spec.config = config_type()
|
||||||
|
run_config.providers[api].append(p_spec)
|
||||||
|
|
||||||
|
run_config_file = build_dir / f"{build_config.name}-run.yaml"
|
||||||
|
with open(run_config_file, "w") as f:
|
||||||
|
to_write = json.loads(json.dumps(run_config.model_dump(), cls=EnumEncoder))
|
||||||
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
|
cprint(
|
||||||
|
f"You can now edit {run_config_file} and run `llama stack run {run_config_file}`",
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_build_command_from_build_config(
|
def _run_stack_build_command_from_build_config(
|
||||||
self, build_config: BuildConfig
|
self, build_config: BuildConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -203,19 +260,16 @@ class StackBuild(Subcommand):
|
||||||
build_file_path = build_dir / f"{build_config.name}-build.yaml"
|
build_file_path = build_dir / f"{build_config.name}-build.yaml"
|
||||||
|
|
||||||
with open(build_file_path, "w") as f:
|
with open(build_file_path, "w") as f:
|
||||||
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
|
to_write = json.loads(
|
||||||
|
json.dumps(build_config.model_dump(), cls=EnumEncoder)
|
||||||
|
)
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
return_code = build_image(build_config, build_file_path)
|
return_code = build_image(build_config, build_file_path)
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: we should make the run.yaml file invisible to users by with ENV variables
|
self._generate_run_config(build_config, build_dir)
|
||||||
# but keeping it visiable and exposed to users for now
|
|
||||||
cprint(
|
|
||||||
f"You can now edit run.yaml files in ./llama-stack/distributions/ and run `llama stack run <path/to/run.yaml>`",
|
|
||||||
color="green",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import json
|
import json
|
||||||
|
|
|
@ -12,6 +12,6 @@ from pydantic import BaseModel, Field
|
||||||
class PGVectorConfig(BaseModel):
|
class PGVectorConfig(BaseModel):
|
||||||
host: str = Field(default="localhost")
|
host: str = Field(default="localhost")
|
||||||
port: int = Field(default=5432)
|
port: int = Field(default=5432)
|
||||||
db: str
|
db: str = Field(default="postgres_db")
|
||||||
user: str
|
user: str = Field(default="postgres_user")
|
||||||
password: str
|
password: str = Field(default="postgres_password")
|
||||||
|
|
|
@ -145,7 +145,7 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
class RemoteProviderConfig(BaseModel):
|
class RemoteProviderConfig(BaseModel):
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: int
|
port: int = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
|
|
|
@ -4,10 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||||
persistence_store: KVStoreConfig
|
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue