Updates to prompt for tool calls (#29)

* update system prompts to drop new line

* Add tool prompt formats

* support json format

* JSON in caps

* function_tag system prompt is also added as a user message

* added docstrings for ToolPromptFormat

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2024-08-15 13:23:51 -07:00 committed by GitHub
parent 0d933ac4c5
commit b8fc4d4dee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 173 additions and 30 deletions

View file

@ -13,8 +13,15 @@ import fire
import httpx
from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
SamplingParams,
ToolParamDefinition,
UserMessage,
)
from termcolor import cprint
from llama_toolchain.agentic_system.event_logger import EventLogger
from .api import (
AgenticSystem,
AgenticSystemCreateRequest,
@ -25,6 +32,7 @@ from .api import (
AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk,
ToolPromptFormat,
)
@ -87,7 +95,7 @@ class AgenticSystemClient(AgenticSystem):
async def run_main(host: str, port: int):
# client to test remote impl of agentic system
api = await AgenticSystemClient(f"http://{host}:{port}")
api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [
AgenticSystemToolDefinition(
@ -96,13 +104,28 @@ async def run_main(host: str, port: int):
AgenticSystemToolDefinition(
tool_name=BuiltinTool.wolfram_alpha,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.photogen,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.code_interpreter,
),
]
tool_definitions += [
AgenticSystemToolDefinition(
tool_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
param_type="str",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinition(
param_type="str",
description="Whether to return the boiling point in Celcius",
required=False,
),
},
),
]
create_request = AgenticSystemCreateRequest(
model="Meta-Llama3.1-8B-Instruct",
@ -114,12 +137,44 @@ async def run_main(host: str, port: int):
output_shields=[],
quantization_config=None,
debug_prefix_messages=[],
tool_prompt_format=ToolPromptFormat.json,
),
)
create_response = await api.create_agentic_system(create_request)
print(create_response)
# TODO: Add chat session / turn apis to test e2e
session_response = await api.create_agentic_system_session(
AgenticSystemSessionCreateRequest(
system_id=create_response.system_id,
session_name="test_session",
)
)
print(session_response)
user_prompts = [
"Who are you?",
"what is the 100th prime number?",
"Search web for who was 44th President of USA?",
"Write code to check if a number is prime. Use that to check if 7 is prime",
"What is the boiling point of polyjuicepotion ?",
]
for content in user_prompts:
cprint(f"User> {content}", color="blue")
iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest(
system_id=create_response.system_id,
session_id=session_response.session_id,
messages=[
UserMessage(content=content),
],
stream=True,
)
)
async for event, log in EventLogger().log(iterator):
if log is not None:
log.print()
def main(host: str, port: int):