fix config w/ safety

This commit is contained in:
Xi Yan 2024-09-22 23:27:04 -07:00
parent 4586692dee
commit 1ac188e1b3
4 changed files with 28 additions and 5 deletions

View file

@ -23,6 +23,7 @@ class CommonRoutingTableImpl(RoutingTable):
routing_table_config: Dict[str, List[RoutableProviderConfig]],
) -> None:
self.providers = {k: v for k, v in inner_impls}
print("routing table providers", self.providers)
self.routing_keys = list(self.providers.keys())
self.routing_table_config = routing_table_config

View file

@ -8,6 +8,7 @@ import importlib
from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
def instantiate_class_type(fully_qualified_name):
@ -47,6 +48,7 @@ async def instantiate_provider(
routing_table = provider_config
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
cprint(f"inner_specs: {inner_specs}", "cyan")
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(

View file

@ -31,6 +31,7 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
print("Initializing MetaReferenceSafetyImpl w/ config", config)
self.config = config
async def initialize(self) -> None:

View file

@ -38,11 +38,30 @@ routing_table:
- routing_key: llama_guard
provider_id: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-8B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
- routing_key: prompt_guard
provider_id: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-8B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
- routing_key: injection_shield
provider_id: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-8B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M