mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Refactor custom tool execution utilities
This commit is contained in:
parent
440d125ea0
commit
ceef117abc
9 changed files with 209 additions and 145 deletions
|
@ -111,7 +111,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
|||
)
|
||||
|
||||
for content in user_prompts:
|
||||
cprint(f"User> {content}", color="blue")
|
||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||
iterator = api.create_agentic_system_turn(
|
||||
AgenticSystemTurnCreateRequest(
|
||||
agent_id=create_response.agent_id,
|
||||
|
|
96
llama_toolchain/agentic_system/execute_with_custom_tools.py
Normal file
96
llama_toolchain/agentic_system/execute_with_custom_tools.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
from llama_toolchain.safety.api import * # noqa: F403
|
||||
|
||||
from llama_toolchain.agentic_system.api import (
|
||||
AgenticSystemTurnResponseEventType as EventType,
|
||||
)
|
||||
from llama_toolchain.tools.custom.datatypes import CustomTool
|
||||
|
||||
|
||||
class AgentWithCustomToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
api: AgenticSystem,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
agent_config: AgentConfig,
|
||||
custom_tools: List[CustomTool],
|
||||
):
|
||||
self.api = api
|
||||
self.agent_id = agent_id
|
||||
self.session_id = session_id
|
||||
self.agent_config = agent_config
|
||||
self.custom_tools = custom_tools
|
||||
|
||||
async def execute_turn(
|
||||
self,
|
||||
messages: List[Message],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
max_iters: int = 5,
|
||||
stream: bool = True,
|
||||
) -> AsyncGenerator:
|
||||
tools_dict = {t.get_name(): t for t in self.custom_tools}
|
||||
|
||||
current_messages = messages.copy()
|
||||
n_iter = 0
|
||||
while n_iter < max_iters:
|
||||
n_iter += 1
|
||||
|
||||
request = AgenticSystemTurnCreateRequest(
|
||||
agent_id=self.agent_id,
|
||||
session_id=self.session_id,
|
||||
messages=current_messages,
|
||||
attachments=attachments,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
turn = None
|
||||
async for chunk in self.api.create_agentic_system_turn(request):
|
||||
if chunk.event.payload.event_type != EventType.turn_complete.value:
|
||||
yield chunk
|
||||
else:
|
||||
turn = chunk.event.payload.turn
|
||||
|
||||
message = turn.output_message
|
||||
if len(message.tool_calls) == 0:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
if message.stop_reason == StopReason.out_of_tokens:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
if tool_call.tool_name not in tools_dict:
|
||||
m = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
|
||||
)
|
||||
next_message = m
|
||||
else:
|
||||
tool = tools_dict[tool_call.tool_name]
|
||||
result_messages = await execute_custom_tool(tool, message)
|
||||
next_message = result_messages[0]
|
||||
|
||||
yield next_message
|
||||
current_messages = [next_message]
|
||||
|
||||
|
||||
async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]:
|
||||
result_messages = await tool.run([message])
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
), f"Expected single message, got {len(result_messages)}"
|
||||
|
||||
return result_messages
|
|
@ -596,7 +596,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
and self.agent_config.tool_choice == ToolChoice.required
|
||||
):
|
||||
return False
|
||||
return attachments or AgenticSystemTool.memory.value in enabled_tools
|
||||
else:
|
||||
return True
|
||||
|
||||
return AgenticSystemTool.memory.value in enabled_tools
|
||||
|
||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||
for t in self.agent_config.tools:
|
||||
|
@ -629,7 +632,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
]
|
||||
await self.memory_api.insert_documents(bank.bank_id, documents)
|
||||
|
||||
assert len(bank_ids) > 0, "No memory banks configured?"
|
||||
if not bank_ids:
|
||||
# this can happen if the per-session memory bank is not yet populated
|
||||
# (i.e., no prior turns uploaded an Attachment)
|
||||
return None, []
|
||||
|
||||
query = " ".join(m.content for m in messages)
|
||||
tasks = [
|
||||
|
|
|
@ -1,83 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage
|
||||
|
||||
from llama_toolchain.agentic_system.api import (
|
||||
AgenticSystem,
|
||||
AgenticSystemTurnCreateRequest,
|
||||
AgenticSystemTurnResponseEventType as EventType,
|
||||
)
|
||||
|
||||
from llama_toolchain.inference.api import Message
|
||||
|
||||
|
||||
async def execute_with_custom_tools(
|
||||
system: AgenticSystem,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: List[Message],
|
||||
custom_tools: List[Any],
|
||||
max_iters: int = 5,
|
||||
stream: bool = True,
|
||||
) -> AsyncGenerator:
|
||||
# first create a session, or do you keep a persistent session?
|
||||
tools_dict = {t.get_name(): t for t in custom_tools}
|
||||
|
||||
current_messages = messages.copy()
|
||||
n_iter = 0
|
||||
while n_iter < max_iters:
|
||||
n_iter += 1
|
||||
|
||||
request = AgenticSystemTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=current_messages,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
turn = None
|
||||
async for chunk in system.create_agentic_system_turn(request):
|
||||
if chunk.event.payload.event_type != EventType.turn_complete.value:
|
||||
yield chunk
|
||||
else:
|
||||
turn = chunk.event.payload.turn
|
||||
|
||||
message = turn.output_message
|
||||
if len(message.tool_calls) == 0:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
if message.stop_reason == StopReason.out_of_tokens:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
if tool_call.tool_name not in tools_dict:
|
||||
m = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
|
||||
)
|
||||
next_message = m
|
||||
else:
|
||||
tool = tools_dict[tool_call.tool_name]
|
||||
result_messages = await execute_custom_tool(tool, message)
|
||||
next_message = result_messages[0]
|
||||
|
||||
yield next_message
|
||||
current_messages = [next_message]
|
||||
|
||||
|
||||
async def execute_custom_tool(tool: Any, message: Message) -> List[Message]:
|
||||
result_messages = await tool.run([message])
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
), f"Expected single message, got {len(result_messages)}"
|
||||
|
||||
return result_messages
|
|
@ -16,7 +16,10 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
|
|||
provider_id="meta-reference",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
"matplotlib",
|
||||
"pillow",
|
||||
"pandas",
|
||||
"scikit-learn",
|
||||
"torch",
|
||||
"transformers",
|
||||
],
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -5,70 +5,92 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
from llama_toolchain.safety.api import * # noqa: F403
|
||||
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
||||
from llama_toolchain.tools.custom.datatypes import CustomTool
|
||||
|
||||
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import (
|
||||
execute_with_custom_tools,
|
||||
)
|
||||
from .execute_with_custom_tools import AgentWithCustomToolExecutor
|
||||
|
||||
|
||||
# TODO: this should move back to the llama-agentic-system repo
|
||||
class AttachmentBehavior(Enum):
|
||||
rag = "rag"
|
||||
code_interpreter = "code_interpreter"
|
||||
auto = "auto"
|
||||
|
||||
|
||||
class AgenticSystemClientWrapper:
|
||||
def __init__(self, api, agent_id, custom_tools):
|
||||
self.api = api
|
||||
self.agent_id = agent_id
|
||||
self.custom_tools = custom_tools
|
||||
self.session_id = None
|
||||
|
||||
async def create_session(self, name: str = None):
|
||||
if name is None:
|
||||
name = f"Session-{uuid.uuid4()}"
|
||||
|
||||
response = await self.api.create_agentic_system_session(
|
||||
agent_id=self.agent_id,
|
||||
session_name=name,
|
||||
)
|
||||
self.session_id = response.session_id
|
||||
return self.session_id
|
||||
|
||||
async def run(self, messages: List[Message], stream: bool = True):
|
||||
async for chunk in execute_with_custom_tools(
|
||||
self.api,
|
||||
self.agent_id,
|
||||
self.session_id,
|
||||
messages,
|
||||
self.custom_tools,
|
||||
stream=stream,
|
||||
):
|
||||
yield chunk
|
||||
def default_builtins() -> List[BuiltinTool]:
|
||||
return [
|
||||
BuiltinTool.brave_search,
|
||||
BuiltinTool.wolfram_alpha,
|
||||
BuiltinTool.photogen,
|
||||
BuiltinTool.code_interpreter,
|
||||
]
|
||||
|
||||
|
||||
async def get_agent_system_instance(
|
||||
host: str,
|
||||
port: int,
|
||||
custom_tools: Optional[List[Any]] = None,
|
||||
disable_safety: bool = False,
|
||||
@dataclass
|
||||
class QuickToolConfig:
|
||||
custom_tools: List[CustomTool] = field(default_factory=list)
|
||||
|
||||
prompt_format: ToolPromptFormat = ToolPromptFormat.json
|
||||
|
||||
# use this to control whether you want the model to write code to
|
||||
# process them, or you want to "RAG" them beforehand
|
||||
attachment_behavior: AttachmentBehavior = AttachmentBehavior.auto
|
||||
|
||||
builtin_tools: List[BuiltinTool] = field(default_factory=default_builtins)
|
||||
|
||||
# if you have a memory bank already pre-populated, specify it here
|
||||
memory_bank_id: Optional[str] = None
|
||||
|
||||
|
||||
# This is a utility function; it does not provide all bells and whistles
|
||||
# you can get from the underlying AgenticSystem API. Any limitations should
|
||||
# ideally be resolved by making another well-scoped utility function instead
|
||||
# of adding complex options here.
|
||||
async def make_agent_config_with_custom_tools(
|
||||
model: str = "Meta-Llama3.1-8B-Instruct",
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
) -> AgenticSystemClientWrapper:
|
||||
custom_tools = custom_tools or []
|
||||
disable_safety: bool = False,
|
||||
tool_config: QuickToolConfig = QuickToolConfig(),
|
||||
) -> AgentConfig:
|
||||
tool_definitions = []
|
||||
|
||||
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
|
||||
# ensure code interpreter is enabled if attachments need it
|
||||
builtin_tools = tool_config.builtin_tools
|
||||
tool_choice = ToolChoice.auto
|
||||
if tool_config.attachment_behavior == AttachmentBehavior.code_interpreter:
|
||||
if BuiltinTool.code_interpreter not in builtin_tools:
|
||||
builtin_tools.append(BuiltinTool.code_interpreter)
|
||||
|
||||
tool_definitions = [
|
||||
BraveSearchToolDefinition(),
|
||||
WolframAlphaToolDefinition(),
|
||||
PhotogenToolDefinition(),
|
||||
CodeInterpreterToolDefinition(),
|
||||
] + [t.get_tool_definition() for t in custom_tools]
|
||||
tool_choice = ToolChoice.required
|
||||
|
||||
for t in builtin_tools:
|
||||
if t == BuiltinTool.brave_search:
|
||||
tool_definitions.append(BraveSearchToolDefinition())
|
||||
elif t == BuiltinTool.wolfram_alpha:
|
||||
tool_definitions.append(WolframAlphaToolDefinition())
|
||||
elif t == BuiltinTool.photogen:
|
||||
tool_definitions.append(PhotogenToolDefinition())
|
||||
elif t == BuiltinTool.code_interpreter:
|
||||
tool_definitions.append(CodeInterpreterToolDefinition())
|
||||
|
||||
# enable memory unless we are specifically disabling it
|
||||
if tool_config.attachment_behavior != AttachmentBehavior.code_interpreter:
|
||||
bank_configs = []
|
||||
if tool_config.memory_bank_id:
|
||||
bank_configs.append(
|
||||
AgenticSystemVectorMemoryBankConfig(bank_id=tool_config.memory_bank_id)
|
||||
)
|
||||
tool_definitions.append(MemoryToolDefinition(memory_bank_configs=bank_configs))
|
||||
|
||||
tool_definitions += [t.get_tool_definition() for t in tool_config.custom_tools]
|
||||
|
||||
if not disable_safety:
|
||||
for t in tool_definitions:
|
||||
|
@ -78,10 +100,13 @@ async def get_agent_system_instance(
|
|||
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
|
||||
]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
cfg = AgentConfig(
|
||||
model=model,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params=SamplingParams(),
|
||||
tools=tool_definitions,
|
||||
tool_prompt_format=tool_config.prompt_format,
|
||||
tool_choice=tool_choice,
|
||||
input_shields=(
|
||||
[]
|
||||
if disable_safety
|
||||
|
@ -97,8 +122,28 @@ async def get_agent_system_instance(
|
|||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||
]
|
||||
),
|
||||
sampling_params=SamplingParams(),
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
async def get_agent_with_custom_tools(
|
||||
host: str,
|
||||
port: int,
|
||||
agent_config: AgentConfig,
|
||||
custom_tools: List[CustomTool],
|
||||
) -> AgentWithCustomToolExecutor:
|
||||
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
|
||||
|
||||
create_response = await api.create_agentic_system(agent_config)
|
||||
return AgenticSystemClientWrapper(api, create_response.agent_id, custom_tools)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
name = f"Session-{uuid.uuid4()}"
|
||||
response = await api.create_agentic_system_session(
|
||||
agent_id=agent_id,
|
||||
session_name=name,
|
||||
)
|
||||
session_id = response.session_id
|
||||
|
||||
return AgentWithCustomToolExecutor(
|
||||
api, agent_id, session_id, agent_config, custom_tools
|
||||
)
|
||||
|
|
|
@ -30,6 +30,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
|
|||
return impl
|
||||
|
||||
|
||||
# This should be a broader utility
|
||||
# This should support local file URLs and data URLs also
|
||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
|
|
@ -33,7 +33,7 @@ def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
|||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
|
||||
content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue