From 1ac188e1b39c3b719ada3b4708c1df73226d2ede Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 22 Sep 2024 23:27:04 -0700 Subject: [PATCH] fix config w/ safety --- .../distribution/routers/routing_tables.py | 1 + llama_stack/distribution/utils/dynamic.py | 2 ++ .../impls/meta_reference/safety/safety.py | 1 + tests/examples/router-local-run.yaml | 29 +++++++++++++++---- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index fcd4d2b2b..b2e4f01eb 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -23,6 +23,7 @@ class CommonRoutingTableImpl(RoutingTable): routing_table_config: Dict[str, List[RoutableProviderConfig]], ) -> None: self.providers = {k: v for k, v in inner_impls} + print("routing table providers", self.providers) self.routing_keys = list(self.providers.keys()) self.routing_table_config = routing_table_config diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 7c2ac2e6a..9818f6a6e 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -8,6 +8,7 @@ import importlib from typing import Any, Dict from llama_stack.distribution.datatypes import * # noqa: F403 +from termcolor import cprint def instantiate_class_type(fully_qualified_name): @@ -47,6 +48,7 @@ async def instantiate_provider( routing_table = provider_config inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} + cprint(f"inner_specs: {inner_specs}", "cyan") inner_impls = [] for routing_entry in routing_table: impl = await instantiate_provider( diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 6eccf47a5..f768ab773 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -31,6 +31,7 @@ def resolve_and_get_path(model_name: str) -> str: class MetaReferenceSafetyImpl(Safety): def __init__(self, config: SafetyConfig) -> None: + print("Initializing MetaReferenceSafetyImpl w/ config", config) self.config = config async def initialize(self) -> None: diff --git a/tests/examples/router-local-run.yaml b/tests/examples/router-local-run.yaml index df4c453b2..774b4c266 100644 --- a/tests/examples/router-local-run.yaml +++ b/tests/examples/router-local-run.yaml @@ -38,11 +38,30 @@ routing_table: - routing_key: llama_guard provider_id: meta-reference config: - model: Llama-Guard-3-8B - excluded_categories: [] - disable_input_check: false - disable_output_check: false + llama_guard_shield: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M - routing_key: prompt_guard provider_id: meta-reference config: - model: Prompt-Guard-86M + llama_guard_shield: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M + - routing_key: injection_shield + provider_id: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M