fix config w/ safety

This commit is contained in:
Xi Yan 2024-09-22 23:27:04 -07:00
parent 4586692dee
commit 1ac188e1b3
4 changed files with 28 additions and 5 deletions

View file

@ -23,6 +23,7 @@ class CommonRoutingTableImpl(RoutingTable):
routing_table_config: Dict[str, List[RoutableProviderConfig]], routing_table_config: Dict[str, List[RoutableProviderConfig]],
) -> None: ) -> None:
self.providers = {k: v for k, v in inner_impls} self.providers = {k: v for k, v in inner_impls}
print("routing table providers", self.providers)
self.routing_keys = list(self.providers.keys()) self.routing_keys = list(self.providers.keys())
self.routing_table_config = routing_table_config self.routing_table_config = routing_table_config

View file

@ -8,6 +8,7 @@ import importlib
from typing import Any, Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
def instantiate_class_type(fully_qualified_name): def instantiate_class_type(fully_qualified_name):
@ -47,6 +48,7 @@ async def instantiate_provider(
routing_table = provider_config routing_table = provider_config
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
cprint(f"inner_specs: {inner_specs}", "cyan")
inner_impls = [] inner_impls = []
for routing_entry in routing_table: for routing_entry in routing_table:
impl = await instantiate_provider( impl = await instantiate_provider(

View file

@ -31,6 +31,7 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety): class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None: def __init__(self, config: SafetyConfig) -> None:
print("Initializing MetaReferenceSafetyImpl w/ config", config)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -38,11 +38,30 @@ routing_table:
- routing_key: llama_guard - routing_key: llama_guard
provider_id: meta-reference provider_id: meta-reference
config: config:
model: Llama-Guard-3-8B llama_guard_shield:
excluded_categories: [] model: Llama-Guard-3-8B
disable_input_check: false excluded_categories: []
disable_output_check: false disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
- routing_key: prompt_guard - routing_key: prompt_guard
provider_id: meta-reference provider_id: meta-reference
config: config:
model: Prompt-Guard-86M llama_guard_shield:
model: Llama-Guard-3-8B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
- routing_key: injection_shield
provider_id: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-8B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M