re-work tool definitions, fix FastAPI issues, fix tool regressions

This commit is contained in:
Ashwin Bharambe 2024-08-24 22:07:06 -07:00
parent 8d14d4228b
commit 8efe614719
11 changed files with 144 additions and 104 deletions

View file

@ -10,8 +10,6 @@ import uuid
from datetime import datetime
from typing import AsyncGenerator, List
from llama_models.llama3.api.datatypes import ToolPromptFormat
from termcolor import cprint
from llama_toolchain.agentic_system.api import * # noqa: F403
@ -20,7 +18,10 @@ from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.tools.base import BaseTool
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
from llama_toolchain.tools.builtin import (
interpret_content_as_attachment,
SingleMessageBuiltinTool,
)
from .safety import SafetyException, ShieldRunnerMixin
@ -192,7 +193,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
turn_id, session, input_messages, attachments, sampling_params, stream
session, turn_id, input_messages, attachments, sampling_params, stream
):
if isinstance(res, bool):
return
@ -358,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin):
req = ChatCompletionRequest(
model=self.agent_config.model,
messages=input_messages,
tools=self.agent_config.tools,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
@ -555,17 +556,13 @@ class ChatAgent(ShieldRunnerMixin):
yield False
return
if isinstance(result_message.content, Attachment):
if out_attachment := interpret_content_as_attachment(
result_message.content
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(result_message.content)
elif isinstance(result_message.content, list) or isinstance(
result_message.content, tuple
):
for c in result_message.content:
if isinstance(c, Attachment):
output_attachments.append(c)
output_attachments.append(out_attachment)
input_messages = input_messages + [message, result_message]
@ -667,6 +664,27 @@ class ChatAgent(ShieldRunnerMixin):
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
def _get_tools(self) -> List[ToolDefinition]:
ret = []
for t in self.agent_config.tools:
if isinstance(t, BraveSearchToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
elif isinstance(t, WolframAlphaToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
elif isinstance(t, PhotogenToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
elif isinstance(t, CodeInterpreterToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
elif isinstance(t, FunctionCallToolDefinition):
ret.append(
ToolDefinition(
tool_name=t.function_name,
description=t.description,
parameters=t.parameters,
)
)
return ret
def attachment_message(urls: List[URL]) -> ToolResponseMessage:
content = []