Updates to prompt for tool calls (#29)

* update system prompts to drop new line

* Add tool prompt formats

* support json format

* JSON in caps

* function_tag system prompt is also added as a user message

* added docstrings for ToolPromptFormat

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2024-08-15 13:23:51 -07:00 committed by GitHub
parent 0d933ac4c5
commit b8fc4d4dee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 173 additions and 30 deletions

View file

@ -8,12 +8,11 @@ from enum import Enum
from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from pydantic import BaseModel, validator
from pydantic import BaseModel
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type
@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel):
on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v
@json_schema_type
class ShieldResponse(BaseModel):
@ -51,3 +60,13 @@ class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v