mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
A few bug fixes for covering corner cases
This commit is contained in:
parent
a05599c67a
commit
353c7dc82a
4 changed files with 30 additions and 23 deletions
|
|
@ -64,9 +64,6 @@ def configure_api_providers(
|
|||
) -> StackRunConfig:
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
# keep this default so all APIs are served
|
||||
config.apis = []
|
||||
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
|
|
@ -79,7 +76,12 @@ def configure_api_providers(
|
|||
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
|
|
@ -153,7 +155,7 @@ def configure_api_providers(
|
|||
"shields": (ShieldDef, configure_shields, "safety"),
|
||||
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
||||
}
|
||||
safety_providers = config.providers["safety"]
|
||||
safety_providers = config.providers.get("safety", [])
|
||||
|
||||
for otype, (odef, config_method, api_str) in object_types.items():
|
||||
existing_objects = getattr(config, otype)
|
||||
|
|
@ -166,9 +168,15 @@ def configure_api_providers(
|
|||
)
|
||||
updated_objects = existing_objects
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||
updated_objects = config_method(config.providers[api_str], safety_providers)
|
||||
providers = config.providers.get(api_str, [])
|
||||
if not providers:
|
||||
updated_objects = []
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||
updated_objects = config_method(
|
||||
config.providers[api_str], safety_providers
|
||||
)
|
||||
|
||||
setattr(config, otype, updated_objects)
|
||||
print("")
|
||||
|
|
@ -277,7 +285,7 @@ def upgrade_from_routing_table_to_registry(
|
|||
shields = []
|
||||
memory_banks = []
|
||||
|
||||
routing_table = config_dict["routing_table"]
|
||||
routing_table = config_dict.get("routing_table", {})
|
||||
for api_str, entries in routing_table.items():
|
||||
providers = get_providers(entries)
|
||||
providers_by_api[api_str] = providers
|
||||
|
|
@ -324,15 +332,13 @@ def upgrade_from_routing_table_to_registry(
|
|||
config_dict["shields"] = shields
|
||||
config_dict["memory_banks"] = memory_banks
|
||||
|
||||
if "api_providers" in config_dict:
|
||||
for api_str, provider in config_dict["api_providers"].items():
|
||||
if api_str in ("inference", "safety", "memory"):
|
||||
continue
|
||||
|
||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||
if provider_map:
|
||||
for api_str, provider in provider_map.items():
|
||||
if isinstance(provider, dict):
|
||||
providers_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=f"{provider['provider_type']}-00",
|
||||
provider_id=f"{provider['provider_type']}",
|
||||
provider_type=provider["provider_type"],
|
||||
config=provider["config"],
|
||||
)
|
||||
|
|
@ -340,11 +346,12 @@ def upgrade_from_routing_table_to_registry(
|
|||
|
||||
config_dict["providers"] = providers_by_api
|
||||
|
||||
del config_dict["routing_table"]
|
||||
del config_dict["api_providers"]
|
||||
config_dict.pop("routing_table", None)
|
||||
config_dict.pop("api_providers", None)
|
||||
config_dict.pop("provider_map", None)
|
||||
|
||||
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||
del config_dict["apis_to_serve"]
|
||||
config_dict.pop("apis_to_serve", None)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue