add safety/list_shields to query available shields

This commit is contained in:
Xi Yan 2024-09-22 01:08:32 -07:00
parent 44fe099866
commit b8914bb56f
3 changed files with 52 additions and 3 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Protocol
from typing import Any, Dict, List, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@ -37,8 +37,16 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None
@json_schema_type
class ListShieldsResponse(BaseModel):
shields: List[str] = None
class Safety(Protocol):
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ...
@webmethod(route="/safety/list_shields")
async def list_shields(self) -> ListShieldsResponse: ...