From ceef117abc65ab8b6cd8fe0cf6f9b847d6f9ab79 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 25 Aug 2024 14:34:20 -0700 Subject: [PATCH] Refactor custom tool execution utilities --- llama_toolchain/agentic_system/client.py | 2 +- .../execute_with_custom_tools.py | 96 +++++++++++ .../meta_reference/agent_instance.py | 10 +- .../execute_with_custom_tools.py | 83 ---------- llama_toolchain/agentic_system/providers.py | 3 + .../agentic_system/tools/__init__.py | 5 - llama_toolchain/agentic_system/utils.py | 151 ++++++++++++------ .../memory/meta_reference/faiss/faiss.py | 2 + llama_toolchain/tools/builtin.py | 2 +- 9 files changed, 209 insertions(+), 145 deletions(-) create mode 100644 llama_toolchain/agentic_system/execute_with_custom_tools.py delete mode 100644 llama_toolchain/agentic_system/meta_reference/execute_with_custom_tools.py delete mode 100644 llama_toolchain/agentic_system/tools/__init__.py diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index f4ae21da4..4048e6da3 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -111,7 +111,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None): ) for content in user_prompts: - cprint(f"User> {content}", color="blue") + cprint(f"User> {content}", color="white", attrs=["bold"]) iterator = api.create_agentic_system_turn( AgenticSystemTurnCreateRequest( agent_id=create_response.agent_id, diff --git a/llama_toolchain/agentic_system/execute_with_custom_tools.py b/llama_toolchain/agentic_system/execute_with_custom_tools.py new file mode 100644 index 000000000..e8038bc20 --- /dev/null +++ b/llama_toolchain/agentic_system/execute_with_custom_tools.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator, List + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_toolchain.agentic_system.api import * # noqa: F403 +from llama_toolchain.memory.api import * # noqa: F403 +from llama_toolchain.safety.api import * # noqa: F403 + +from llama_toolchain.agentic_system.api import ( + AgenticSystemTurnResponseEventType as EventType, +) +from llama_toolchain.tools.custom.datatypes import CustomTool + + +class AgentWithCustomToolExecutor: + def __init__( + self, + api: AgenticSystem, + agent_id: str, + session_id: str, + agent_config: AgentConfig, + custom_tools: List[CustomTool], + ): + self.api = api + self.agent_id = agent_id + self.session_id = session_id + self.agent_config = agent_config + self.custom_tools = custom_tools + + async def execute_turn( + self, + messages: List[Message], + attachments: Optional[List[Attachment]] = None, + max_iters: int = 5, + stream: bool = True, + ) -> AsyncGenerator: + tools_dict = {t.get_name(): t for t in self.custom_tools} + + current_messages = messages.copy() + n_iter = 0 + while n_iter < max_iters: + n_iter += 1 + + request = AgenticSystemTurnCreateRequest( + agent_id=self.agent_id, + session_id=self.session_id, + messages=current_messages, + attachments=attachments, + stream=stream, + ) + + turn = None + async for chunk in self.api.create_agentic_system_turn(request): + if chunk.event.payload.event_type != EventType.turn_complete.value: + yield chunk + else: + turn = chunk.event.payload.turn + + message = turn.output_message + if len(message.tool_calls) == 0: + yield chunk + return + + if message.stop_reason == StopReason.out_of_tokens: + yield chunk + return + + tool_call = message.tool_calls[0] + if tool_call.tool_name not in tools_dict: + m = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", + ) + next_message = m + else: + tool = tools_dict[tool_call.tool_name] + result_messages = await execute_custom_tool(tool, message) + next_message = result_messages[0] + + yield next_message + current_messages = [next_message] + + +async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]: + result_messages = await tool.run([message]) + assert ( + len(result_messages) == 1 + ), f"Expected single message, got {len(result_messages)}" + + return result_messages diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 127ebf18e..c559470cc 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -596,7 +596,10 @@ class ChatAgent(ShieldRunnerMixin): and self.agent_config.tool_choice == ToolChoice.required ): return False - return attachments or AgenticSystemTool.memory.value in enabled_tools + else: + return True + + return AgenticSystemTool.memory.value in enabled_tools def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: for t in self.agent_config.tools: @@ -629,7 +632,10 @@ class ChatAgent(ShieldRunnerMixin): ] await self.memory_api.insert_documents(bank.bank_id, documents) - assert len(bank_ids) > 0, "No memory banks configured?" + if not bank_ids: + # this can happen if the per-session memory bank is not yet populated + # (i.e., no prior turns uploaded an Attachment) + return None, [] query = " ".join(m.content for m in messages) tasks = [ diff --git a/llama_toolchain/agentic_system/meta_reference/execute_with_custom_tools.py b/llama_toolchain/agentic_system/meta_reference/execute_with_custom_tools.py deleted file mode 100644 index 2d0068894..000000000 --- a/llama_toolchain/agentic_system/meta_reference/execute_with_custom_tools.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, AsyncGenerator, List - -from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage - -from llama_toolchain.agentic_system.api import ( - AgenticSystem, - AgenticSystemTurnCreateRequest, - AgenticSystemTurnResponseEventType as EventType, -) - -from llama_toolchain.inference.api import Message - - -async def execute_with_custom_tools( - system: AgenticSystem, - agent_id: str, - session_id: str, - messages: List[Message], - custom_tools: List[Any], - max_iters: int = 5, - stream: bool = True, -) -> AsyncGenerator: - # first create a session, or do you keep a persistent session? - tools_dict = {t.get_name(): t for t in custom_tools} - - current_messages = messages.copy() - n_iter = 0 - while n_iter < max_iters: - n_iter += 1 - - request = AgenticSystemTurnCreateRequest( - agent_id=agent_id, - session_id=session_id, - messages=current_messages, - stream=stream, - ) - - turn = None - async for chunk in system.create_agentic_system_turn(request): - if chunk.event.payload.event_type != EventType.turn_complete.value: - yield chunk - else: - turn = chunk.event.payload.turn - - message = turn.output_message - if len(message.tool_calls) == 0: - yield chunk - return - - if message.stop_reason == StopReason.out_of_tokens: - yield chunk - return - - tool_call = message.tool_calls[0] - if tool_call.tool_name not in tools_dict: - m = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", - ) - next_message = m - else: - tool = tools_dict[tool_call.tool_name] - result_messages = await execute_custom_tool(tool, message) - next_message = result_messages[0] - - yield next_message - current_messages = [next_message] - - -async def execute_custom_tool(tool: Any, message: Message) -> List[Message]: - result_messages = await tool.run([message]) - assert ( - len(result_messages) == 1 - ), f"Expected single message, got {len(result_messages)}" - - return result_messages diff --git a/llama_toolchain/agentic_system/providers.py b/llama_toolchain/agentic_system/providers.py index 7d49fd004..027a7de43 100644 --- a/llama_toolchain/agentic_system/providers.py +++ b/llama_toolchain/agentic_system/providers.py @@ -16,7 +16,10 @@ def available_agentic_system_providers() -> List[ProviderSpec]: provider_id="meta-reference", pip_packages=[ "codeshield", + "matplotlib", "pillow", + "pandas", + "scikit-learn", "torch", "transformers", ], diff --git a/llama_toolchain/agentic_system/tools/__init__.py b/llama_toolchain/agentic_system/tools/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_toolchain/agentic_system/tools/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index f07d02c73..f418b5fbf 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -5,70 +5,92 @@ # the root directory of this source tree. import uuid -from typing import Any, List, Optional + +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403 +from llama_toolchain.agentic_system.client import AgenticSystemClient from llama_toolchain.memory.api import * # noqa: F403 from llama_toolchain.safety.api import * # noqa: F403 -from llama_toolchain.agentic_system.client import AgenticSystemClient +from llama_toolchain.tools.custom.datatypes import CustomTool -from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import ( - execute_with_custom_tools, -) +from .execute_with_custom_tools import AgentWithCustomToolExecutor -# TODO: this should move back to the llama-agentic-system repo +class AttachmentBehavior(Enum): + rag = "rag" + code_interpreter = "code_interpreter" + auto = "auto" -class AgenticSystemClientWrapper: - def __init__(self, api, agent_id, custom_tools): - self.api = api - self.agent_id = agent_id - self.custom_tools = custom_tools - self.session_id = None - - async def create_session(self, name: str = None): - if name is None: - name = f"Session-{uuid.uuid4()}" - - response = await self.api.create_agentic_system_session( - agent_id=self.agent_id, - session_name=name, - ) - self.session_id = response.session_id - return self.session_id - - async def run(self, messages: List[Message], stream: bool = True): - async for chunk in execute_with_custom_tools( - self.api, - self.agent_id, - self.session_id, - messages, - self.custom_tools, - stream=stream, - ): - yield chunk +def default_builtins() -> List[BuiltinTool]: + return [ + BuiltinTool.brave_search, + BuiltinTool.wolfram_alpha, + BuiltinTool.photogen, + BuiltinTool.code_interpreter, + ] -async def get_agent_system_instance( - host: str, - port: int, - custom_tools: Optional[List[Any]] = None, - disable_safety: bool = False, +@dataclass +class QuickToolConfig: + custom_tools: List[CustomTool] = field(default_factory=list) + + prompt_format: ToolPromptFormat = ToolPromptFormat.json + + # use this to control whether you want the model to write code to + # process them, or you want to "RAG" them beforehand + attachment_behavior: AttachmentBehavior = AttachmentBehavior.auto + + builtin_tools: List[BuiltinTool] = field(default_factory=default_builtins) + + # if you have a memory bank already pre-populated, specify it here + memory_bank_id: Optional[str] = None + + +# This is a utility function; it does not provide all bells and whistles +# you can get from the underlying AgenticSystem API. Any limitations should +# ideally be resolved by making another well-scoped utility function instead +# of adding complex options here. +async def make_agent_config_with_custom_tools( model: str = "Meta-Llama3.1-8B-Instruct", - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, -) -> AgenticSystemClientWrapper: - custom_tools = custom_tools or [] + disable_safety: bool = False, + tool_config: QuickToolConfig = QuickToolConfig(), +) -> AgentConfig: + tool_definitions = [] - api = AgenticSystemClient(base_url=f"http://{host}:{port}") + # ensure code interpreter is enabled if attachments need it + builtin_tools = tool_config.builtin_tools + tool_choice = ToolChoice.auto + if tool_config.attachment_behavior == AttachmentBehavior.code_interpreter: + if BuiltinTool.code_interpreter not in builtin_tools: + builtin_tools.append(BuiltinTool.code_interpreter) - tool_definitions = [ - BraveSearchToolDefinition(), - WolframAlphaToolDefinition(), - PhotogenToolDefinition(), - CodeInterpreterToolDefinition(), - ] + [t.get_tool_definition() for t in custom_tools] + tool_choice = ToolChoice.required + + for t in builtin_tools: + if t == BuiltinTool.brave_search: + tool_definitions.append(BraveSearchToolDefinition()) + elif t == BuiltinTool.wolfram_alpha: + tool_definitions.append(WolframAlphaToolDefinition()) + elif t == BuiltinTool.photogen: + tool_definitions.append(PhotogenToolDefinition()) + elif t == BuiltinTool.code_interpreter: + tool_definitions.append(CodeInterpreterToolDefinition()) + + # enable memory unless we are specifically disabling it + if tool_config.attachment_behavior != AttachmentBehavior.code_interpreter: + bank_configs = [] + if tool_config.memory_bank_id: + bank_configs.append( + AgenticSystemVectorMemoryBankConfig(bank_id=tool_config.memory_bank_id) + ) + tool_definitions.append(MemoryToolDefinition(memory_bank_configs=bank_configs)) + + tool_definitions += [t.get_tool_definition() for t in tool_config.custom_tools] if not disable_safety: for t in tool_definitions: @@ -78,10 +100,13 @@ async def get_agent_system_instance( ShieldDefinition(shield_type=BuiltinShield.injection_shield), ] - agent_config = AgentConfig( + cfg = AgentConfig( model=model, instructions="You are a helpful assistant", + sampling_params=SamplingParams(), tools=tool_definitions, + tool_prompt_format=tool_config.prompt_format, + tool_choice=tool_choice, input_shields=( [] if disable_safety @@ -97,8 +122,28 @@ async def get_agent_system_instance( ShieldDefinition(shield_type=BuiltinShield.llama_guard), ] ), - sampling_params=SamplingParams(), - tool_prompt_format=tool_prompt_format, ) + return cfg + + +async def get_agent_with_custom_tools( + host: str, + port: int, + agent_config: AgentConfig, + custom_tools: List[CustomTool], +) -> AgentWithCustomToolExecutor: + api = AgenticSystemClient(base_url=f"http://{host}:{port}") + create_response = await api.create_agentic_system(agent_config) - return AgenticSystemClientWrapper(api, create_response.agent_id, custom_tools) + agent_id = create_response.agent_id + + name = f"Session-{uuid.uuid4()}" + response = await api.create_agentic_system_session( + agent_id=agent_id, + session_name=name, + ) + session_id = response.session_id + + return AgentWithCustomToolExecutor( + api, agent_id, session_id, agent_config, custom_tools + ) diff --git a/llama_toolchain/memory/meta_reference/faiss/faiss.py b/llama_toolchain/memory/meta_reference/faiss/faiss.py index 85a92f35f..e6d7b2b46 100644 --- a/llama_toolchain/memory/meta_reference/faiss/faiss.py +++ b/llama_toolchain/memory/meta_reference/faiss/faiss.py @@ -30,6 +30,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp return impl +# This should be a broader utility +# This should support local file URLs and data URLs also async def content_from_doc(doc: MemoryBankDocument) -> str: if isinstance(doc.content, URL): async with httpx.AsyncClient() as client: diff --git a/llama_toolchain/tools/builtin.py b/llama_toolchain/tools/builtin.py index e5e71187f..f2ddeefa7 100644 --- a/llama_toolchain/tools/builtin.py +++ b/llama_toolchain/tools/builtin.py @@ -33,7 +33,7 @@ def interpret_content_as_attachment(content: str) -> Optional[Attachment]: snippet = match.group(1) data = json.loads(snippet) return Attachment( - url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"] + content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"] ) return None