forked from phoenix-oss/llama-stack-mirror
Add a RoutableProvider protocol, support for multiple routing keys (#163)
* Update configure.py to use multiple routing keys for safety * Refactor distribution/datatypes into a providers/datatypes * Cleanup
This commit is contained in:
parent
73decb3781
commit
eb2d8a31a5
24 changed files with 600 additions and 577 deletions
|
@ -4,13 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import Api, RoutableProvider
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
|
@ -35,7 +37,7 @@ def resolve_and_get_path(model_name: str) -> str:
|
|||
return model_dir
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
@ -46,6 +48,15 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
_ = PromptGuardShield.instance(model_dir)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
for key in routing_keys:
|
||||
if key not in available_shields:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue