Update the meta reference safety implementation to match new API

This commit is contained in:
Ashwin Bharambe 2024-09-20 14:17:44 -07:00 committed by Xi Yan
parent 7e40eead4e
commit 82ddd851c8
11 changed files with 115 additions and 130 deletions

View file

@ -37,8 +37,8 @@ class AgentTool(Enum):
class ToolDefinitionCommon(BaseModel):
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
class SearchEngineType(Enum):
@ -266,8 +266,8 @@ class Session(BaseModel):
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)

View file

@ -13,11 +13,11 @@ import fire
import httpx
from llama_models.llama3.api.datatypes import UserMessage
from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .safety import * # noqa: F403
@ -69,11 +69,7 @@ async def run_main(host: str, port: int):
response = await client.run_shields(
RunShieldRequest(
messages=[message],
shields=[
ShieldDefinition(
shield_type=BuiltinShield.llama_guard,
)
],
shields=["llama_guard"],
)
)
print(response)

View file

@ -37,11 +37,8 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None
ShieldType = str
class Safety(Protocol):
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield: ShieldType, messages: List[Message], params: Dict[str, Any] = None
self, shield: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ...