mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
fix config w/ safety
This commit is contained in:
parent
4586692dee
commit
1ac188e1b3
4 changed files with 28 additions and 5 deletions
|
@ -23,6 +23,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.providers = {k: v for k, v in inner_impls}
|
self.providers = {k: v for k, v in inner_impls}
|
||||||
|
print("routing table providers", self.providers)
|
||||||
self.routing_keys = list(self.providers.keys())
|
self.routing_keys = list(self.providers.keys())
|
||||||
self.routing_table_config = routing_table_config
|
self.routing_table_config = routing_table_config
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import importlib
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
def instantiate_class_type(fully_qualified_name):
|
def instantiate_class_type(fully_qualified_name):
|
||||||
|
@ -47,6 +48,7 @@ async def instantiate_provider(
|
||||||
routing_table = provider_config
|
routing_table = provider_config
|
||||||
|
|
||||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||||
|
cprint(f"inner_specs: {inner_specs}", "cyan")
|
||||||
inner_impls = []
|
inner_impls = []
|
||||||
for routing_entry in routing_table:
|
for routing_entry in routing_table:
|
||||||
impl = await instantiate_provider(
|
impl = await instantiate_provider(
|
||||||
|
|
|
@ -31,6 +31,7 @@ def resolve_and_get_path(model_name: str) -> str:
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety):
|
class MetaReferenceSafetyImpl(Safety):
|
||||||
def __init__(self, config: SafetyConfig) -> None:
|
def __init__(self, config: SafetyConfig) -> None:
|
||||||
|
print("Initializing MetaReferenceSafetyImpl w/ config", config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
|
@ -38,11 +38,30 @@ routing_table:
|
||||||
- routing_key: llama_guard
|
- routing_key: llama_guard
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
config:
|
config:
|
||||||
model: Llama-Guard-3-8B
|
llama_guard_shield:
|
||||||
excluded_categories: []
|
model: Llama-Guard-3-8B
|
||||||
disable_input_check: false
|
excluded_categories: []
|
||||||
disable_output_check: false
|
disable_input_check: false
|
||||||
|
disable_output_check: false
|
||||||
|
prompt_guard_shield:
|
||||||
|
model: Prompt-Guard-86M
|
||||||
- routing_key: prompt_guard
|
- routing_key: prompt_guard
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
config:
|
config:
|
||||||
model: Prompt-Guard-86M
|
llama_guard_shield:
|
||||||
|
model: Llama-Guard-3-8B
|
||||||
|
excluded_categories: []
|
||||||
|
disable_input_check: false
|
||||||
|
disable_output_check: false
|
||||||
|
prompt_guard_shield:
|
||||||
|
model: Prompt-Guard-86M
|
||||||
|
- routing_key: injection_shield
|
||||||
|
provider_id: meta-reference
|
||||||
|
config:
|
||||||
|
llama_guard_shield:
|
||||||
|
model: Llama-Guard-3-8B
|
||||||
|
excluded_categories: []
|
||||||
|
disable_input_check: false
|
||||||
|
disable_output_check: false
|
||||||
|
prompt_guard_shield:
|
||||||
|
model: Prompt-Guard-86M
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue