Significantly upgrade the interactive configuration experience

This commit is contained in:
Ashwin Bharambe 2024-10-05 11:12:46 -07:00 committed by Ashwin Bharambe
parent 8d157a8197
commit 5a7b01d292
7 changed files with 217 additions and 156 deletions

View file

@ -150,9 +150,6 @@ class StackBuild(Subcommand):
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json
import yaml
from llama_stack.cli.table import print_table
# eventually, this should query a registry at llama.meta.com/llamastack/distributions

View file

@ -148,14 +148,17 @@ class StackConfigure(Subcommand):
"yellow",
attrs=["bold"],
)
config_dict = yaml.safe_load(config_file.read_text())
config_dict = yaml.safe_load(run_config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
else:
config = StackRunConfig(
built_at=datetime.now(),
image_name=image_name,
apis_to_serve=[],
api_providers={},
providers={},
models=[],
shields=[],
memory_banks=[],
)
config = configure_api_providers(config, build_config.distribution_spec)

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import textwrap
from typing import Any
@ -14,7 +15,6 @@ from llama_models.sku_list import (
safety_models,
)
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
@ -23,14 +23,14 @@ from termcolor import cprint
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
stack_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.impls.meta_reference.safety.config import (
MetaReferenceShieldType,
)
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
ALLOWED_MODELS = (
@ -38,84 +38,162 @@ ALLOWED_MODELS = (
)
def make_routing_entry_type(config_class: Any):
class BaseModelWithConfig(BaseModel):
routing_key: str
config: config_class
return BaseModelWithConfig
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
"""Get corresponding builtin APIs given provider backed APIs"""
res = []
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in provider_backed_apis:
res.append(inf.routing_table_api.value)
return res
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec
) -> StackRunConfig:
apis = config.apis_to_serve or list(spec.providers.keys())
# append the bulitin routing APIs
apis += get_builtin_apis(apis)
router_api2builtin_api = {
inf.router_api.value: inf.routing_table_api.value
for inf in builtin_automatically_routed_apis()
}
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
apis = [v.value for v in stack_apis()]
all_providers = get_provider_registry()
# configure simple case for with non-routing providers to api_providers
for api_str in spec.providers.keys():
if api_str not in apis:
raise ValueError(f"Unknown API `{api_str}`")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
api = Api(api_str)
p = spec.providers[api_str]
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
if isinstance(p, list):
cprint(
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
"yellow",
)
p = p[0]
provider_spec = all_providers[api][p]
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
provider_config = config.api_providers.get(api_str)
if provider_config:
existing = config_type(**provider_config.config)
if provider.config:
existing = config_type(**provider.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
if api_str in router_api2builtin_api:
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
routing_entries = []
if api_str == "inference":
if hasattr(cfg, "model"):
routing_key = cfg.model
cfg = prompt_for_config(config_type, existing)
return Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config=cfg.dict(),
)
def configure_api_providers(
config: StackRunConfig, build_spec: DistributionSpec
) -> StackRunConfig:
is_nux = len(config.providers) == 0
apis = set((config.apis_to_serve or list(build_spec.providers.keys())))
config.apis_to_serve = [a for a in apis if a != "telemetry"]
if is_nux:
print(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.
"""
)
)
provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
for api_str in config.apis_to_serve:
api = Api(api_str)
if api in builtin_apis:
continue
if api not in provider_registry:
raise ValueError(f"Unknown API `{api_str}`")
existing_providers = config.providers.get(api_str, [])
if existing_providers:
cprint(
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
updated_providers = []
for p in existing_providers:
print(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(
configure_single_provider(provider_registry[api], p)
)
print("")
else:
routing_key = prompt(
"> Please enter the supported model your provider has for inference: ",
default="Llama3.1-8B-Instruct",
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
plist = plist if isinstance(plist, list) else [plist]
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
updated_providers = []
for i, provider_type in enumerate(plist):
print(f"> Configuring provider `({provider_type})`")
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(
f"{provider_type}-{i:02d}"
if len(plist) > 1
else provider_type
),
provider_type=provider_type,
config={},
),
)
)
print("")
config.providers[api_str] = updated_providers
if is_nux:
print(
textwrap.dedent(
"""
=========================================================================================
Now let's configure the `objects` you will be serving via the stack. These are:
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
This wizard will guide you through setting up one of each of these objects. You can
always add more later by editing the run.yaml file.
"""
)
)
object_types = {
"models": (ModelDef, configure_models, "inference"),
"shields": (ShieldDef, configure_shields, "safety"),
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
}
safety_providers = config.providers["safety"]
for otype, (odef, config_method, api_str) in object_types.items():
existing_objects = getattr(config, otype)
if existing_objects:
cprint(
f"{len(existing_objects)} {otype} exist. Skipping...",
"blue",
attrs=["bold"],
)
updated_objects = existing_objects
else:
# we are newly configuring this API
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
updated_objects = config_method(config.providers[api_str], safety_providers)
setattr(config, otype, updated_objects)
print("")
return config
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
if not safety_providers:
return None
provider = safety_providers[0]
assert provider.provider_type == "meta-reference"
cfg = provider.config["llama_guard_shield"]
if not cfg:
return None
return cfg["model"]
def configure_models(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ModelDef]:
model = prompt(
"> Please enter the model you want to serve: ",
default="Llama3.2-1B-Instruct",
validator=Validator.from_callable(
lambda x: resolve_model(x) is not None,
error_message="Model must be: {}".format(
@ -123,68 +201,57 @@ def configure_api_providers(
),
),
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
model = ModelDef(
identifier=model,
llama_model=model,
provider_id=providers[0].provider_id,
)
ret = [model]
if llama_guard := get_llama_guard_model(safety_providers):
ret.append(
ModelDef(
identifier=llama_guard,
llama_model=llama_guard,
provider_id=providers[0].provider_id,
)
)
if api_str == "safety":
# TODO: add support for other safety providers, and simplify safety provider config
if p == "meta-reference":
routing_entries.append(
RoutableProviderConfig(
routing_key=[s.value for s in MetaReferenceShieldType],
provider_type=p,
config=cfg.dict(),
)
)
else:
cprint(
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
"yellow",
attrs=["bold"],
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
return ret
def configure_shields(
providers: List[Provider], safety_providers: List[Provider]
) -> List[ShieldDef]:
if get_llama_guard_model(safety_providers):
return [
ShieldDef(
identifier="llama_guard",
type="llama_guard",
provider_id=providers[0].provider_id,
params={},
)
]
return []
def configure_memory_banks(
providers: List[Provider], safety_providers: List[Provider]
) -> List[MemoryBankDef]:
bank_name = prompt(
"> Please enter a name for your memory bank: ",
default="my-memory-bank",
)
if api_str == "memory":
bank_types = list([x.value for x in MemoryBankType])
routing_key = prompt(
"> Please enter the supported memory bank type your provider has for memory: ",
default="vector",
validator=Validator.from_callable(
lambda x: x in bank_types,
error_message="Invalid provider, please enter one of the following: {}".format(
bank_types
),
),
return [
VectorMemoryBankDef(
identifier=bank_name,
provider_id=providers[0].provider_id,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
)
config.routing_table[api_str] = routing_entries
else:
config.api_providers[api_str] = GenericProviderConfig(
provider_type=p,
config=cfg.dict(),
)
print("")
return config
]
def upgrade_from_routing_table_to_registry(
@ -193,7 +260,11 @@ def upgrade_from_routing_table_to_registry(
def get_providers(entries):
return [
Provider(
provider_id=f"{entry['provider_type']}-{i:02d}",
provider_id=(
f"{entry['provider_type']}-{i:02d}"
if len(entries) > 1
else entry["provider_type"]
),
provider_type=entry["provider_type"],
config=entry["config"],
)
@ -254,6 +325,9 @@ def upgrade_from_routing_table_to_registry(
if "api_providers" in config_dict:
for api_str, provider in config_dict["api_providers"].items():
if api_str in ("inference", "safety", "memory"):
continue
if isinstance(provider, dict):
providers_by_api[api_str] = [
Provider(

View file

@ -75,6 +75,7 @@ in the runtime configuration to help route to the correct provider.""",
)
# TODO: rename as ProviderInstanceConfig
class Provider(BaseModel):
provider_id: str
provider_type: str
@ -108,8 +109,8 @@ The list of APIs to serve. If not specified, all APIs specified in the provider_
providers: Dict[str, List[Provider]]
models: List[ModelDef]
memory_banks: List[MemoryBankDef]
shields: List[ShieldDef]
memory_banks: List[MemoryBankDef]
# api_providers: Dict[

View file

@ -22,8 +22,6 @@ class MetaReferenceShieldType(Enum):
class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = []
disable_input_check: bool = False
disable_output_check: bool = False
@field_validator("model")
@classmethod

View file

@ -91,8 +91,6 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
)
elif typ == MetaReferenceShieldType.jailbreak_shield:
from .shields import JailbreakShield

View file

@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
model: str,
inference_api: Inference,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
async def run(self, messages: List[Message]) -> ShieldResponse:
messages = self.validate_messages(messages)
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
is_violation=False,
)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
shield_input_message = self.build_vision_shield_input(messages)