A few bug fixes for covering corner cases

This commit is contained in:
Ashwin Bharambe 2024-10-07 13:55:01 -07:00
parent a05599c67a
commit 353c7dc82a
4 changed files with 30 additions and 23 deletions

View file

@ -22,7 +22,7 @@ def available_templates_specs() -> List[BuildConfig]:
import yaml import yaml
template_specs = [] template_specs = []
for p in TEMPLATES_PATH.rglob("*.yaml"): for p in TEMPLATES_PATH.rglob("*build.yaml"):
with open(p, "r") as f: with open(p, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f)) build_config = BuildConfig(**yaml.safe_load(f))
template_specs.append(build_config) template_specs.append(build_config)

View file

@ -152,7 +152,7 @@ class StackConfigure(Subcommand):
config = StackRunConfig( config = StackRunConfig(
built_at=datetime.now(), built_at=datetime.now(),
image_name=image_name, image_name=image_name,
apis=[], apis=list(build_config.distribution_spec.providers.keys()),
providers={}, providers={},
models=[], models=[],
shields=[], shields=[],

View file

@ -7,7 +7,6 @@
import argparse import argparse
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
class StackRun(Subcommand): class StackRun(Subcommand):
@ -49,8 +48,8 @@ class StackRun(Subcommand):
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.build import ImageType from llama_stack.distribution.build import ImageType
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.utils.exec import run_with_pty
if not args.config: if not args.config:
@ -78,7 +77,8 @@ class StackRun(Subcommand):
cprint(f"Using config `{config_file}`", "green") cprint(f"Using config `{config_file}`", "green")
with open(config_file, "r") as f: with open(config_file, "r") as f:
config = StackRunConfig(**yaml.safe_load(f)) config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if config.docker_image: if config.docker_image:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(

View file

@ -64,9 +64,6 @@ def configure_api_providers(
) -> StackRunConfig: ) -> StackRunConfig:
is_nux = len(config.providers) == 0 is_nux = len(config.providers) == 0
# keep this default so all APIs are served
config.apis = []
if is_nux: if is_nux:
print( print(
textwrap.dedent( textwrap.dedent(
@ -79,7 +76,12 @@ def configure_api_providers(
provider_registry = get_provider_registry() provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()] builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)] apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
if api in builtin_apis: if api in builtin_apis:
@ -153,7 +155,7 @@ def configure_api_providers(
"shields": (ShieldDef, configure_shields, "safety"), "shields": (ShieldDef, configure_shields, "safety"),
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"), "memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
} }
safety_providers = config.providers["safety"] safety_providers = config.providers.get("safety", [])
for otype, (odef, config_method, api_str) in object_types.items(): for otype, (odef, config_method, api_str) in object_types.items():
existing_objects = getattr(config, otype) existing_objects = getattr(config, otype)
@ -165,10 +167,16 @@ def configure_api_providers(
attrs=["bold"], attrs=["bold"],
) )
updated_objects = existing_objects updated_objects = existing_objects
else:
providers = config.providers.get(api_str, [])
if not providers:
updated_objects = []
else: else:
# we are newly configuring this API # we are newly configuring this API
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"]) cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
updated_objects = config_method(config.providers[api_str], safety_providers) updated_objects = config_method(
config.providers[api_str], safety_providers
)
setattr(config, otype, updated_objects) setattr(config, otype, updated_objects)
print("") print("")
@ -277,7 +285,7 @@ def upgrade_from_routing_table_to_registry(
shields = [] shields = []
memory_banks = [] memory_banks = []
routing_table = config_dict["routing_table"] routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items(): for api_str, entries in routing_table.items():
providers = get_providers(entries) providers = get_providers(entries)
providers_by_api[api_str] = providers providers_by_api[api_str] = providers
@ -324,15 +332,13 @@ def upgrade_from_routing_table_to_registry(
config_dict["shields"] = shields config_dict["shields"] = shields
config_dict["memory_banks"] = memory_banks config_dict["memory_banks"] = memory_banks
if "api_providers" in config_dict: provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
for api_str, provider in config_dict["api_providers"].items(): if provider_map:
if api_str in ("inference", "safety", "memory"): for api_str, provider in provider_map.items():
continue
if isinstance(provider, dict): if isinstance(provider, dict):
providers_by_api[api_str] = [ providers_by_api[api_str] = [
Provider( Provider(
provider_id=f"{provider['provider_type']}-00", provider_id=f"{provider['provider_type']}",
provider_type=provider["provider_type"], provider_type=provider["provider_type"],
config=provider["config"], config=provider["config"],
) )
@ -340,11 +346,12 @@ def upgrade_from_routing_table_to_registry(
config_dict["providers"] = providers_by_api config_dict["providers"] = providers_by_api
del config_dict["routing_table"] config_dict.pop("routing_table", None)
del config_dict["api_providers"] config_dict.pop("api_providers", None)
config_dict.pop("provider_map", None)
config_dict["apis"] = config_dict["apis_to_serve"] config_dict["apis"] = config_dict["apis_to_serve"]
del config_dict["apis_to_serve"] config_dict.pop("apis_to_serve", None)
return config_dict return config_dict