Redo the { models, shields, memory_banks } typeset

This commit is contained in:
Ashwin Bharambe 2024-10-05 08:41:36 -07:00 committed by Ashwin Bharambe
parent 6b094b72d3
commit f3923e3f0b
15 changed files with 588 additions and 454 deletions

View file

@ -20,7 +20,6 @@ from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
@ -177,9 +176,6 @@ def configure_api_providers(
)
config.routing_table[api_str] = routing_entries
config.api_providers[api_str] = PlaceholderProviderConfig(
providers=p if isinstance(p, list) else [p]
)
else:
config.api_providers[api_str] = GenericProviderConfig(
provider_type=p,
@ -189,3 +185,102 @@ def configure_api_providers(
print("")
return config
def upgrade_from_routing_table_to_registry(
config_dict: Dict[str, Any],
) -> Dict[str, Any]:
def get_providers(entries):
return [
Provider(
provider_id=f"{entry['provider_type']}-{i:02d}",
provider_type=entry["provider_type"],
config=entry["config"],
)
for i, entry in enumerate(entries)
]
providers_by_api = {}
models = []
shields = []
memory_banks = []
routing_table = config_dict["routing_table"]
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
if api_str == "inference":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
models.append(
ModelDef(
identifier=key,
provider_id=provider.provider_id,
llama_model=key,
)
)
elif api_str == "safety":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
shields.append(
ShieldDef(
identifier=key,
type=ShieldType.llama_guard.value,
provider_id=provider.provider_id,
)
)
elif api_str == "memory":
for entry, provider in zip(entries, providers):
key = entry["routing_key"]
keys = key if isinstance(key, list) else [key]
for key in keys:
# we currently only support Vector memory banks so this is OK
memory_banks.append(
VectorMemoryBankDef(
identifier=key,
provider_id=provider.provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
)
config_dict["models"] = models
config_dict["shields"] = shields
config_dict["memory_banks"] = memory_banks
if "api_providers" in config_dict:
for api_str, provider in config_dict["api_providers"].items():
if isinstance(provider, dict):
providers_by_api[api_str] = [
Provider(
provider_id=f"{provider['provider_type']}-00",
provider_type=provider["provider_type"],
config=provider["config"],
)
]
config_dict["providers"] = providers_by_api
del config_dict["routing_table"]
del config_dict["api_providers"]
return config_dict
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**config_dict)
if "models" not in config_dict:
print("Upgrading config...")
config_dict = upgrade_from_routing_table_to_registry(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
config_dict["built_at"] = datetime.now().isoformat()
return StackRunConfig(**config_dict)