diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index a8b826972..d7f10a4f5 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -25,14 +25,10 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.tools.base import BaseTool -from llama_stack.tools.builtin import ( - interpret_content_as_attachment, - SingleMessageBuiltinTool, -) - from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin +from .tools.base import BaseTool +from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool def make_random_string(length: int = 8): diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index d77a31bb0..25517ba6c 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -14,16 +14,16 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.tools.builtin import ( + +from .agent_instance import ChatAgent +from .config import MetaReferenceImplConfig +from .tools.builtin import ( CodeInterpreterTool, PhotogenTool, SearchTool, WolframAlphaTool, ) -from llama_stack.tools.safety import with_safety - -from .agent_instance import ChatAgent -from .config import MetaReferenceImplConfig +from .tools.safety import with_safety logger = logging.getLogger() diff --git a/llama_stack/providers/utils/agents/__init__.py b/llama_stack/providers/impls/meta_reference/agents/tests/__init__.py similarity index 100% rename from llama_stack/providers/utils/agents/__init__.py rename to llama_stack/providers/impls/meta_reference/agents/tests/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py b/llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py new file mode 100644 index 000000000..495cd2c92 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import unittest + +from llama_models.llama3.api.datatypes import ( + Attachment, + BuiltinTool, + CompletionMessage, + StopReason, + ToolCall, +) + +from ..tools.builtin import CodeInterpreterTool + + +class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase): + async def test_matplotlib(self): + tool = CodeInterpreterTool() + code = """ +import matplotlib.pyplot as plt +import numpy as np + +x = np.array([1, 1]) +y = np.array([0, 10]) + +plt.plot(x, y) +plt.title('x = 1') +plt.xlabel('x') +plt.ylabel('y') +plt.grid(True) +plt.axvline(x=1, color='r') +plt.show() + """ + message = CompletionMessage( + role="assistant", + content="", + tool_calls=[ + ToolCall( + call_id="call_id", + tool_name=BuiltinTool.code_interpreter, + arguments={"code": code}, + ) + ], + stop_reason=StopReason.end_of_message, + ) + ret = await tool.run([message]) + + self.assertEqual(len(ret), 1) + + output = ret[0].content + self.assertIsInstance(output, Attachment) + self.assertEqual(output.mime_type, "image/png") + + async def test_path_unlink(self): + tool = CodeInterpreterTool() + code = """ +import os +from pathlib import Path +import tempfile + +dpath = Path(os.environ["MPLCONFIGDIR"]) +with open(dpath / "test", "w") as f: + f.write("hello") + +Path(dpath / "test").unlink() +print("_OK_") + """ + message = CompletionMessage( + role="assistant", + content="", + tool_calls=[ + ToolCall( + call_id="call_id", + tool_name=BuiltinTool.code_interpreter, + arguments={"code": code}, + ) + ], + stop_reason=StopReason.end_of_message, + ) + ret = await tool.run([message]) + + self.assertEqual(len(ret), 1) + + output = ret[0].content + self.assertTrue("_OK_" in output) + + +if __name__ == "__main__": + unittest.main() diff --git a/llama_stack/tools/__init__.py b/llama_stack/providers/impls/meta_reference/agents/tools/__init__.py similarity index 100% rename from llama_stack/tools/__init__.py rename to llama_stack/providers/impls/meta_reference/agents/tools/__init__.py diff --git a/llama_stack/tools/base.py b/llama_stack/providers/impls/meta_reference/agents/tools/base.py similarity index 100% rename from llama_stack/tools/base.py rename to llama_stack/providers/impls/meta_reference/agents/tools/base.py diff --git a/llama_stack/tools/builtin.py b/llama_stack/providers/impls/meta_reference/agents/tools/builtin.py similarity index 100% rename from llama_stack/tools/builtin.py rename to llama_stack/providers/impls/meta_reference/agents/tools/builtin.py diff --git a/llama_stack/tools/custom/__init__.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/__init__.py similarity index 100% rename from llama_stack/tools/custom/__init__.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/__init__.py diff --git a/llama_stack/tools/ipython_tool/code_env_prefix.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_env_prefix.py similarity index 100% rename from llama_stack/tools/ipython_tool/code_env_prefix.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_env_prefix.py diff --git a/llama_stack/tools/ipython_tool/code_execution.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_execution.py similarity index 100% rename from llama_stack/tools/ipython_tool/code_execution.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_execution.py diff --git a/llama_stack/tools/ipython_tool/matplotlib_custom_backend.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py similarity index 100% rename from llama_stack/tools/ipython_tool/matplotlib_custom_backend.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py diff --git a/llama_stack/tools/ipython_tool/utils.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/utils.py similarity index 100% rename from llama_stack/tools/ipython_tool/utils.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/utils.py diff --git a/llama_stack/tools/safety.py b/llama_stack/providers/impls/meta_reference/agents/tools/safety.py similarity index 100% rename from llama_stack/tools/safety.py rename to llama_stack/providers/impls/meta_reference/agents/tools/safety.py diff --git a/llama_stack/providers/utils/agents/event_logger.py b/llama_stack/providers/utils/agents/event_logger.py deleted file mode 100644 index 1d3f2a68a..000000000 --- a/llama_stack/providers/utils/agents/event_logger.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Optional - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_models.llama3.api.tool_utils import ToolUtils - -from termcolor import cprint - -from llama_stack.apis.agents import AgentTurnResponseEventType, StepType - - -class LogEvent: - def __init__( - self, - role: Optional[str] = None, - content: str = "", - end: str = "\n", - color="white", - ): - self.role = role - self.content = content - self.color = color - self.end = "\n" if end is None else end - - def __str__(self): - if self.role is not None: - return f"{self.role}> {self.content}" - else: - return f"{self.content}" - - def print(self, flush=True): - cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) - - -EventType = AgentTurnResponseEventType - - -class EventLogger: - async def log( - self, - event_generator, - stream=True, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, - ): - previous_event_type = None - previous_step_type = None - - async for chunk in event_generator: - if not hasattr(chunk, "event"): - # Need to check for custom tool first - # since it does not produce event but instead - # a Message - if isinstance(chunk, ToolResponseMessage): - yield chunk, LogEvent( - role="CustomTool", content=chunk.content, color="grey" - ) - continue - - event = chunk.event - event_type = event.payload.event_type - if event_type in { - EventType.turn_start.value, - EventType.turn_complete.value, - }: - # Currently not logging any turn realted info - yield event, None - continue - - step_type = event.payload.step_type - # handle safety - if ( - step_type == StepType.shield_call - and event_type == EventType.step_complete.value - ): - response = event.payload.step_details.response - if not response.is_violation: - yield event, LogEvent( - role=step_type, content="No Violation", color="magenta" - ) - else: - yield event, LogEvent( - role=step_type, - content=f"{response.violation_type} {response.violation_return_message}", - color="red", - ) - - # handle inference - if step_type == StepType.inference: - if stream: - if event_type == EventType.step_start.value: - # TODO: Currently this event is never received - yield event, LogEvent( - role=step_type, content="", end="", color="yellow" - ) - elif event_type == EventType.step_progress.value: - # HACK: if previous was not step/event was not inference's step_progress - # this is the first time we are getting model inference response - # aka equivalent to step_start for inference. Hence, - # start with "Model>". - if ( - previous_event_type != EventType.step_progress.value - and previous_step_type != StepType.inference - ): - yield event, LogEvent( - role=step_type, content="", end="", color="yellow" - ) - - if event.payload.tool_call_delta: - if isinstance(event.payload.tool_call_delta.content, str): - yield event, LogEvent( - role=None, - content=event.payload.tool_call_delta.content, - end="", - color="cyan", - ) - else: - yield event, LogEvent( - role=None, - content=event.payload.model_response_text_delta, - end="", - color="yellow", - ) - else: - # step_complete - yield event, LogEvent(role=None, content="") - - else: - # Not streaming - if event_type == EventType.step_complete.value: - response = event.payload.step_details.model_response - if response.tool_calls: - content = ToolUtils.encode_tool_call( - response.tool_calls[0], tool_prompt_format - ) - else: - content = response.content - yield event, LogEvent( - role=step_type, - content=content, - color="yellow", - ) - - # handle tool_execution - if ( - step_type == StepType.tool_execution - and - # Only print tool calls and responses at the step_complete event - event_type == EventType.step_complete.value - ): - details = event.payload.step_details - for t in details.tool_calls: - yield event, LogEvent( - role=step_type, - content=f"Tool:{t.tool_name} Args:{t.arguments}", - color="green", - ) - for r in details.tool_responses: - yield event, LogEvent( - role=step_type, - content=f"Tool:{r.tool_name} Response:{r.content}", - color="green", - ) - - if ( - step_type == StepType.memory_retrieval - and event_type == EventType.step_complete.value - ): - details = event.payload.step_details - content = interleaved_text_media_as_str(details.inserted_context) - content = content[:200] + "..." if len(content) > 200 else content - - yield event, LogEvent( - role=step_type, - content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>", - color="cyan", - ) - - preivous_event_type = event_type - previous_step_type = step_type diff --git a/llama_stack/providers/utils/agents/execute_with_custom_tools.py b/llama_stack/providers/utils/agents/execute_with_custom_tools.py deleted file mode 100644 index 928d444ca..000000000 --- a/llama_stack/providers/utils/agents/execute_with_custom_tools.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import AsyncGenerator, List - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 - -from llama_stack.apis.agents import AgentTurnResponseEventType as EventType -from llama_stack.tools.custom.datatypes import CustomTool - - -class AgentWithCustomToolExecutor: - def __init__( - self, - api: Agents, - agent_id: str, - session_id: str, - agent_config: AgentConfig, - custom_tools: List[CustomTool], - ): - self.api = api - self.agent_id = agent_id - self.session_id = session_id - self.agent_config = agent_config - self.custom_tools = custom_tools - - async def execute_turn( - self, - messages: List[Message], - attachments: Optional[List[Attachment]] = None, - max_iters: int = 5, - stream: bool = True, - ) -> AsyncGenerator: - tools_dict = {t.get_name(): t for t in self.custom_tools} - - current_messages = messages.copy() - n_iter = 0 - while n_iter < max_iters: - n_iter += 1 - - request = AgentTurnCreateRequest( - agent_id=self.agent_id, - session_id=self.session_id, - messages=current_messages, - attachments=attachments, - stream=stream, - ) - - turn = None - async for chunk in self.api.create_agent_turn(request): - if chunk.event.payload.event_type != EventType.turn_complete.value: - yield chunk - else: - turn = chunk.event.payload.turn - - message = turn.output_message - if len(message.tool_calls) == 0: - yield chunk - return - - if message.stop_reason == StopReason.out_of_tokens: - yield chunk - return - - tool_call = message.tool_calls[0] - if tool_call.tool_name not in tools_dict: - m = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", - ) - next_message = m - else: - tool = tools_dict[tool_call.tool_name] - result_messages = await execute_custom_tool(tool, message) - next_message = result_messages[0] - - yield next_message - current_messages = [next_message] - - -async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]: - result_messages = await tool.run([message]) - assert ( - len(result_messages) == 1 - ), f"Expected single message, got {len(result_messages)}" - - return result_messages diff --git a/llama_stack/tools/custom/datatypes.py b/llama_stack/tools/custom/datatypes.py deleted file mode 100644 index c8dacefa3..000000000 --- a/llama_stack/tools/custom/datatypes.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json - -from abc import abstractmethod -from typing import Dict, List - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 - - -class CustomTool: - """ - Developers can define their custom tools that models can use - by extending this class. - - Developers need to provide - - name - - description - - params_definition - - implement tool's behavior in `run_impl` method - - NOTE: The return of the `run` method needs to be json serializable - """ - - @abstractmethod - def get_name(self) -> str: - raise NotImplementedError - - @abstractmethod - def get_description(self) -> str: - raise NotImplementedError - - @abstractmethod - def get_params_definition(self) -> Dict[str, ToolParamDefinition]: - raise NotImplementedError - - def get_instruction_string(self) -> str: - return f"Use the function '{self.get_name()}' to: {self.get_description()}" - - def parameters_for_system_prompt(self) -> str: - return json.dumps( - { - "name": self.get_name(), - "description": self.get_description(), - "parameters": { - name: definition.__dict__ - for name, definition in self.get_params_definition().items() - }, - } - ) - - def get_tool_definition(self) -> FunctionCallToolDefinition: - return FunctionCallToolDefinition( - function_name=self.get_name(), - description=self.get_description(), - parameters=self.get_params_definition(), - ) - - @abstractmethod - async def run(self, messages: List[Message]) -> List[Message]: - raise NotImplementedError - - -class SingleMessageCustomTool(CustomTool): - """ - Helper class to handle custom tools that take a single message - Extending this class and implementing the `run_impl` method will - allow for the tool be called by the model and the necessary plumbing. - """ - - async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: - assert len(messages) == 1, "Expected single message" - - message = messages[0] - - tool_call = message.tool_calls[0] - - try: - response = await self.run_impl(**tool_call.arguments) - response_str = json.dumps(response, ensure_ascii=False) - except Exception as e: - response_str = f"Error when running tool: {e}" - - message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=response_str, - ) - return [message] - - @abstractmethod - async def run_impl(self, *args, **kwargs): - raise NotImplementedError() diff --git a/llama_stack/tools/ipython_tool/__init__.py b/llama_stack/tools/ipython_tool/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/tools/ipython_tool/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree.