mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Removing custom tool and agent utilities and moving them client side
This commit is contained in:
parent
fa864f70da
commit
099ac81bc7
17 changed files with 100 additions and 392 deletions
|
@ -25,14 +25,10 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.tools.base import BaseTool
|
|
||||||
from llama_stack.tools.builtin import (
|
|
||||||
interpret_content_as_attachment,
|
|
||||||
SingleMessageBuiltinTool,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .rag.context_retriever import generate_rag_query
|
from .rag.context_retriever import generate_rag_query
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
from .tools.base import BaseTool
|
||||||
|
from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
|
|
|
@ -14,16 +14,16 @@ from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
from llama_stack.tools.builtin import (
|
|
||||||
|
from .agent_instance import ChatAgent
|
||||||
|
from .config import MetaReferenceImplConfig
|
||||||
|
from .tools.builtin import (
|
||||||
CodeInterpreterTool,
|
CodeInterpreterTool,
|
||||||
PhotogenTool,
|
PhotogenTool,
|
||||||
SearchTool,
|
SearchTool,
|
||||||
WolframAlphaTool,
|
WolframAlphaTool,
|
||||||
)
|
)
|
||||||
from llama_stack.tools.safety import with_safety
|
from .tools.safety import with_safety
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
|
||||||
from .config import MetaReferenceImplConfig
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
Attachment,
|
||||||
|
BuiltinTool,
|
||||||
|
CompletionMessage,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..tools.builtin import CodeInterpreterTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def test_matplotlib(self):
|
||||||
|
tool = CodeInterpreterTool()
|
||||||
|
code = """
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
x = np.array([1, 1])
|
||||||
|
y = np.array([0, 10])
|
||||||
|
|
||||||
|
plt.plot(x, y)
|
||||||
|
plt.title('x = 1')
|
||||||
|
plt.xlabel('x')
|
||||||
|
plt.ylabel('y')
|
||||||
|
plt.grid(True)
|
||||||
|
plt.axvline(x=1, color='r')
|
||||||
|
plt.show()
|
||||||
|
"""
|
||||||
|
message = CompletionMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="call_id",
|
||||||
|
tool_name=BuiltinTool.code_interpreter,
|
||||||
|
arguments={"code": code},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
)
|
||||||
|
ret = await tool.run([message])
|
||||||
|
|
||||||
|
self.assertEqual(len(ret), 1)
|
||||||
|
|
||||||
|
output = ret[0].content
|
||||||
|
self.assertIsInstance(output, Attachment)
|
||||||
|
self.assertEqual(output.mime_type, "image/png")
|
||||||
|
|
||||||
|
async def test_path_unlink(self):
|
||||||
|
tool = CodeInterpreterTool()
|
||||||
|
code = """
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
dpath = Path(os.environ["MPLCONFIGDIR"])
|
||||||
|
with open(dpath / "test", "w") as f:
|
||||||
|
f.write("hello")
|
||||||
|
|
||||||
|
Path(dpath / "test").unlink()
|
||||||
|
print("_OK_")
|
||||||
|
"""
|
||||||
|
message = CompletionMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="call_id",
|
||||||
|
tool_name=BuiltinTool.code_interpreter,
|
||||||
|
arguments={"code": code},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
)
|
||||||
|
ret = await tool.run([message])
|
||||||
|
|
||||||
|
self.assertEqual(len(ret), 1)
|
||||||
|
|
||||||
|
output = ret[0].content
|
||||||
|
self.assertTrue("_OK_" in output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
|
@ -1,184 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
role: Optional[str] = None,
|
|
||||||
content: str = "",
|
|
||||||
end: str = "\n",
|
|
||||||
color="white",
|
|
||||||
):
|
|
||||||
self.role = role
|
|
||||||
self.content = content
|
|
||||||
self.color = color
|
|
||||||
self.end = "\n" if end is None else end
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.role is not None:
|
|
||||||
return f"{self.role}> {self.content}"
|
|
||||||
else:
|
|
||||||
return f"{self.content}"
|
|
||||||
|
|
||||||
def print(self, flush=True):
|
|
||||||
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
|
|
||||||
|
|
||||||
|
|
||||||
EventType = AgentTurnResponseEventType
|
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
|
||||||
async def log(
|
|
||||||
self,
|
|
||||||
event_generator,
|
|
||||||
stream=True,
|
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
||||||
):
|
|
||||||
previous_event_type = None
|
|
||||||
previous_step_type = None
|
|
||||||
|
|
||||||
async for chunk in event_generator:
|
|
||||||
if not hasattr(chunk, "event"):
|
|
||||||
# Need to check for custom tool first
|
|
||||||
# since it does not produce event but instead
|
|
||||||
# a Message
|
|
||||||
if isinstance(chunk, ToolResponseMessage):
|
|
||||||
yield chunk, LogEvent(
|
|
||||||
role="CustomTool", content=chunk.content, color="grey"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
event = chunk.event
|
|
||||||
event_type = event.payload.event_type
|
|
||||||
if event_type in {
|
|
||||||
EventType.turn_start.value,
|
|
||||||
EventType.turn_complete.value,
|
|
||||||
}:
|
|
||||||
# Currently not logging any turn realted info
|
|
||||||
yield event, None
|
|
||||||
continue
|
|
||||||
|
|
||||||
step_type = event.payload.step_type
|
|
||||||
# handle safety
|
|
||||||
if (
|
|
||||||
step_type == StepType.shield_call
|
|
||||||
and event_type == EventType.step_complete.value
|
|
||||||
):
|
|
||||||
response = event.payload.step_details.response
|
|
||||||
if not response.is_violation:
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=step_type, content="No Violation", color="magenta"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"{response.violation_type} {response.violation_return_message}",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle inference
|
|
||||||
if step_type == StepType.inference:
|
|
||||||
if stream:
|
|
||||||
if event_type == EventType.step_start.value:
|
|
||||||
# TODO: Currently this event is never received
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=step_type, content="", end="", color="yellow"
|
|
||||||
)
|
|
||||||
elif event_type == EventType.step_progress.value:
|
|
||||||
# HACK: if previous was not step/event was not inference's step_progress
|
|
||||||
# this is the first time we are getting model inference response
|
|
||||||
# aka equivalent to step_start for inference. Hence,
|
|
||||||
# start with "Model>".
|
|
||||||
if (
|
|
||||||
previous_event_type != EventType.step_progress.value
|
|
||||||
and previous_step_type != StepType.inference
|
|
||||||
):
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=step_type, content="", end="", color="yellow"
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.payload.tool_call_delta:
|
|
||||||
if isinstance(event.payload.tool_call_delta.content, str):
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=None,
|
|
||||||
content=event.payload.tool_call_delta.content,
|
|
||||||
end="",
|
|
||||||
color="cyan",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=None,
|
|
||||||
content=event.payload.model_response_text_delta,
|
|
||||||
end="",
|
|
||||||
color="yellow",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# step_complete
|
|
||||||
yield event, LogEvent(role=None, content="")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Not streaming
|
|
||||||
if event_type == EventType.step_complete.value:
|
|
||||||
response = event.payload.step_details.model_response
|
|
||||||
if response.tool_calls:
|
|
||||||
content = ToolUtils.encode_tool_call(
|
|
||||||
response.tool_calls[0], tool_prompt_format
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
content = response.content
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=content,
|
|
||||||
color="yellow",
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle tool_execution
|
|
||||||
if (
|
|
||||||
step_type == StepType.tool_execution
|
|
||||||
and
|
|
||||||
# Only print tool calls and responses at the step_complete event
|
|
||||||
event_type == EventType.step_complete.value
|
|
||||||
):
|
|
||||||
details = event.payload.step_details
|
|
||||||
for t in details.tool_calls:
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
|
||||||
color="green",
|
|
||||||
)
|
|
||||||
for r in details.tool_responses:
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"Tool:{r.tool_name} Response:{r.content}",
|
|
||||||
color="green",
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
step_type == StepType.memory_retrieval
|
|
||||||
and event_type == EventType.step_complete.value
|
|
||||||
):
|
|
||||||
details = event.payload.step_details
|
|
||||||
content = interleaved_text_media_as_str(details.inserted_context)
|
|
||||||
content = content[:200] + "..." if len(content) > 200 else content
|
|
||||||
|
|
||||||
yield event, LogEvent(
|
|
||||||
role=step_type,
|
|
||||||
content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>",
|
|
||||||
color="cyan",
|
|
||||||
)
|
|
||||||
|
|
||||||
preivous_event_type = event_type
|
|
||||||
previous_step_type = step_type
|
|
|
@ -1,94 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import AsyncGenerator, List
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentTurnResponseEventType as EventType
|
|
||||||
from llama_stack.tools.custom.datatypes import CustomTool
|
|
||||||
|
|
||||||
|
|
||||||
class AgentWithCustomToolExecutor:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api: Agents,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
agent_config: AgentConfig,
|
|
||||||
custom_tools: List[CustomTool],
|
|
||||||
):
|
|
||||||
self.api = api
|
|
||||||
self.agent_id = agent_id
|
|
||||||
self.session_id = session_id
|
|
||||||
self.agent_config = agent_config
|
|
||||||
self.custom_tools = custom_tools
|
|
||||||
|
|
||||||
async def execute_turn(
|
|
||||||
self,
|
|
||||||
messages: List[Message],
|
|
||||||
attachments: Optional[List[Attachment]] = None,
|
|
||||||
max_iters: int = 5,
|
|
||||||
stream: bool = True,
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
tools_dict = {t.get_name(): t for t in self.custom_tools}
|
|
||||||
|
|
||||||
current_messages = messages.copy()
|
|
||||||
n_iter = 0
|
|
||||||
while n_iter < max_iters:
|
|
||||||
n_iter += 1
|
|
||||||
|
|
||||||
request = AgentTurnCreateRequest(
|
|
||||||
agent_id=self.agent_id,
|
|
||||||
session_id=self.session_id,
|
|
||||||
messages=current_messages,
|
|
||||||
attachments=attachments,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn = None
|
|
||||||
async for chunk in self.api.create_agent_turn(request):
|
|
||||||
if chunk.event.payload.event_type != EventType.turn_complete.value:
|
|
||||||
yield chunk
|
|
||||||
else:
|
|
||||||
turn = chunk.event.payload.turn
|
|
||||||
|
|
||||||
message = turn.output_message
|
|
||||||
if len(message.tool_calls) == 0:
|
|
||||||
yield chunk
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.stop_reason == StopReason.out_of_tokens:
|
|
||||||
yield chunk
|
|
||||||
return
|
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
if tool_call.tool_name not in tools_dict:
|
|
||||||
m = ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
|
|
||||||
)
|
|
||||||
next_message = m
|
|
||||||
else:
|
|
||||||
tool = tools_dict[tool_call.tool_name]
|
|
||||||
result_messages = await execute_custom_tool(tool, message)
|
|
||||||
next_message = result_messages[0]
|
|
||||||
|
|
||||||
yield next_message
|
|
||||||
current_messages = [next_message]
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]:
|
|
||||||
result_messages = await tool.run([message])
|
|
||||||
assert (
|
|
||||||
len(result_messages) == 1
|
|
||||||
), f"Expected single message, got {len(result_messages)}"
|
|
||||||
|
|
||||||
return result_messages
|
|
|
@ -1,98 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class CustomTool:
|
|
||||||
"""
|
|
||||||
Developers can define their custom tools that models can use
|
|
||||||
by extending this class.
|
|
||||||
|
|
||||||
Developers need to provide
|
|
||||||
- name
|
|
||||||
- description
|
|
||||||
- params_definition
|
|
||||||
- implement tool's behavior in `run_impl` method
|
|
||||||
|
|
||||||
NOTE: The return of the `run` method needs to be json serializable
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_name(self) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_description(self) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_params_definition(self) -> Dict[str, ToolParamDefinition]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_instruction_string(self) -> str:
|
|
||||||
return f"Use the function '{self.get_name()}' to: {self.get_description()}"
|
|
||||||
|
|
||||||
def parameters_for_system_prompt(self) -> str:
|
|
||||||
return json.dumps(
|
|
||||||
{
|
|
||||||
"name": self.get_name(),
|
|
||||||
"description": self.get_description(),
|
|
||||||
"parameters": {
|
|
||||||
name: definition.__dict__
|
|
||||||
for name, definition in self.get_params_definition().items()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_tool_definition(self) -> FunctionCallToolDefinition:
|
|
||||||
return FunctionCallToolDefinition(
|
|
||||||
function_name=self.get_name(),
|
|
||||||
description=self.get_description(),
|
|
||||||
parameters=self.get_params_definition(),
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run(self, messages: List[Message]) -> List[Message]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class SingleMessageCustomTool(CustomTool):
|
|
||||||
"""
|
|
||||||
Helper class to handle custom tools that take a single message
|
|
||||||
Extending this class and implementing the `run_impl` method will
|
|
||||||
allow for the tool be called by the model and the necessary plumbing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
|
||||||
assert len(messages) == 1, "Expected single message"
|
|
||||||
|
|
||||||
message = messages[0]
|
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.run_impl(**tool_call.arguments)
|
|
||||||
response_str = json.dumps(response, ensure_ascii=False)
|
|
||||||
except Exception as e:
|
|
||||||
response_str = f"Error when running tool: {e}"
|
|
||||||
|
|
||||||
message = ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content=response_str,
|
|
||||||
)
|
|
||||||
return [message]
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run_impl(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError()
|
|
|
@ -1,5 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
Loading…
Add table
Add a link
Reference in a new issue