mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
A few bug fixes for covering corner cases
This commit is contained in:
parent
a05599c67a
commit
353c7dc82a
4 changed files with 30 additions and 23 deletions
|
@ -22,7 +22,7 @@ def available_templates_specs() -> List[BuildConfig]:
|
|||
import yaml
|
||||
|
||||
template_specs = []
|
||||
for p in TEMPLATES_PATH.rglob("*.yaml"):
|
||||
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||
with open(p, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
template_specs.append(build_config)
|
||||
|
|
|
@ -152,7 +152,7 @@ class StackConfigure(Subcommand):
|
|||
config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
image_name=image_name,
|
||||
apis=[],
|
||||
apis=list(build_config.distribution_spec.providers.keys()),
|
||||
providers={},
|
||||
models=[],
|
||||
shields=[],
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
|
@ -49,8 +48,8 @@ class StackRun(Subcommand):
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
if not args.config:
|
||||
|
@ -78,7 +77,8 @@ class StackRun(Subcommand):
|
|||
|
||||
cprint(f"Using config `{config_file}`", "green")
|
||||
with open(config_file, "r") as f:
|
||||
config = StackRunConfig(**yaml.safe_load(f))
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
|
||||
if config.docker_image:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
|
@ -64,9 +64,6 @@ def configure_api_providers(
|
|||
) -> StackRunConfig:
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
# keep this default so all APIs are served
|
||||
config.apis = []
|
||||
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
|
@ -79,7 +76,12 @@ def configure_api_providers(
|
|||
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
|
@ -153,7 +155,7 @@ def configure_api_providers(
|
|||
"shields": (ShieldDef, configure_shields, "safety"),
|
||||
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
||||
}
|
||||
safety_providers = config.providers["safety"]
|
||||
safety_providers = config.providers.get("safety", [])
|
||||
|
||||
for otype, (odef, config_method, api_str) in object_types.items():
|
||||
existing_objects = getattr(config, otype)
|
||||
|
@ -165,10 +167,16 @@ def configure_api_providers(
|
|||
attrs=["bold"],
|
||||
)
|
||||
updated_objects = existing_objects
|
||||
else:
|
||||
providers = config.providers.get(api_str, [])
|
||||
if not providers:
|
||||
updated_objects = []
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||
updated_objects = config_method(config.providers[api_str], safety_providers)
|
||||
updated_objects = config_method(
|
||||
config.providers[api_str], safety_providers
|
||||
)
|
||||
|
||||
setattr(config, otype, updated_objects)
|
||||
print("")
|
||||
|
@ -277,7 +285,7 @@ def upgrade_from_routing_table_to_registry(
|
|||
shields = []
|
||||
memory_banks = []
|
||||
|
||||
routing_table = config_dict["routing_table"]
|
||||
routing_table = config_dict.get("routing_table", {})
|
||||
for api_str, entries in routing_table.items():
|
||||
providers = get_providers(entries)
|
||||
providers_by_api[api_str] = providers
|
||||
|
@ -324,15 +332,13 @@ def upgrade_from_routing_table_to_registry(
|
|||
config_dict["shields"] = shields
|
||||
config_dict["memory_banks"] = memory_banks
|
||||
|
||||
if "api_providers" in config_dict:
|
||||
for api_str, provider in config_dict["api_providers"].items():
|
||||
if api_str in ("inference", "safety", "memory"):
|
||||
continue
|
||||
|
||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||
if provider_map:
|
||||
for api_str, provider in provider_map.items():
|
||||
if isinstance(provider, dict):
|
||||
providers_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=f"{provider['provider_type']}-00",
|
||||
provider_id=f"{provider['provider_type']}",
|
||||
provider_type=provider["provider_type"],
|
||||
config=provider["config"],
|
||||
)
|
||||
|
@ -340,11 +346,12 @@ def upgrade_from_routing_table_to_registry(
|
|||
|
||||
config_dict["providers"] = providers_by_api
|
||||
|
||||
del config_dict["routing_table"]
|
||||
del config_dict["api_providers"]
|
||||
config_dict.pop("routing_table", None)
|
||||
config_dict.pop("api_providers", None)
|
||||
config_dict.pop("provider_map", None)
|
||||
|
||||
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||
del config_dict["apis_to_serve"]
|
||||
config_dict.pop("apis_to_serve", None)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue