configure script works

This commit is contained in:
Xi Yan 2024-09-23 00:49:03 -07:00
parent b224fcf9ab
commit 1d463e1a36
2 changed files with 62 additions and 15 deletions

View file

@ -9,7 +9,7 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.safety.safety import BuiltinShield # noqa: F403 from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
api_providers, api_providers,
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
@ -18,7 +18,11 @@ from llama_stack.distribution.distribution import (
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.impls.meta_reference.safety.config import (
MetaReferenceShieldType,
)
from prompt_toolkit import prompt from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint from termcolor import cprint
@ -45,7 +49,6 @@ def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
def configure_api_providers( def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec config: StackRunConfig, spec: DistributionSpec
) -> StackRunConfig: ) -> StackRunConfig:
cprint(f"configure_api_providers {spec}", "red")
apis = config.apis_to_serve or list(spec.providers.keys()) apis = config.apis_to_serve or list(spec.providers.keys())
# append the bulitin routing APIs # append the bulitin routing APIs
apis += get_builtin_apis(apis) apis += get_builtin_apis(apis)
@ -71,6 +74,13 @@ def configure_api_providers(
p = spec.providers[api_str] p = spec.providers[api_str]
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green") cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
if isinstance(p, list):
cprint(
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
"yellow",
)
p = p[0]
provider_spec = all_providers[api][p] provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
try: try:
@ -86,6 +96,7 @@ def configure_api_providers(
if api_str in router_api2builtin_api: if api_str in router_api2builtin_api:
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table # a routing api, we need to infer and assign it a routing_key and put it in the routing_table
routing_key = "<PLEASE_FILL_ROUTING_KEY>" routing_key = "<PLEASE_FILL_ROUTING_KEY>"
routing_entries = []
if api_str == "inference": if api_str == "inference":
if hasattr(cfg, "model"): if hasattr(cfg, "model"):
routing_key = cfg.model routing_key = cfg.model
@ -94,22 +105,59 @@ def configure_api_providers(
"> Please enter the supported model your provider has for inference: ", "> Please enter the supported model your provider has for inference: ",
default="Meta-Llama3.1-8B-Instruct", default="Meta-Llama3.1-8B-Instruct",
) )
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_id=p,
config=cfg.dict(),
)
)
if api_str == "safety": if api_str == "safety":
# check all supported shields # TODO: add support for other safety providers, and simplify safety provider config
for shield_type in BuiltinShield: if p == "meta-reference":
print(shield_type.value) for shield_type in MetaReferenceShieldType:
routing_entries.append(
RoutableProviderConfig(
routing_key=shield_type.value,
provider_id=p,
config=cfg.dict(),
)
)
else:
cprint(
f"[WARN] Interactive configuration of safety provider {p} is not supported, please manually configure safety shields types in routing_table of run.yaml",
"yellow",
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_id=p,
config=cfg.dict(),
)
)
# if api_str == "memory": if api_str == "memory":
# # check all supported memory_banks bank_types = list([x.value for x in MemoryBankType])
routing_key = prompt(
config.routing_table[api_str] = [ "> Please enter the supported memory bank type your provider has for memory: ",
RoutableProviderConfig( default="vector",
routing_key=routing_key, validator=Validator.from_callable(
provider_id=p, lambda x: x in bank_types,
config=cfg.dict(), error_message="Invalid provider, please enter one of the following: {}".format(
bank_types
),
),
) )
] routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_id=p,
config=cfg.dict(),
)
)
config.routing_table[api_str] = routing_entries
else: else:
config.api_providers[api_str] = GenericProviderConfig( config.api_providers[api_str] = GenericProviderConfig(
provider_id=p, provider_id=p,

View file

@ -31,7 +31,6 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety): class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None: def __init__(self, config: SafetyConfig) -> None:
print("Initializing MetaReferenceSafetyImpl w/ config", config)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None: