From 830252257bbfbe5fd529b567d75c022de1bb3e23 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 24 Aug 2024 22:56:43 -0700 Subject: [PATCH] fix agentic_system utils --- llama_toolchain/agentic_system/utils.py | 29 +++++++---------------- llama_toolchain/tools/custom/datatypes.py | 6 ++--- 2 files changed, 11 insertions(+), 24 deletions(-) diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index 1ac05ce73..f07d02c73 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -7,20 +7,15 @@ import uuid from typing import Any, List, Optional -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - Message, - SamplingParams, - ToolPromptFormat, -) - -from llama_toolchain.agentic_system.api import AgentConfig, AgenticSystemToolDefinition +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_toolchain.agentic_system.api import * # noqa: F403 +from llama_toolchain.memory.api import * # noqa: F403 +from llama_toolchain.safety.api import * # noqa: F403 from llama_toolchain.agentic_system.client import AgenticSystemClient from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import ( execute_with_custom_tools, ) -from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition # TODO: this should move back to the llama-agentic-system repo @@ -69,18 +64,10 @@ async def get_agent_system_instance( api = AgenticSystemClient(base_url=f"http://{host}:{port}") tool_definitions = [ - AgenticSystemToolDefinition( - tool_name=BuiltinTool.brave_search, - ), - AgenticSystemToolDefinition( - tool_name=BuiltinTool.wolfram_alpha, - ), - AgenticSystemToolDefinition( - tool_name=BuiltinTool.photogen, - ), - AgenticSystemToolDefinition( - tool_name=BuiltinTool.code_interpreter, - ), + BraveSearchToolDefinition(), + WolframAlphaToolDefinition(), + PhotogenToolDefinition(), + CodeInterpreterToolDefinition(), ] + [t.get_tool_definition() for t in custom_tools] if not disable_safety: diff --git a/llama_toolchain/tools/custom/datatypes.py b/llama_toolchain/tools/custom/datatypes.py index a7fe34e9b..05b142d6f 100644 --- a/llama_toolchain/tools/custom/datatypes.py +++ b/llama_toolchain/tools/custom/datatypes.py @@ -54,9 +54,9 @@ class CustomTool: } ) - def get_tool_definition(self) -> AgenticSystemToolDefinition: - return AgenticSystemToolDefinition( - tool_name=self.get_name(), + def get_tool_definition(self) -> FunctionCallToolDefinition: + return FunctionCallToolDefinition( + function_name=self.get_name(), description=self.get_description(), parameters=self.get_params_definition(), )