Add a RoutableProvider protocol, support for multiple routing keys (#163)

* Update configure.py to use multiple routing keys for safety
* Refactor distribution/datatypes into a providers/datatypes
* Cleanup
This commit is contained in:
Ashwin Bharambe 2024-09-30 17:30:21 -07:00 committed by GitHub
parent 73decb3781
commit eb2d8a31a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 600 additions and 577 deletions

View file

@ -4,13 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import Api, RoutableProvider
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
@ -35,7 +37,7 @@ def resolve_and_get_path(model_name: str) -> str:
return model_dir
class MetaReferenceSafetyImpl(Safety):
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
@ -46,6 +48,15 @@ class MetaReferenceSafetyImpl(Safety):
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
available_shields = [v.value for v in MetaReferenceShieldType]
for key in routing_keys:
if key not in available_shields:
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield(
self,
shield_type: str,