mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Significantly upgrade the interactive configuration experience
This commit is contained in:
parent
8d157a8197
commit
5a7b01d292
7 changed files with 217 additions and 156 deletions
|
@ -150,9 +150,6 @@ class StackBuild(Subcommand):
|
||||||
|
|
||||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
|
|
|
@ -148,14 +148,17 @@ class StackConfigure(Subcommand):
|
||||||
"yellow",
|
"yellow",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
config_dict = yaml.safe_load(config_file.read_text())
|
config_dict = yaml.safe_load(run_config_file.read_text())
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
else:
|
else:
|
||||||
config = StackRunConfig(
|
config = StackRunConfig(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
apis_to_serve=[],
|
apis_to_serve=[],
|
||||||
api_providers={},
|
providers={},
|
||||||
|
models=[],
|
||||||
|
shields=[],
|
||||||
|
memory_banks=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
config = configure_api_providers(config, build_config.distribution_spec)
|
config = configure_api_providers(config, build_config.distribution_spec)
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import textwrap
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -14,7 +15,6 @@ from llama_models.sku_list import (
|
||||||
safety_models,
|
safety_models,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
|
@ -23,14 +23,14 @@ from termcolor import cprint
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
stack_apis,
|
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||||
from llama_stack.providers.impls.meta_reference.safety.config import (
|
|
||||||
MetaReferenceShieldType,
|
|
||||||
)
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
ALLOWED_MODELS = (
|
ALLOWED_MODELS = (
|
||||||
|
@ -38,84 +38,162 @@ ALLOWED_MODELS = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_routing_entry_type(config_class: Any):
|
def configure_single_provider(
|
||||||
class BaseModelWithConfig(BaseModel):
|
registry: Dict[str, ProviderSpec], provider: Provider
|
||||||
routing_key: str
|
) -> Provider:
|
||||||
config: config_class
|
provider_spec = registry[provider.provider_type]
|
||||||
|
|
||||||
return BaseModelWithConfig
|
|
||||||
|
|
||||||
|
|
||||||
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
|
|
||||||
"""Get corresponding builtin APIs given provider backed APIs"""
|
|
||||||
res = []
|
|
||||||
for inf in builtin_automatically_routed_apis():
|
|
||||||
if inf.router_api.value in provider_backed_apis:
|
|
||||||
res.append(inf.routing_table_api.value)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: make sure we can deal with existing configuration values correctly
|
|
||||||
# instead of just overwriting them
|
|
||||||
def configure_api_providers(
|
|
||||||
config: StackRunConfig, spec: DistributionSpec
|
|
||||||
) -> StackRunConfig:
|
|
||||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
|
||||||
# append the bulitin routing APIs
|
|
||||||
apis += get_builtin_apis(apis)
|
|
||||||
|
|
||||||
router_api2builtin_api = {
|
|
||||||
inf.router_api.value: inf.routing_table_api.value
|
|
||||||
for inf in builtin_automatically_routed_apis()
|
|
||||||
}
|
|
||||||
|
|
||||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
|
||||||
|
|
||||||
apis = [v.value for v in stack_apis()]
|
|
||||||
all_providers = get_provider_registry()
|
|
||||||
|
|
||||||
# configure simple case for with non-routing providers to api_providers
|
|
||||||
for api_str in spec.providers.keys():
|
|
||||||
if api_str not in apis:
|
|
||||||
raise ValueError(f"Unknown API `{api_str}`")
|
|
||||||
|
|
||||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
|
||||||
api = Api(api_str)
|
|
||||||
|
|
||||||
p = spec.providers[api_str]
|
|
||||||
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
|
|
||||||
|
|
||||||
if isinstance(p, list):
|
|
||||||
cprint(
|
|
||||||
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
|
|
||||||
"yellow",
|
|
||||||
)
|
|
||||||
p = p[0]
|
|
||||||
|
|
||||||
provider_spec = all_providers[api][p]
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
try:
|
try:
|
||||||
provider_config = config.api_providers.get(api_str)
|
if provider.config:
|
||||||
if provider_config:
|
existing = config_type(**provider.config)
|
||||||
existing = config_type(**provider_config.config)
|
|
||||||
else:
|
else:
|
||||||
existing = None
|
existing = None
|
||||||
except Exception:
|
except Exception:
|
||||||
existing = None
|
existing = None
|
||||||
cfg = prompt_for_config(config_type, existing)
|
|
||||||
|
|
||||||
if api_str in router_api2builtin_api:
|
cfg = prompt_for_config(config_type, existing)
|
||||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
return Provider(
|
||||||
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
provider_id=provider.provider_id,
|
||||||
routing_entries = []
|
provider_type=provider.provider_type,
|
||||||
if api_str == "inference":
|
config=cfg.dict(),
|
||||||
if hasattr(cfg, "model"):
|
)
|
||||||
routing_key = cfg.model
|
|
||||||
|
|
||||||
|
def configure_api_providers(
|
||||||
|
config: StackRunConfig, build_spec: DistributionSpec
|
||||||
|
) -> StackRunConfig:
|
||||||
|
is_nux = len(config.providers) == 0
|
||||||
|
|
||||||
|
apis = set((config.apis_to_serve or list(build_spec.providers.keys())))
|
||||||
|
config.apis_to_serve = [a for a in apis if a != "telemetry"]
|
||||||
|
|
||||||
|
if is_nux:
|
||||||
|
print(
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||||
|
we need to configure the providers (implementations) you want to use for these APIs.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_registry = get_provider_registry()
|
||||||
|
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||||
|
for api_str in config.apis_to_serve:
|
||||||
|
api = Api(api_str)
|
||||||
|
if api in builtin_apis:
|
||||||
|
continue
|
||||||
|
if api not in provider_registry:
|
||||||
|
raise ValueError(f"Unknown API `{api_str}`")
|
||||||
|
|
||||||
|
existing_providers = config.providers.get(api_str, [])
|
||||||
|
if existing_providers:
|
||||||
|
cprint(
|
||||||
|
f"Re-configuring existing providers for API `{api_str}`...",
|
||||||
|
"green",
|
||||||
|
attrs=["bold"],
|
||||||
|
)
|
||||||
|
updated_providers = []
|
||||||
|
for p in existing_providers:
|
||||||
|
print(f"> Configuring provider `({p.provider_type})`")
|
||||||
|
updated_providers.append(
|
||||||
|
configure_single_provider(provider_registry[api], p)
|
||||||
|
)
|
||||||
|
print("")
|
||||||
else:
|
else:
|
||||||
routing_key = prompt(
|
# we are newly configuring this API
|
||||||
"> Please enter the supported model your provider has for inference: ",
|
plist = build_spec.providers.get(api_str, [])
|
||||||
default="Llama3.1-8B-Instruct",
|
plist = plist if isinstance(plist, list) else [plist]
|
||||||
|
|
||||||
|
if not plist:
|
||||||
|
raise ValueError(f"No provider configured for API {api_str}?")
|
||||||
|
|
||||||
|
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||||
|
updated_providers = []
|
||||||
|
for i, provider_type in enumerate(plist):
|
||||||
|
print(f"> Configuring provider `({provider_type})`")
|
||||||
|
updated_providers.append(
|
||||||
|
configure_single_provider(
|
||||||
|
provider_registry[api],
|
||||||
|
Provider(
|
||||||
|
provider_id=(
|
||||||
|
f"{provider_type}-{i:02d}"
|
||||||
|
if len(plist) > 1
|
||||||
|
else provider_type
|
||||||
|
),
|
||||||
|
provider_type=provider_type,
|
||||||
|
config={},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("")
|
||||||
|
|
||||||
|
config.providers[api_str] = updated_providers
|
||||||
|
|
||||||
|
if is_nux:
|
||||||
|
print(
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
=========================================================================================
|
||||||
|
Now let's configure the `objects` you will be serving via the stack. These are:
|
||||||
|
|
||||||
|
- Models: the Llama model SKUs you expect to inference (e.g., Llama3.2-1B-Instruct)
|
||||||
|
- Shields: the safety models you expect to use for safety (e.g., Llama-Guard-3-1B)
|
||||||
|
- Memory Banks: the memory banks you expect to use for memory (e.g., Vector stores)
|
||||||
|
|
||||||
|
This wizard will guide you through setting up one of each of these objects. You can
|
||||||
|
always add more later by editing the run.yaml file.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
object_types = {
|
||||||
|
"models": (ModelDef, configure_models, "inference"),
|
||||||
|
"shields": (ShieldDef, configure_shields, "safety"),
|
||||||
|
"memory_banks": (MemoryBankDef, configure_memory_banks, "memory"),
|
||||||
|
}
|
||||||
|
safety_providers = config.providers["safety"]
|
||||||
|
|
||||||
|
for otype, (odef, config_method, api_str) in object_types.items():
|
||||||
|
existing_objects = getattr(config, otype)
|
||||||
|
|
||||||
|
if existing_objects:
|
||||||
|
cprint(
|
||||||
|
f"{len(existing_objects)} {otype} exist. Skipping...",
|
||||||
|
"blue",
|
||||||
|
attrs=["bold"],
|
||||||
|
)
|
||||||
|
updated_objects = existing_objects
|
||||||
|
else:
|
||||||
|
# we are newly configuring this API
|
||||||
|
cprint(f"Configuring `{otype}`...", "blue", attrs=["bold"])
|
||||||
|
updated_objects = config_method(config.providers[api_str], safety_providers)
|
||||||
|
|
||||||
|
setattr(config, otype, updated_objects)
|
||||||
|
print("")
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_llama_guard_model(safety_providers: List[Provider]) -> Optional[str]:
|
||||||
|
if not safety_providers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider = safety_providers[0]
|
||||||
|
assert provider.provider_type == "meta-reference"
|
||||||
|
|
||||||
|
cfg = provider.config["llama_guard_shield"]
|
||||||
|
if not cfg:
|
||||||
|
return None
|
||||||
|
return cfg["model"]
|
||||||
|
|
||||||
|
|
||||||
|
def configure_models(
|
||||||
|
providers: List[Provider], safety_providers: List[Provider]
|
||||||
|
) -> List[ModelDef]:
|
||||||
|
model = prompt(
|
||||||
|
"> Please enter the model you want to serve: ",
|
||||||
|
default="Llama3.2-1B-Instruct",
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
lambda x: resolve_model(x) is not None,
|
lambda x: resolve_model(x) is not None,
|
||||||
error_message="Model must be: {}".format(
|
error_message="Model must be: {}".format(
|
||||||
|
@ -123,68 +201,57 @@ def configure_api_providers(
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
routing_entries.append(
|
model = ModelDef(
|
||||||
RoutableProviderConfig(
|
identifier=model,
|
||||||
routing_key=routing_key,
|
llama_model=model,
|
||||||
provider_type=p,
|
provider_id=providers[0].provider_id,
|
||||||
config=cfg.dict(),
|
)
|
||||||
|
|
||||||
|
ret = [model]
|
||||||
|
if llama_guard := get_llama_guard_model(safety_providers):
|
||||||
|
ret.append(
|
||||||
|
ModelDef(
|
||||||
|
identifier=llama_guard,
|
||||||
|
llama_model=llama_guard,
|
||||||
|
provider_id=providers[0].provider_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if api_str == "safety":
|
return ret
|
||||||
# TODO: add support for other safety providers, and simplify safety provider config
|
|
||||||
if p == "meta-reference":
|
|
||||||
routing_entries.append(
|
def configure_shields(
|
||||||
RoutableProviderConfig(
|
providers: List[Provider], safety_providers: List[Provider]
|
||||||
routing_key=[s.value for s in MetaReferenceShieldType],
|
) -> List[ShieldDef]:
|
||||||
provider_type=p,
|
if get_llama_guard_model(safety_providers):
|
||||||
config=cfg.dict(),
|
return [
|
||||||
)
|
ShieldDef(
|
||||||
)
|
identifier="llama_guard",
|
||||||
else:
|
type="llama_guard",
|
||||||
cprint(
|
provider_id=providers[0].provider_id,
|
||||||
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
|
params={},
|
||||||
"yellow",
|
|
||||||
attrs=["bold"],
|
|
||||||
)
|
|
||||||
routing_entries.append(
|
|
||||||
RoutableProviderConfig(
|
|
||||||
routing_key=routing_key,
|
|
||||||
provider_type=p,
|
|
||||||
config=cfg.dict(),
|
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def configure_memory_banks(
|
||||||
|
providers: List[Provider], safety_providers: List[Provider]
|
||||||
|
) -> List[MemoryBankDef]:
|
||||||
|
bank_name = prompt(
|
||||||
|
"> Please enter a name for your memory bank: ",
|
||||||
|
default="my-memory-bank",
|
||||||
)
|
)
|
||||||
|
|
||||||
if api_str == "memory":
|
return [
|
||||||
bank_types = list([x.value for x in MemoryBankType])
|
VectorMemoryBankDef(
|
||||||
routing_key = prompt(
|
identifier=bank_name,
|
||||||
"> Please enter the supported memory bank type your provider has for memory: ",
|
provider_id=providers[0].provider_id,
|
||||||
default="vector",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
validator=Validator.from_callable(
|
chunk_size_in_tokens=512,
|
||||||
lambda x: x in bank_types,
|
|
||||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
|
||||||
bank_types
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
routing_entries.append(
|
]
|
||||||
RoutableProviderConfig(
|
|
||||||
routing_key=routing_key,
|
|
||||||
provider_type=p,
|
|
||||||
config=cfg.dict(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
config.routing_table[api_str] = routing_entries
|
|
||||||
else:
|
|
||||||
config.api_providers[api_str] = GenericProviderConfig(
|
|
||||||
provider_type=p,
|
|
||||||
config=cfg.dict(),
|
|
||||||
)
|
|
||||||
|
|
||||||
print("")
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade_from_routing_table_to_registry(
|
def upgrade_from_routing_table_to_registry(
|
||||||
|
@ -193,7 +260,11 @@ def upgrade_from_routing_table_to_registry(
|
||||||
def get_providers(entries):
|
def get_providers(entries):
|
||||||
return [
|
return [
|
||||||
Provider(
|
Provider(
|
||||||
provider_id=f"{entry['provider_type']}-{i:02d}",
|
provider_id=(
|
||||||
|
f"{entry['provider_type']}-{i:02d}"
|
||||||
|
if len(entries) > 1
|
||||||
|
else entry["provider_type"]
|
||||||
|
),
|
||||||
provider_type=entry["provider_type"],
|
provider_type=entry["provider_type"],
|
||||||
config=entry["config"],
|
config=entry["config"],
|
||||||
)
|
)
|
||||||
|
@ -254,6 +325,9 @@ def upgrade_from_routing_table_to_registry(
|
||||||
|
|
||||||
if "api_providers" in config_dict:
|
if "api_providers" in config_dict:
|
||||||
for api_str, provider in config_dict["api_providers"].items():
|
for api_str, provider in config_dict["api_providers"].items():
|
||||||
|
if api_str in ("inference", "safety", "memory"):
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(provider, dict):
|
if isinstance(provider, dict):
|
||||||
providers_by_api[api_str] = [
|
providers_by_api[api_str] = [
|
||||||
Provider(
|
Provider(
|
||||||
|
|
|
@ -75,6 +75,7 @@ in the runtime configuration to help route to the correct provider.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: rename as ProviderInstanceConfig
|
||||||
class Provider(BaseModel):
|
class Provider(BaseModel):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
|
@ -108,8 +109,8 @@ The list of APIs to serve. If not specified, all APIs specified in the provider_
|
||||||
providers: Dict[str, List[Provider]]
|
providers: Dict[str, List[Provider]]
|
||||||
|
|
||||||
models: List[ModelDef]
|
models: List[ModelDef]
|
||||||
memory_banks: List[MemoryBankDef]
|
|
||||||
shields: List[ShieldDef]
|
shields: List[ShieldDef]
|
||||||
|
memory_banks: List[MemoryBankDef]
|
||||||
|
|
||||||
|
|
||||||
# api_providers: Dict[
|
# api_providers: Dict[
|
||||||
|
|
|
@ -22,8 +22,6 @@ class MetaReferenceShieldType(Enum):
|
||||||
class LlamaGuardShieldConfig(BaseModel):
|
class LlamaGuardShieldConfig(BaseModel):
|
||||||
model: str = "Llama-Guard-3-1B"
|
model: str = "Llama-Guard-3-1B"
|
||||||
excluded_categories: List[str] = []
|
excluded_categories: List[str] = []
|
||||||
disable_input_check: bool = False
|
|
||||||
disable_output_check: bool = False
|
|
||||||
|
|
||||||
@field_validator("model")
|
@field_validator("model")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -91,8 +91,6 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||||
model=cfg.model,
|
model=cfg.model,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
excluded_categories=cfg.excluded_categories,
|
excluded_categories=cfg.excluded_categories,
|
||||||
disable_input_check=cfg.disable_input_check,
|
|
||||||
disable_output_check=cfg.disable_output_check,
|
|
||||||
)
|
)
|
||||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||||
from .shields import JailbreakShield
|
from .shields import JailbreakShield
|
||||||
|
|
|
@ -113,8 +113,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
model: str,
|
model: str,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
excluded_categories: List[str] = None,
|
excluded_categories: List[str] = None,
|
||||||
disable_input_check: bool = False,
|
|
||||||
disable_output_check: bool = False,
|
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
):
|
):
|
||||||
super().__init__(on_violation_action)
|
super().__init__(on_violation_action)
|
||||||
|
@ -132,8 +130,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.excluded_categories = excluded_categories
|
self.excluded_categories = excluded_categories
|
||||||
self.disable_input_check = disable_input_check
|
|
||||||
self.disable_output_check = disable_output_check
|
|
||||||
|
|
||||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||||
match = re.match(r"^unsafe\n(.*)$", response)
|
match = re.match(r"^unsafe\n(.*)$", response)
|
||||||
|
@ -180,12 +176,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
messages = self.validate_messages(messages)
|
messages = self.validate_messages(messages)
|
||||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
||||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
|
||||||
return ShieldResponse(
|
|
||||||
is_violation=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||||
shield_input_message = self.build_vision_shield_input(messages)
|
shield_input_message = self.build_vision_shield_input(messages)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue