Refactor custom tool execution utilities

This commit is contained in:
Ashwin Bharambe 2024-08-25 14:34:20 -07:00
parent 440d125ea0
commit ceef117abc
9 changed files with 209 additions and 145 deletions

View file

@ -111,7 +111,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
) )
for content in user_prompts: for content in user_prompts:
cprint(f"User> {content}", color="blue") cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agentic_system_turn( iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest( AgenticSystemTurnCreateRequest(
agent_id=create_response.agent_id, agent_id=create_response.agent_id,

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.agentic_system.api import (
AgenticSystemTurnResponseEventType as EventType,
)
from llama_toolchain.tools.custom.datatypes import CustomTool
class AgentWithCustomToolExecutor:
def __init__(
self,
api: AgenticSystem,
agent_id: str,
session_id: str,
agent_config: AgentConfig,
custom_tools: List[CustomTool],
):
self.api = api
self.agent_id = agent_id
self.session_id = session_id
self.agent_config = agent_config
self.custom_tools = custom_tools
async def execute_turn(
self,
messages: List[Message],
attachments: Optional[List[Attachment]] = None,
max_iters: int = 5,
stream: bool = True,
) -> AsyncGenerator:
tools_dict = {t.get_name(): t for t in self.custom_tools}
current_messages = messages.copy()
n_iter = 0
while n_iter < max_iters:
n_iter += 1
request = AgenticSystemTurnCreateRequest(
agent_id=self.agent_id,
session_id=self.session_id,
messages=current_messages,
attachments=attachments,
stream=stream,
)
turn = None
async for chunk in self.api.create_agentic_system_turn(request):
if chunk.event.payload.event_type != EventType.turn_complete.value:
yield chunk
else:
turn = chunk.event.payload.turn
message = turn.output_message
if len(message.tool_calls) == 0:
yield chunk
return
if message.stop_reason == StopReason.out_of_tokens:
yield chunk
return
tool_call = message.tool_calls[0]
if tool_call.tool_name not in tools_dict:
m = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
)
next_message = m
else:
tool = tools_dict[tool_call.tool_name]
result_messages = await execute_custom_tool(tool, message)
next_message = result_messages[0]
yield next_message
current_messages = [next_message]
async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]:
result_messages = await tool.run([message])
assert (
len(result_messages) == 1
), f"Expected single message, got {len(result_messages)}"
return result_messages

View file

@ -596,7 +596,10 @@ class ChatAgent(ShieldRunnerMixin):
and self.agent_config.tool_choice == ToolChoice.required and self.agent_config.tool_choice == ToolChoice.required
): ):
return False return False
return attachments or AgenticSystemTool.memory.value in enabled_tools else:
return True
return AgenticSystemTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools: for t in self.agent_config.tools:
@ -629,7 +632,10 @@ class ChatAgent(ShieldRunnerMixin):
] ]
await self.memory_api.insert_documents(bank.bank_id, documents) await self.memory_api.insert_documents(bank.bank_id, documents)
assert len(bank_ids) > 0, "No memory banks configured?" if not bank_ids:
# this can happen if the per-session memory bank is not yet populated
# (i.e., no prior turns uploaded an Attachment)
return None, []
query = " ".join(m.content for m in messages) query = " ".join(m.content for m in messages)
tasks = [ tasks = [

View file

@ -1,83 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, List
from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage
from llama_toolchain.agentic_system.api import (
AgenticSystem,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseEventType as EventType,
)
from llama_toolchain.inference.api import Message
async def execute_with_custom_tools(
system: AgenticSystem,
agent_id: str,
session_id: str,
messages: List[Message],
custom_tools: List[Any],
max_iters: int = 5,
stream: bool = True,
) -> AsyncGenerator:
# first create a session, or do you keep a persistent session?
tools_dict = {t.get_name(): t for t in custom_tools}
current_messages = messages.copy()
n_iter = 0
while n_iter < max_iters:
n_iter += 1
request = AgenticSystemTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=current_messages,
stream=stream,
)
turn = None
async for chunk in system.create_agentic_system_turn(request):
if chunk.event.payload.event_type != EventType.turn_complete.value:
yield chunk
else:
turn = chunk.event.payload.turn
message = turn.output_message
if len(message.tool_calls) == 0:
yield chunk
return
if message.stop_reason == StopReason.out_of_tokens:
yield chunk
return
tool_call = message.tool_calls[0]
if tool_call.tool_name not in tools_dict:
m = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
)
next_message = m
else:
tool = tools_dict[tool_call.tool_name]
result_messages = await execute_custom_tool(tool, message)
next_message = result_messages[0]
yield next_message
current_messages = [next_message]
async def execute_custom_tool(tool: Any, message: Message) -> List[Message]:
result_messages = await tool.run([message])
assert (
len(result_messages) == 1
), f"Expected single message, got {len(result_messages)}"
return result_messages

View file

@ -16,7 +16,10 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
provider_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"codeshield", "codeshield",
"matplotlib",
"pillow", "pillow",
"pandas",
"scikit-learn",
"torch", "torch",
"transformers", "transformers",
], ],

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -5,70 +5,92 @@
# the root directory of this source tree. # the root directory of this source tree.
import uuid import uuid
from typing import Any, List, Optional
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.agentic_system.api import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.memory.api import * # noqa: F403 from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403 from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.agentic_system.client import AgenticSystemClient from llama_toolchain.tools.custom.datatypes import CustomTool
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import ( from .execute_with_custom_tools import AgentWithCustomToolExecutor
execute_with_custom_tools,
)
# TODO: this should move back to the llama-agentic-system repo class AttachmentBehavior(Enum):
rag = "rag"
code_interpreter = "code_interpreter"
auto = "auto"
class AgenticSystemClientWrapper: def default_builtins() -> List[BuiltinTool]:
def __init__(self, api, agent_id, custom_tools): return [
self.api = api BuiltinTool.brave_search,
self.agent_id = agent_id BuiltinTool.wolfram_alpha,
self.custom_tools = custom_tools BuiltinTool.photogen,
self.session_id = None BuiltinTool.code_interpreter,
]
async def create_session(self, name: str = None):
if name is None:
name = f"Session-{uuid.uuid4()}"
response = await self.api.create_agentic_system_session(
agent_id=self.agent_id,
session_name=name,
)
self.session_id = response.session_id
return self.session_id
async def run(self, messages: List[Message], stream: bool = True):
async for chunk in execute_with_custom_tools(
self.api,
self.agent_id,
self.session_id,
messages,
self.custom_tools,
stream=stream,
):
yield chunk
async def get_agent_system_instance( @dataclass
host: str, class QuickToolConfig:
port: int, custom_tools: List[CustomTool] = field(default_factory=list)
custom_tools: Optional[List[Any]] = None,
disable_safety: bool = False, prompt_format: ToolPromptFormat = ToolPromptFormat.json
# use this to control whether you want the model to write code to
# process them, or you want to "RAG" them beforehand
attachment_behavior: AttachmentBehavior = AttachmentBehavior.auto
builtin_tools: List[BuiltinTool] = field(default_factory=default_builtins)
# if you have a memory bank already pre-populated, specify it here
memory_bank_id: Optional[str] = None
# This is a utility function; it does not provide all bells and whistles
# you can get from the underlying AgenticSystem API. Any limitations should
# ideally be resolved by making another well-scoped utility function instead
# of adding complex options here.
async def make_agent_config_with_custom_tools(
model: str = "Meta-Llama3.1-8B-Instruct", model: str = "Meta-Llama3.1-8B-Instruct",
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, disable_safety: bool = False,
) -> AgenticSystemClientWrapper: tool_config: QuickToolConfig = QuickToolConfig(),
custom_tools = custom_tools or [] ) -> AgentConfig:
tool_definitions = []
api = AgenticSystemClient(base_url=f"http://{host}:{port}") # ensure code interpreter is enabled if attachments need it
builtin_tools = tool_config.builtin_tools
tool_choice = ToolChoice.auto
if tool_config.attachment_behavior == AttachmentBehavior.code_interpreter:
if BuiltinTool.code_interpreter not in builtin_tools:
builtin_tools.append(BuiltinTool.code_interpreter)
tool_definitions = [ tool_choice = ToolChoice.required
BraveSearchToolDefinition(),
WolframAlphaToolDefinition(), for t in builtin_tools:
PhotogenToolDefinition(), if t == BuiltinTool.brave_search:
CodeInterpreterToolDefinition(), tool_definitions.append(BraveSearchToolDefinition())
] + [t.get_tool_definition() for t in custom_tools] elif t == BuiltinTool.wolfram_alpha:
tool_definitions.append(WolframAlphaToolDefinition())
elif t == BuiltinTool.photogen:
tool_definitions.append(PhotogenToolDefinition())
elif t == BuiltinTool.code_interpreter:
tool_definitions.append(CodeInterpreterToolDefinition())
# enable memory unless we are specifically disabling it
if tool_config.attachment_behavior != AttachmentBehavior.code_interpreter:
bank_configs = []
if tool_config.memory_bank_id:
bank_configs.append(
AgenticSystemVectorMemoryBankConfig(bank_id=tool_config.memory_bank_id)
)
tool_definitions.append(MemoryToolDefinition(memory_bank_configs=bank_configs))
tool_definitions += [t.get_tool_definition() for t in tool_config.custom_tools]
if not disable_safety: if not disable_safety:
for t in tool_definitions: for t in tool_definitions:
@ -78,10 +100,13 @@ async def get_agent_system_instance(
ShieldDefinition(shield_type=BuiltinShield.injection_shield), ShieldDefinition(shield_type=BuiltinShield.injection_shield),
] ]
agent_config = AgentConfig( cfg = AgentConfig(
model=model, model=model,
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
sampling_params=SamplingParams(),
tools=tool_definitions, tools=tool_definitions,
tool_prompt_format=tool_config.prompt_format,
tool_choice=tool_choice,
input_shields=( input_shields=(
[] []
if disable_safety if disable_safety
@ -97,8 +122,28 @@ async def get_agent_system_instance(
ShieldDefinition(shield_type=BuiltinShield.llama_guard), ShieldDefinition(shield_type=BuiltinShield.llama_guard),
] ]
), ),
sampling_params=SamplingParams(),
tool_prompt_format=tool_prompt_format,
) )
return cfg
async def get_agent_with_custom_tools(
host: str,
port: int,
agent_config: AgentConfig,
custom_tools: List[CustomTool],
) -> AgentWithCustomToolExecutor:
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
create_response = await api.create_agentic_system(agent_config) create_response = await api.create_agentic_system(agent_config)
return AgenticSystemClientWrapper(api, create_response.agent_id, custom_tools) agent_id = create_response.agent_id
name = f"Session-{uuid.uuid4()}"
response = await api.create_agentic_system_session(
agent_id=agent_id,
session_name=name,
)
session_id = response.session_id
return AgentWithCustomToolExecutor(
api, agent_id, session_id, agent_config, custom_tools
)

View file

@ -30,6 +30,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
return impl return impl
# This should be a broader utility
# This should support local file URLs and data URLs also
async def content_from_doc(doc: MemoryBankDocument) -> str: async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL): if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:

View file

@ -33,7 +33,7 @@ def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
snippet = match.group(1) snippet = match.group(1)
data = json.loads(snippet) data = json.loads(snippet)
return Attachment( return Attachment(
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"] content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
) )
return None return None