support json format

This commit is contained in:
Hardik Shah 2024-08-14 12:43:43 -07:00
parent 48b78430eb
commit 86df597a83
7 changed files with 97 additions and 29 deletions

View file

@ -134,7 +134,7 @@ class AgenticSystemInstanceConfig(BaseModel):
# this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.function_tag
default=ToolPromptFormat.json
)

View file

@ -19,7 +19,9 @@ from llama_models.llama3_1.api.datatypes import (
ToolParamDefinition,
UserMessage,
)
from termcolor import cprint
from llama_toolchain.agentic_system.event_logger import EventLogger
from .api import (
AgenticSystem,
AgenticSystemCreateRequest,
@ -120,7 +122,18 @@ async def run_main(host: str, port: int):
required=True,
)
},
),
AgenticSystemToolDefinition(
tool_name="custom_tool_2",
description="a second custom tool",
parameters={
"param2": ToolParamDefinition(
param_type="str",
description="a string parameter",
required=True,
)
},
),
]
create_request = AgenticSystemCreateRequest(
@ -138,7 +151,7 @@ async def run_main(host: str, port: int):
)
create_response = await api.create_agentic_system(create_request)
print("Create Response -->", create_response)
print(create_response)
session_response = await api.create_agentic_system_session(
AgenticSystemSessionCreateRequest(
@ -146,21 +159,28 @@ async def run_main(host: str, port: int):
session_name="test_session",
)
)
print("Session Response -->", session_response)
print(session_response)
turn_response = api.create_agentic_system_turn(
user_prompts = [
"Who are you?",
"Write code to check if a number is prime. Use that to check if 7 is prime",
]
for content in user_prompts:
cprint(f"User> {content}", color="blue")
iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest(
system_id=create_response.system_id,
session_id=session_response.session_id,
messages=[
UserMessage(content="Who are you?"),
UserMessage(content=content),
],
stream=False,
stream=True,
)
)
print("Turn Response -->")
async for chunk in turn_response:
print(chunk)
async for event, log in EventLogger().log(iterator):
if log is not None:
log.print()
def main(host: str, port: int):

View file

@ -76,7 +76,7 @@ class AgentInstance(ShieldRunnerMixin):
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.function_tag,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
):
self.system_id = system_id
self.instance_config = instance_config

View file

@ -6,14 +6,15 @@
from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role
from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage
from termcolor import cprint
from llama_toolchain.safety.api.datatypes import (
OnViolationAction,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
from termcolor import cprint
class SafetyException(Exception): # noqa: N818
@ -36,12 +37,11 @@ class ShieldRunnerMixin:
async def run_shields(
self, messages: List[Message], shields: List[ShieldDefinition]
) -> List[ShieldResponse]:
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
# TODO(ashwin): we need to change the type of the message, this kind of modification
# is no longer appropriate
messages[0].role = Role.user.value
messages[0] = UserMessage(content=messages[0].content)
res = await self.safety_api.run_shields(
RunShieldRequest(

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import textwrap
from datetime import datetime
from typing import List
@ -15,6 +16,7 @@ from llama_toolchain.inference.api import (
Message,
SystemMessage,
ToolDefinition,
UserMessage,
)
from .tools.builtin import SingleMessageBuiltinTool
@ -49,18 +51,43 @@ Today Date: {formatted_date}\n"""
if custom_tools:
if tool_prompt_format == ToolPromptFormat.function_tag:
custom_message = get_system_prompt_for_custom_tools(custom_tools)
custom_message = prompt_for_function_tag(custom_tools)
content += custom_message
messages.append(SystemMessage(content=content))
elif tool_prompt_format == ToolPromptFormat.json:
messages.append(SystemMessage(content=content))
# json is added as a user prompt
text = prompt_for_json(custom_tools)
messages.append(UserMessage(content=text))
else:
raise NotImplementedError(
f"Tool prompt format {tool_prompt_format} is not supported"
)
else:
messages.append(SystemMessage(content=content))
return messages
def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str:
def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
tool_defs = "\n".join(
translate_custom_tool_definition_to_json(t) for t in custom_tools
)
content = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{tool_defs}
Return function calls in json format.
"""
)
content = content.lstrip("\n").format(tool_defs=tool_defs)
return content
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
custom_tool_params = ""
for t in custom_tools:
custom_tool_params += get_instruction_string(t) + "\n"
@ -102,7 +129,6 @@ def get_parameters_string(custom_tool_definition) -> str:
)
# NOTE: Unused right now
def translate_custom_tool_definition_to_json(tool_def):
"""Translates ToolDefinition to json as expected by model
eg. output for a function
@ -153,4 +179,4 @@ def translate_custom_tool_definition_to_json(tool_def):
else:
func_def["function"]["parameters"] = {}
return json.dumps(func_def)
return json.dumps(func_def, indent=4)

View file

@ -15,6 +15,7 @@ from llama_toolchain.agentic_system.api import (
AgenticSystemSessionCreateRequest,
AgenticSystemToolDefinition,
)
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.agentic_system.tools.custom.execute import (
@ -64,6 +65,7 @@ async def get_agent_system_instance(
custom_tools: Optional[List[Any]] = None,
disable_safety: bool = False,
model: str = "Meta-Llama3.1-8B-Instruct",
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> AgenticSystemClientWrapper:
custom_tools = custom_tools or []
@ -113,6 +115,7 @@ async def get_agent_system_instance(
]
),
sampling_params=SamplingParams(),
tool_prompt_format=tool_prompt_format,
),
)
create_response = await api.create_agentic_system(create_request)

View file

@ -8,12 +8,11 @@ from enum import Enum
from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from pydantic import BaseModel, validator
from pydantic import BaseModel
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type
@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel):
on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v
@json_schema_type
class ShieldResponse(BaseModel):
@ -51,3 +60,13 @@ class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v