mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
A few bug fixes for covering corner cases
This commit is contained in:
parent
a05599c67a
commit
353c7dc82a
4 changed files with 30 additions and 23 deletions
|
@ -22,7 +22,7 @@ def available_templates_specs() -> List[BuildConfig]:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
template_specs = []
|
template_specs = []
|
||||||
for p in TEMPLATES_PATH.rglob("*.yaml"):
|
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||||
with open(p, "r") as f:
|
with open(p, "r") as f:
|
||||||
build_config = BuildConfig(**yaml.safe_load(f))
|
build_config = BuildConfig(**yaml.safe_load(f))
|
||||||
template_specs.append(build_config)
|
template_specs.append(build_config)
|
||||||
|
|
|
@ -152,7 +152,7 @@ class StackConfigure(Subcommand):
|
||||||
config = StackRunConfig(
|
config = StackRunConfig(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
apis=[],
|
apis=list(build_config.distribution_spec.providers.keys()),
|
||||||
providers={},
|
providers={},
|
||||||
models=[],
|
models=[],
|
||||||
shields=[],
|
shields=[],
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class StackRun(Subcommand):
|
class StackRun(Subcommand):
|
||||||
|
@ -49,8 +48,8 @@ class StackRun(Subcommand):
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.build import ImageType
|
from llama_stack.distribution.build import ImageType
|
||||||
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||||
|
|
||||||
from llama_stack.distribution.utils.exec import run_with_pty
|
from llama_stack.distribution.utils.exec import run_with_pty
|
||||||
|
|
||||||
if not args.config:
|
if not args.config:
|
||||||
|
@ -78,7 +77,8 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
cprint(f"Using config `{config_file}`", "green")
|
cprint(f"Using config `{config_file}`", "green")
|
||||||
with open(config_file, "r") as f:
|
with open(config_file, "r") as f:
|
||||||
config = StackRunConfig(**yaml.safe_load(f))
|
config_dict = yaml.safe_load(config_file.read_text())
|
||||||
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
|
||||||
if config.docker_image:
|
if config.docker_image:
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
|
|
|
@ -64,9 +64,6 @@ def configure_api_providers(
|
||||||
) -> StackRunConfig:
|
) -> StackRunConfig:
|
||||||
is_nux = len(config.providers) == 0
|
is_nux = len(config.providers) == 0
|
||||||
|
|
||||||
# keep this default so all APIs are served
|
|
||||||
config.apis = []
|
|
||||||
|
|
||||||
if is_nux:
|
if is_nux:
|
||||||
print(
|
print(
|
||||||
textwrap.dedent(
|
textwrap.dedent(
|
||||||
|
@ -79,7 +76,12 @@ def configure_api_providers(
|
||||||
|
|
||||||
provider_registry = get_provider_registry()
|
provider_registry = get_provider_registry()
|
||||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
|
||||||
|
if config.apis:
|
||||||
|
apis_to_serve = config.apis
|
||||||
|
else:
|
||||||
|
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||||
|
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
if api in builtin_apis:
|
if api in builtin_apis:
|
||||||
|
@ -153,7 +155,7 @@ def configure_api_providers(
|
||||||
"shields": (ShieldDef, configure_shields, "safety"),
|
"shields": (ShieldDef, configure_shields, "safety"),
|
||||||
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
||||||
}
|
}
|
||||||
safety_providers = config.providers["safety"]
|
safety_providers = config.providers.get("safety", [])
|
||||||
|
|
||||||
for otype, (odef, config_method, api_str) in object_types.items():
|
for otype, (odef, config_method, api_str) in object_types.items():
|
||||||
existing_objects = getattr(config, otype)
|
existing_objects = getattr(config, otype)
|
||||||
|
@ -166,9 +168,15 @@ def configure_api_providers(
|
||||||
)
|
)
|
||||||
updated_objects = existing_objects
|
updated_objects = existing_objects
|
||||||
else:
|
else:
|
||||||
# we are newly configuring this API
|
providers = config.providers.get(api_str, [])
|
||||||
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
if not providers:
|
||||||
updated_objects = config_method(config.providers[api_str], safety_providers)
|
updated_objects = []
|
||||||
|
else:
|
||||||
|
# we are newly configuring this API
|
||||||
|
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||||
|
updated_objects = config_method(
|
||||||
|
config.providers[api_str], safety_providers
|
||||||
|
)
|
||||||
|
|
||||||
setattr(config, otype, updated_objects)
|
setattr(config, otype, updated_objects)
|
||||||
print("")
|
print("")
|
||||||
|
@ -277,7 +285,7 @@ def upgrade_from_routing_table_to_registry(
|
||||||
shields = []
|
shields = []
|
||||||
memory_banks = []
|
memory_banks = []
|
||||||
|
|
||||||
routing_table = config_dict["routing_table"]
|
routing_table = config_dict.get("routing_table", {})
|
||||||
for api_str, entries in routing_table.items():
|
for api_str, entries in routing_table.items():
|
||||||
providers = get_providers(entries)
|
providers = get_providers(entries)
|
||||||
providers_by_api[api_str] = providers
|
providers_by_api[api_str] = providers
|
||||||
|
@ -324,15 +332,13 @@ def upgrade_from_routing_table_to_registry(
|
||||||
config_dict["shields"] = shields
|
config_dict["shields"] = shields
|
||||||
config_dict["memory_banks"] = memory_banks
|
config_dict["memory_banks"] = memory_banks
|
||||||
|
|
||||||
if "api_providers" in config_dict:
|
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||||
for api_str, provider in config_dict["api_providers"].items():
|
if provider_map:
|
||||||
if api_str in ("inference", "safety", "memory"):
|
for api_str, provider in provider_map.items():
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(provider, dict):
|
if isinstance(provider, dict):
|
||||||
providers_by_api[api_str] = [
|
providers_by_api[api_str] = [
|
||||||
Provider(
|
Provider(
|
||||||
provider_id=f"{provider['provider_type']}-00",
|
provider_id=f"{provider['provider_type']}",
|
||||||
provider_type=provider["provider_type"],
|
provider_type=provider["provider_type"],
|
||||||
config=provider["config"],
|
config=provider["config"],
|
||||||
)
|
)
|
||||||
|
@ -340,11 +346,12 @@ def upgrade_from_routing_table_to_registry(
|
||||||
|
|
||||||
config_dict["providers"] = providers_by_api
|
config_dict["providers"] = providers_by_api
|
||||||
|
|
||||||
del config_dict["routing_table"]
|
config_dict.pop("routing_table", None)
|
||||||
del config_dict["api_providers"]
|
config_dict.pop("api_providers", None)
|
||||||
|
config_dict.pop("provider_map", None)
|
||||||
|
|
||||||
config_dict["apis"] = config_dict["apis_to_serve"]
|
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||||
del config_dict["apis_to_serve"]
|
config_dict.pop("apis_to_serve", None)
|
||||||
|
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue