diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 95df6a737..0cedbe901 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -150,9 +150,6 @@ class StackBuild(Subcommand): def _run_template_list_cmd(self, args: argparse.Namespace) -> None: import json - - import yaml - from llama_stack.cli.table import print_table # eventually, this should query a registry at llama.meta.com/llamastack/distributions diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py index e1b0aa39f..76ade470e 100644 --- a/llama_stack/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -148,14 +148,17 @@ class StackConfigure(Subcommand): "yellow", attrs=["bold"], ) - config_dict = yaml.safe_load(config_file.read_text()) + config_dict = yaml.safe_load(run_config_file.read_text()) config = parse_and_maybe_upgrade_config(config_dict) else: config = StackRunConfig( built_at=datetime.now(), image_name=image_name, apis_to_serve=[], - api_providers={}, + providers={}, + models=[], + shields=[], + memory_banks=[], ) config = configure_api_providers(config, build_config.distribution_spec) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 1fdde3092..b40cff242 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import textwrap from typing import Any @@ -14,7 +15,6 @@ from llama_models.sku_list import ( safety_models, ) -from pydantic import BaseModel from llama_stack.distribution.datatypes import * # noqa: F403 from prompt_toolkit import prompt from prompt_toolkit.validation import Validator @@ -23,14 +23,14 @@ from termcolor import cprint from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, get_provider_registry, - stack_apis, ) from llama_stack.distribution.utils.dynamic import instantiate_class_type - from llama_stack.distribution.utils.prompt_for_config import prompt_for_config -from llama_stack.providers.impls.meta_reference.safety.config import ( - MetaReferenceShieldType, -) + + +from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.memory_banks import * # noqa: F403 ALLOWED_MODELS = ( @@ -38,162 +38,233 @@ ALLOWED_MODELS = ( ) -def make_routing_entry_type(config_class: Any): - class BaseModelWithConfig(BaseModel): - routing_key: str - config: config_class +def configure_single_provider( + registry: Dict[str, ProviderSpec], provider: Provider +) -> Provider: + provider_spec = registry[provider.provider_type] + config_type = instantiate_class_type(provider_spec.config_class) + try: + if provider.config: + existing = config_type(**provider.config) + else: + existing = None + except Exception: + existing = None - return BaseModelWithConfig + cfg = prompt_for_config(config_type, existing) + return Provider( + provider_id=provider.provider_id, + provider_type=provider.provider_type, + config=cfg.dict(), + ) -def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]: - """Get corresponding builtin APIs given provider backed APIs""" - res = [] - for inf in builtin_automatically_routed_apis(): - if inf.router_api.value in provider_backed_apis: - res.append(inf.routing_table_api.value) - - return res - - -# TODO: make sure we can deal with existing configuration values correctly -# instead of just overwriting them def configure_api_providers( - config: StackRunConfig, spec: DistributionSpec + config: StackRunConfig, build_spec: DistributionSpec ) -> StackRunConfig: - apis = config.apis_to_serve or list(spec.providers.keys()) - # append the bulitin routing APIs - apis += get_builtin_apis(apis) + is_nux = len(config.providers) == 0 - router_api2builtin_api = { - inf.router_api.value: inf.routing_table_api.value - for inf in builtin_automatically_routed_apis() - } + apis = set((config.apis_to_serve or list(build_spec.providers.keys()))) + config.apis_to_serve = [a for a in apis if a != "telemetry"] - config.apis_to_serve = list(set([a for a in apis if a != "telemetry"])) + if is_nux: + print( + textwrap.dedent( + """ + Llama Stack is composed of several APIs working together. For each API served by the Stack, + we need to configure the providers (implementations) you want to use for these APIs. +""" + ) + ) - apis = [v.value for v in stack_apis()] - all_providers = get_provider_registry() - - # configure simple case for with non-routing providers to api_providers - for api_str in spec.providers.keys(): - if api_str not in apis: + provider_registry = get_provider_registry() + builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()] + for api_str in config.apis_to_serve: + api = Api(api_str) + if api in builtin_apis: + continue + if api not in provider_registry: raise ValueError(f"Unknown API `{api_str}`") - cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) - api = Api(api_str) - - p = spec.providers[api_str] - cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green") - - if isinstance(p, list): + existing_providers = config.providers.get(api_str, []) + if existing_providers: cprint( - f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml", - "yellow", + f"Re-configuring existing providers for API `{api_str}`...", + "green", + attrs=["bold"], ) - p = p[0] - - provider_spec = all_providers[api][p] - config_type = instantiate_class_type(provider_spec.config_class) - try: - provider_config = config.api_providers.get(api_str) - if provider_config: - existing = config_type(**provider_config.config) - else: - existing = None - except Exception: - existing = None - cfg = prompt_for_config(config_type, existing) - - if api_str in router_api2builtin_api: - # a routing api, we need to infer and assign it a routing_key and put it in the routing_table - routing_key = "" - routing_entries = [] - if api_str == "inference": - if hasattr(cfg, "model"): - routing_key = cfg.model - else: - routing_key = prompt( - "> Please enter the supported model your provider has for inference: ", - default="Llama3.1-8B-Instruct", - validator=Validator.from_callable( - lambda x: resolve_model(x) is not None, - error_message="Model must be: {}".format( - [x.descriptor() for x in ALLOWED_MODELS] - ), - ), - ) - routing_entries.append( - RoutableProviderConfig( - routing_key=routing_key, - provider_type=p, - config=cfg.dict(), - ) + updated_providers = [] + for p in existing_providers: + print(f"> Configuring provider `({p.provider_type})`") + updated_providers.append( + configure_single_provider(provider_registry[api], p) ) - - if api_str == "safety": - # TODO: add support for other safety providers, and simplify safety provider config - if p == "meta-reference": - routing_entries.append( - RoutableProviderConfig( - routing_key=[s.value for s in MetaReferenceShieldType], - provider_type=p, - config=cfg.dict(), - ) - ) - else: - cprint( - f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.", - "yellow", - attrs=["bold"], - ) - routing_entries.append( - RoutableProviderConfig( - routing_key=routing_key, - provider_type=p, - config=cfg.dict(), - ) - ) - - if api_str == "memory": - bank_types = list([x.value for x in MemoryBankType]) - routing_key = prompt( - "> Please enter the supported memory bank type your provider has for memory: ", - default="vector", - validator=Validator.from_callable( - lambda x: x in bank_types, - error_message="Invalid provider, please enter one of the following: {}".format( - bank_types - ), - ), - ) - routing_entries.append( - RoutableProviderConfig( - routing_key=routing_key, - provider_type=p, - config=cfg.dict(), - ) - ) - - config.routing_table[api_str] = routing_entries + print("") else: - config.api_providers[api_str] = GenericProviderConfig( - provider_type=p, - config=cfg.dict(), - ) + # we are newly configuring this API + plist = build_spec.providers.get(api_str, []) + plist = plist if isinstance(plist, list) else [plist] + if not plist: + raise ValueError(f"No provider configured for API {api_str}?") + + cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) + updated_providers = [] + for i, provider_type in enumerate(plist): + print(f"> Configuring provider `({provider_type})`") + updated_providers.append( + configure_single_provider( + provider_registry[api], + Provider( + provider_id=( + f"{provider_type}-{i:02d}" + if len(plist) > 1 + else provider_type + ), + provider_type=provider_type, + config={}, + ), + ) + ) + print("") + + config.providers[api_str] = updated_providers + + if is_nux: + print( + textwrap.dedent( + """ + ========================================================================================= + Now let's configure the `objects` you will be serving via the stack. These are: + + - Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct) + - Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B) + - Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores) + + This wizard will guide you through setting up one of each of these objects. You can + always add more later by editing the run.yaml file. + """ + ) + ) + + object_types = { + "models": (ModelDef, configure_models, "inference"), + "shields": (ShieldDef, configure_shields, "safety"), + "memory_banks": (MemoryBankDef, configure_memory_banks, "memory"), + } + safety_providers = config.providers["safety"] + + for otype, (odef, config_method, api_str) in object_types.items(): + existing_objects = getattr(config, otype) + + if existing_objects: + cprint( + f"{len(existing_objects)} {otype} exist. Skipping...", + "blue", + attrs=["bold"], + ) + updated_objects = existing_objects + else: + # we are newly configuring this API + cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"]) + updated_objects = config_method(config.providers[api_str], safety_providers) + + setattr(config, otype, updated_objects) print("") return config +def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]: + if not safety_providers: + return None + + provider = safety_providers[0] + assert provider.provider_type == "meta-reference" + + cfg = provider.config["llama_guard_shield"] + if not cfg: + return None + return cfg["model"] + + +def configure_models( + providers: List[Provider], safety_providers: List[Provider] +) -> List[ModelDef]: + model = prompt( + "> Please enter the model you want to serve: ", + default="Llama3.2-1B-Instruct", + validator=Validator.from_callable( + lambda x: resolve_model(x) is not None, + error_message="Model must be: {}".format( + [x.descriptor() for x in ALLOWED_MODELS] + ), + ), + ) + model = ModelDef( + identifier=model, + llama_model=model, + provider_id=providers[0].provider_id, + ) + + ret = [model] + if llama_guard := get_llama_guard_model(safety_providers): + ret.append( + ModelDef( + identifier=llama_guard, + llama_model=llama_guard, + provider_id=providers[0].provider_id, + ) + ) + + return ret + + +def configure_shields( + providers: List[Provider], safety_providers: List[Provider] +) -> List[ShieldDef]: + if get_llama_guard_model(safety_providers): + return [ + ShieldDef( + identifier="llama_guard", + type="llama_guard", + provider_id=providers[0].provider_id, + params={}, + ) + ] + + return [] + + +def configure_memory_banks( + providers: List[Provider], safety_providers: List[Provider] +) -> List[MemoryBankDef]: + bank_name = prompt( + "> Please enter a name for your memory bank: ", + default="my-memory-bank", + ) + + return [ + VectorMemoryBankDef( + identifier=bank_name, + provider_id=providers[0].provider_id, + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + ) + ] + + def upgrade_from_routing_table_to_registry( config_dict: Dict[str, Any], ) -> Dict[str, Any]: def get_providers(entries): return [ Provider( - provider_id=f"{entry['provider_type']}-{i:02d}", + provider_id=( + f"{entry['provider_type']}-{i:02d}" + if len(entries) > 1 + else entry["provider_type"] + ), provider_type=entry["provider_type"], config=entry["config"], ) @@ -254,6 +325,9 @@ def upgrade_from_routing_table_to_registry( if "api_providers" in config_dict: for api_str, provider in config_dict["api_providers"].items(): + if api_str in ("inference", "safety", "memory"): + continue + if isinstance(provider, dict): providers_by_api[api_str] = [ Provider( diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index bccb7d705..0ee03175c 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -75,6 +75,7 @@ in the runtime configuration to help route to the correct provider.""", ) +# TODO: rename as ProviderInstanceConfig class Provider(BaseModel): provider_id: str provider_type: str @@ -108,8 +109,8 @@ The list of APIs to serve. If not specified, all APIs specified in the provider_ providers: Dict[str, List[Provider]] models: List[ModelDef] - memory_banks: List[MemoryBankDef] shields: List[ShieldDef] + memory_banks: List[MemoryBankDef] # api_providers: Dict[ diff --git a/llama_stack/providers/impls/meta_reference/safety/config.py b/llama_stack/providers/impls/meta_reference/safety/config.py index 4f6de544b..51d2ae2bf 100644 --- a/llama_stack/providers/impls/meta_reference/safety/config.py +++ b/llama_stack/providers/impls/meta_reference/safety/config.py @@ -22,8 +22,6 @@ class MetaReferenceShieldType(Enum): class LlamaGuardShieldConfig(BaseModel): model: str = "Llama-Guard-3-1B" excluded_categories: List[str] = [] - disable_input_check: bool = False - disable_output_check: bool = False @field_validator("model") @classmethod diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 0ac3b6244..bf19a3010 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -91,8 +91,6 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider): model=cfg.model, inference_api=self.inference_api, excluded_categories=cfg.excluded_categories, - disable_input_check=cfg.disable_input_check, - disable_output_check=cfg.disable_output_check, ) elif typ == MetaReferenceShieldType.jailbreak_shield: from .shields import JailbreakShield diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index f98d95c43..19a20a899 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase): model: str, inference_api: Inference, excluded_categories: List[str] = None, - disable_input_check: bool = False, - disable_output_check: bool = False, on_violation_action: OnViolationAction = OnViolationAction.RAISE, ): super().__init__(on_violation_action) @@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase): self.model = model self.inference_api = inference_api self.excluded_categories = excluded_categories - self.disable_input_check = disable_input_check - self.disable_output_check = disable_output_check def check_unsafe_response(self, response: str) -> Optional[str]: match = re.match(r"^unsafe\n(.*)$", response) @@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase): async def run(self, messages: List[Message]) -> ShieldResponse: messages = self.validate_messages(messages) - if self.disable_input_check and messages[-1].role == Role.user.value: - return ShieldResponse(is_violation=False) - elif self.disable_output_check and messages[-1].role == Role.assistant.value: - return ShieldResponse( - is_violation=False, - ) if self.model == CoreModelId.llama_guard_3_11b_vision.value: shield_input_message = self.build_vision_shield_input(messages)