Refactor custom tool execution utilities

This commit is contained in:
Ashwin Bharambe 2024-08-25 14:34:20 -07:00
parent 440d125ea0
commit ceef117abc
9 changed files with 209 additions and 145 deletions

View file

@ -5,70 +5,92 @@
# the root directory of this source tree.
import uuid
from typing import Any, List, Optional
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.tools.custom.datatypes import CustomTool
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import (
execute_with_custom_tools,
)
from .execute_with_custom_tools import AgentWithCustomToolExecutor
# TODO: this should move back to the llama-agentic-system repo
class AttachmentBehavior(Enum):
rag = "rag"
code_interpreter = "code_interpreter"
auto = "auto"
class AgenticSystemClientWrapper:
def __init__(self, api, agent_id, custom_tools):
self.api = api
self.agent_id = agent_id
self.custom_tools = custom_tools
self.session_id = None
async def create_session(self, name: str = None):
if name is None:
name = f"Session-{uuid.uuid4()}"
response = await self.api.create_agentic_system_session(
agent_id=self.agent_id,
session_name=name,
)
self.session_id = response.session_id
return self.session_id
async def run(self, messages: List[Message], stream: bool = True):
async for chunk in execute_with_custom_tools(
self.api,
self.agent_id,
self.session_id,
messages,
self.custom_tools,
stream=stream,
):
yield chunk
def default_builtins() -> List[BuiltinTool]:
return [
BuiltinTool.brave_search,
BuiltinTool.wolfram_alpha,
BuiltinTool.photogen,
BuiltinTool.code_interpreter,
]
async def get_agent_system_instance(
host: str,
port: int,
custom_tools: Optional[List[Any]] = None,
disable_safety: bool = False,
@dataclass
class QuickToolConfig:
custom_tools: List[CustomTool] = field(default_factory=list)
prompt_format: ToolPromptFormat = ToolPromptFormat.json
# use this to control whether you want the model to write code to
# process them, or you want to "RAG" them beforehand
attachment_behavior: AttachmentBehavior = AttachmentBehavior.auto
builtin_tools: List[BuiltinTool] = field(default_factory=default_builtins)
# if you have a memory bank already pre-populated, specify it here
memory_bank_id: Optional[str] = None
# This is a utility function; it does not provide all bells and whistles
# you can get from the underlying AgenticSystem API. Any limitations should
# ideally be resolved by making another well-scoped utility function instead
# of adding complex options here.
async def make_agent_config_with_custom_tools(
model: str = "Meta-Llama3.1-8B-Instruct",
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> AgenticSystemClientWrapper:
custom_tools = custom_tools or []
disable_safety: bool = False,
tool_config: QuickToolConfig = QuickToolConfig(),
) -> AgentConfig:
tool_definitions = []
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
# ensure code interpreter is enabled if attachments need it
builtin_tools = tool_config.builtin_tools
tool_choice = ToolChoice.auto
if tool_config.attachment_behavior == AttachmentBehavior.code_interpreter:
if BuiltinTool.code_interpreter not in builtin_tools:
builtin_tools.append(BuiltinTool.code_interpreter)
tool_definitions = [
BraveSearchToolDefinition(),
WolframAlphaToolDefinition(),
PhotogenToolDefinition(),
CodeInterpreterToolDefinition(),
] + [t.get_tool_definition() for t in custom_tools]
tool_choice = ToolChoice.required
for t in builtin_tools:
if t == BuiltinTool.brave_search:
tool_definitions.append(BraveSearchToolDefinition())
elif t == BuiltinTool.wolfram_alpha:
tool_definitions.append(WolframAlphaToolDefinition())
elif t == BuiltinTool.photogen:
tool_definitions.append(PhotogenToolDefinition())
elif t == BuiltinTool.code_interpreter:
tool_definitions.append(CodeInterpreterToolDefinition())
# enable memory unless we are specifically disabling it
if tool_config.attachment_behavior != AttachmentBehavior.code_interpreter:
bank_configs = []
if tool_config.memory_bank_id:
bank_configs.append(
AgenticSystemVectorMemoryBankConfig(bank_id=tool_config.memory_bank_id)
)
tool_definitions.append(MemoryToolDefinition(memory_bank_configs=bank_configs))
tool_definitions += [t.get_tool_definition() for t in tool_config.custom_tools]
if not disable_safety:
for t in tool_definitions:
@ -78,10 +100,13 @@ async def get_agent_system_instance(
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
]
agent_config = AgentConfig(
cfg = AgentConfig(
model=model,
instructions="You are a helpful assistant",
sampling_params=SamplingParams(),
tools=tool_definitions,
tool_prompt_format=tool_config.prompt_format,
tool_choice=tool_choice,
input_shields=(
[]
if disable_safety
@ -97,8 +122,28 @@ async def get_agent_system_instance(
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
]
),
sampling_params=SamplingParams(),
tool_prompt_format=tool_prompt_format,
)
return cfg
async def get_agent_with_custom_tools(
host: str,
port: int,
agent_config: AgentConfig,
custom_tools: List[CustomTool],
) -> AgentWithCustomToolExecutor:
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
create_response = await api.create_agentic_system(agent_config)
return AgenticSystemClientWrapper(api, create_response.agent_id, custom_tools)
agent_id = create_response.agent_id
name = f"Session-{uuid.uuid4()}"
response = await api.create_agentic_system_session(
agent_id=agent_id,
session_name=name,
)
session_id = response.session_id
return AgentWithCustomToolExecutor(
api, agent_id, session_id, agent_config, custom_tools
)