support json format

This commit is contained in:
Hardik Shah 2024-08-14 12:43:43 -07:00
parent 48b78430eb
commit 86df597a83
7 changed files with 97 additions and 29 deletions

View file

@ -8,12 +8,11 @@ from enum import Enum
from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from pydantic import BaseModel, validator
from pydantic import BaseModel
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type
@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel):
on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v
@json_schema_type
class ShieldResponse(BaseModel):
@ -51,3 +60,13 @@ class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v