From 1d463e1a361d39578fef39fc5e4e0ad06afa6af6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 23 Sep 2024 00:49:03 -0700 Subject: [PATCH] configure script works --- llama_stack/distribution/configure.py | 76 +++++++++++++++---- .../impls/meta_reference/safety/safety.py | 1 - 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 3e9a0fbeb..036df3ade 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -9,7 +9,7 @@ from typing import Any from pydantic import BaseModel from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.apis.safety.safety import BuiltinShield # noqa: F403 +from llama_stack.apis.memory.memory import MemoryBankType from llama_stack.distribution.distribution import ( api_providers, builtin_automatically_routed_apis, @@ -18,7 +18,11 @@ from llama_stack.distribution.distribution import ( from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.prompt_for_config import prompt_for_config +from llama_stack.providers.impls.meta_reference.safety.config import ( + MetaReferenceShieldType, +) from prompt_toolkit import prompt +from prompt_toolkit.validation import Validator from termcolor import cprint @@ -45,7 +49,6 @@ def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]: def configure_api_providers( config: StackRunConfig, spec: DistributionSpec ) -> StackRunConfig: - cprint(f"configure_api_providers {spec}", "red") apis = config.apis_to_serve or list(spec.providers.keys()) # append the bulitin routing APIs apis += get_builtin_apis(apis) @@ -71,6 +74,13 @@ def configure_api_providers( p = spec.providers[api_str] cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green") + if isinstance(p, list): + cprint( + f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml", + "yellow", + ) + p = p[0] + provider_spec = all_providers[api][p] config_type = instantiate_class_type(provider_spec.config_class) try: @@ -86,6 +96,7 @@ def configure_api_providers( if api_str in router_api2builtin_api: # a routing api, we need to infer and assign it a routing_key and put it in the routing_table routing_key = "" + routing_entries = [] if api_str == "inference": if hasattr(cfg, "model"): routing_key = cfg.model @@ -94,22 +105,59 @@ def configure_api_providers( "> Please enter the supported model your provider has for inference: ", default="Meta-Llama3.1-8B-Instruct", ) + routing_entries.append( + RoutableProviderConfig( + routing_key=routing_key, + provider_id=p, + config=cfg.dict(), + ) + ) if api_str == "safety": - # check all supported shields - for shield_type in BuiltinShield: - print(shield_type.value) + # TODO: add support for other safety providers, and simplify safety provider config + if p == "meta-reference": + for shield_type in MetaReferenceShieldType: + routing_entries.append( + RoutableProviderConfig( + routing_key=shield_type.value, + provider_id=p, + config=cfg.dict(), + ) + ) + else: + cprint( + f"[WARN] Interactive configuration of safety provider {p} is not supported, please manually configure safety shields types in routing_table of run.yaml", + "yellow", + ) + routing_entries.append( + RoutableProviderConfig( + routing_key=routing_key, + provider_id=p, + config=cfg.dict(), + ) + ) - # if api_str == "memory": - # # check all supported memory_banks - - config.routing_table[api_str] = [ - RoutableProviderConfig( - routing_key=routing_key, - provider_id=p, - config=cfg.dict(), + if api_str == "memory": + bank_types = list([x.value for x in MemoryBankType]) + routing_key = prompt( + "> Please enter the supported memory bank type your provider has for memory: ", + default="vector", + validator=Validator.from_callable( + lambda x: x in bank_types, + error_message="Invalid provider, please enter one of the following: {}".format( + bank_types + ), + ), ) - ] + routing_entries.append( + RoutableProviderConfig( + routing_key=routing_key, + provider_id=p, + config=cfg.dict(), + ) + ) + + config.routing_table[api_str] = routing_entries else: config.api_providers[api_str] = GenericProviderConfig( provider_id=p, diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index f768ab773..6eccf47a5 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -31,7 +31,6 @@ def resolve_and_get_path(model_name: str) -> str: class MetaReferenceSafetyImpl(Safety): def __init__(self, config: SafetyConfig) -> None: - print("Initializing MetaReferenceSafetyImpl w/ config", config) self.config = config async def initialize(self) -> None: