forked from phoenix-oss/llama-stack-mirror
Updates to prompt for tool calls (#29)
* update system prompts to drop new line * Add tool prompt formats * support json format * JSON in caps * function_tag system prompt is also added as a user message * added docstrings for ToolPromptFormat --------- Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
parent
0d933ac4c5
commit
b8fc4d4dee
8 changed files with 173 additions and 30 deletions
|
@ -8,12 +8,11 @@ from enum import Enum
|
|||
from typing import Dict, Optional, Union
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel):
|
|||
on_violation_action: OnViolationAction = OnViolationAction.RAISE
|
||||
execution_config: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
@validator("shield_type", pre=True)
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinShield(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldResponse(BaseModel):
|
||||
|
@ -51,3 +60,13 @@ class ShieldResponse(BaseModel):
|
|||
is_violation: bool
|
||||
violation_type: Optional[str] = None
|
||||
violation_return_message: Optional[str] = None
|
||||
|
||||
@validator("shield_type", pre=True)
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinShield(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue