mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Significantly upgrade the interactive configuration experience
This commit is contained in:
parent
8d157a8197
commit
5a7b01d292
7 changed files with 217 additions and 156 deletions
|
@ -150,9 +150,6 @@ class StackBuild(Subcommand):
|
|||
|
||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.cli.table import print_table
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
|
|
|
@ -148,14 +148,17 @@ class StackConfigure(Subcommand):
|
|||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config_dict = yaml.safe_load(run_config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
image_name=image_name,
|
||||
apis_to_serve=[],
|
||||
api_providers={},
|
||||
providers={},
|
||||
models=[],
|
||||
shields=[],
|
||||
memory_banks=[],
|
||||
)
|
||||
|
||||
config = configure_api_providers(config, build_config.distribution_spec)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import textwrap
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
@ -14,7 +15,6 @@ from llama_models.sku_list import (
|
|||
safety_models,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.validation import Validator
|
||||
|
@ -23,14 +23,14 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
stack_apis,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.providers.impls.meta_reference.safety.config import (
|
||||
MetaReferenceShieldType,
|
||||
)
|
||||
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
ALLOWED_MODELS = (
|
||||
|
@ -38,162 +38,233 @@ ALLOWED_MODELS = (
|
|||
)
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
routing_key: str
|
||||
config: config_class
|
||||
def configure_single_provider(
|
||||
registry: Dict[str, ProviderSpec], provider: Provider
|
||||
) -> Provider:
|
||||
provider_spec = registry[provider.provider_type]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
if provider.config:
|
||||
existing = config_type(**provider.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
|
||||
return BaseModelWithConfig
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
|
||||
|
||||
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
|
||||
"""Get corresponding builtin APIs given provider backed APIs"""
|
||||
res = []
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in provider_backed_apis:
|
||||
res.append(inf.routing_table_api.value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
# TODO: make sure we can deal with existing configuration values correctly
|
||||
# instead of just overwriting them
|
||||
def configure_api_providers(
|
||||
config: StackRunConfig, spec: DistributionSpec
|
||||
config: StackRunConfig, build_spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||
# append the bulitin routing APIs
|
||||
apis += get_builtin_apis(apis)
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
router_api2builtin_api = {
|
||||
inf.router_api.value: inf.routing_table_api.value
|
||||
for inf in builtin_automatically_routed_apis()
|
||||
}
|
||||
apis = set((config.apis_to_serve or list(build_spec.providers.keys())))
|
||||
config.apis_to_serve = [a for a in apis if a != "telemetry"]
|
||||
|
||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||
we need to configure the providers (implementations) you want to use for these APIs.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
apis = [v.value for v in stack_apis()]
|
||||
all_providers = get_provider_registry()
|
||||
|
||||
# configure simple case for with non-routing providers to api_providers
|
||||
for api_str in spec.providers.keys():
|
||||
if api_str not in apis:
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
for api_str in config.apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
continue
|
||||
if api not in provider_registry:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
api = Api(api_str)
|
||||
|
||||
p = spec.providers[api_str]
|
||||
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
|
||||
|
||||
if isinstance(p, list):
|
||||
existing_providers = config.providers.get(api_str, [])
|
||||
if existing_providers:
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
|
||||
"yellow",
|
||||
f"Re-configuring existing providers for API `{api_str}`...",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
p = p[0]
|
||||
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
provider_config = config.api_providers.get(api_str)
|
||||
if provider_config:
|
||||
existing = config_type(**provider_config.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
|
||||
if api_str in router_api2builtin_api:
|
||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
||||
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
||||
routing_entries = []
|
||||
if api_str == "inference":
|
||||
if hasattr(cfg, "model"):
|
||||
routing_key = cfg.model
|
||||
else:
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported model your provider has for inference: ",
|
||||
default="Llama3.1-8B-Instruct",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: resolve_model(x) is not None,
|
||||
error_message="Model must be: {}".format(
|
||||
[x.descriptor() for x in ALLOWED_MODELS]
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
print(f"> Configuring provider `({p.provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(provider_registry[api], p)
|
||||
)
|
||||
|
||||
if api_str == "safety":
|
||||
# TODO: add support for other safety providers, and simplify safety provider config
|
||||
if p == "meta-reference":
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=[s.value for s in MetaReferenceShieldType],
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
if api_str == "memory":
|
||||
bank_types = list([x.value for x in MemoryBankType])
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported memory bank type your provider has for memory: ",
|
||||
default="vector",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in bank_types,
|
||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
||||
bank_types
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
config.routing_table[api_str] = routing_entries
|
||||
print("")
|
||||
else:
|
||||
config.api_providers[api_str] = GenericProviderConfig(
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
# we are newly configuring this API
|
||||
plist = build_spec.providers.get(api_str, [])
|
||||
plist = plist if isinstance(plist, list) else [plist]
|
||||
|
||||
if not plist:
|
||||
raise ValueError(f"No provider configured for API {api_str}?")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
updated_providers = []
|
||||
for i, provider_type in enumerate(plist):
|
||||
print(f"> Configuring provider `({provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{provider_type}-{i:02d}"
|
||||
if len(plist) > 1
|
||||
else provider_type
|
||||
),
|
||||
provider_type=provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
)
|
||||
print("")
|
||||
|
||||
config.providers[api_str] = updated_providers
|
||||
|
||||
if is_nux:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
=========================================================================================
|
||||
Now let's configure the `objects` you will be serving via the stack. These are:
|
||||
|
||||
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
|
||||
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
|
||||
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
|
||||
|
||||
This wizard will guide you through setting up one of each of these objects. You can
|
||||
always add more later by editing the run.yaml file.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
object_types = {
|
||||
"models": (ModelDef, configure_models, "inference"),
|
||||
"shields": (ShieldDef, configure_shields, "safety"),
|
||||
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
||||
}
|
||||
safety_providers = config.providers["safety"]
|
||||
|
||||
for otype, (odef, config_method, api_str) in object_types.items():
|
||||
existing_objects = getattr(config, otype)
|
||||
|
||||
if existing_objects:
|
||||
cprint(
|
||||
f"{len(existing_objects)} {otype} exist. Skipping...",
|
||||
"blue",
|
||||
attrs=["bold"],
|
||||
)
|
||||
updated_objects = existing_objects
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||
updated_objects = config_method(config.providers[api_str], safety_providers)
|
||||
|
||||
setattr(config, otype, updated_objects)
|
||||
print("")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
|
||||
if not safety_providers:
|
||||
return None
|
||||
|
||||
provider = safety_providers[0]
|
||||
assert provider.provider_type == "meta-reference"
|
||||
|
||||
cfg = provider.config["llama_guard_shield"]
|
||||
if not cfg:
|
||||
return None
|
||||
return cfg["model"]
|
||||
|
||||
|
||||
def configure_models(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[ModelDef]:
|
||||
model = prompt(
|
||||
"> Please enter the model you want to serve: ",
|
||||
default="Llama3.2-1B-Instruct",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: resolve_model(x) is not None,
|
||||
error_message="Model must be: {}".format(
|
||||
[x.descriptor() for x in ALLOWED_MODELS]
|
||||
),
|
||||
),
|
||||
)
|
||||
model = ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
provider_id=providers[0].provider_id,
|
||||
)
|
||||
|
||||
ret = [model]
|
||||
if llama_guard := get_llama_guard_model(safety_providers):
|
||||
ret.append(
|
||||
ModelDef(
|
||||
identifier=llama_guard,
|
||||
llama_model=llama_guard,
|
||||
provider_id=providers[0].provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def configure_shields(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[ShieldDef]:
|
||||
if get_llama_guard_model(safety_providers):
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier="llama_guard",
|
||||
type="llama_guard",
|
||||
provider_id=providers[0].provider_id,
|
||||
params={},
|
||||
)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def configure_memory_banks(
|
||||
providers: List[Provider], safety_providers: List[Provider]
|
||||
) -> List[MemoryBankDef]:
|
||||
bank_name = prompt(
|
||||
"> Please enter a name for your memory bank: ",
|
||||
default="my-memory-bank",
|
||||
)
|
||||
|
||||
return [
|
||||
VectorMemoryBankDef(
|
||||
identifier=bank_name,
|
||||
provider_id=providers[0].provider_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def upgrade_from_routing_table_to_registry(
|
||||
config_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
def get_providers(entries):
|
||||
return [
|
||||
Provider(
|
||||
provider_id=f"{entry['provider_type']}-{i:02d}",
|
||||
provider_id=(
|
||||
f"{entry['provider_type']}-{i:02d}"
|
||||
if len(entries) > 1
|
||||
else entry["provider_type"]
|
||||
),
|
||||
provider_type=entry["provider_type"],
|
||||
config=entry["config"],
|
||||
)
|
||||
|
@ -254,6 +325,9 @@ def upgrade_from_routing_table_to_registry(
|
|||
|
||||
if "api_providers" in config_dict:
|
||||
for api_str, provider in config_dict["api_providers"].items():
|
||||
if api_str in ("inference", "safety", "memory"):
|
||||
continue
|
||||
|
||||
if isinstance(provider, dict):
|
||||
providers_by_api[api_str] = [
|
||||
Provider(
|
||||
|
|
|
@ -75,6 +75,7 @@ in the runtime configuration to help route to the correct provider.""",
|
|||
)
|
||||
|
||||
|
||||
# TODO: rename as ProviderInstanceConfig
|
||||
class Provider(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
@ -108,8 +109,8 @@ The list of APIs to serve. If not specified, all APIs specified in the provider_
|
|||
providers: Dict[str, List[Provider]]
|
||||
|
||||
models: List[ModelDef]
|
||||
memory_banks: List[MemoryBankDef]
|
||||
shields: List[ShieldDef]
|
||||
memory_banks: List[MemoryBankDef]
|
||||
|
||||
|
||||
# api_providers: Dict[
|
||||
|
|
|
@ -22,8 +22,6 @@ class MetaReferenceShieldType(Enum):
|
|||
class LlamaGuardShieldConfig(BaseModel):
|
||||
model: str = "Llama-Guard-3-1B"
|
||||
excluded_categories: List[str] = []
|
||||
disable_input_check: bool = False
|
||||
disable_output_check: bool = False
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
|
|
|
@ -91,8 +91,6 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
disable_input_check=cfg.disable_input_check,
|
||||
disable_output_check=cfg.disable_output_check,
|
||||
)
|
||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||
from .shields import JailbreakShield
|
||||
|
|
|
@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
|
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
shield_input_message = self.build_vision_shield_input(messages)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue