diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 65be92348..325ce9490 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -47,78 +47,6 @@ class Attachment(BaseModel): mime_type: str -class AgentTool(Enum): - brave_search = "brave_search" - wolfram_alpha = "wolfram_alpha" - photogen = "photogen" - code_interpreter = "code_interpreter" - - function_call = "function_call" - memory = "memory" - - -class ToolDefinitionCommon(BaseModel): - input_shields: Optional[List[str]] = Field(default_factory=list) - output_shields: Optional[List[str]] = Field(default_factory=list) - - -class SearchEngineType(Enum): - bing = "bing" - brave = "brave" - tavily = "tavily" - - -@json_schema_type -class SearchToolDefinition(ToolDefinitionCommon): - # NOTE: brave_search is just a placeholder since model always uses - # brave_search as tool call name - type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value - api_key: str - engine: SearchEngineType = SearchEngineType.brave - remote_execution: Optional[RestAPIExecutionConfig] = None - - -@json_schema_type -class WolframAlphaToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value - api_key: str - remote_execution: Optional[RestAPIExecutionConfig] = None - - -@json_schema_type -class PhotogenToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value - remote_execution: Optional[RestAPIExecutionConfig] = None - - -@json_schema_type -class CodeInterpreterToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value - enable_inline_code_execution: bool = True - remote_execution: Optional[RestAPIExecutionConfig] = None - - -@json_schema_type -class FunctionCallToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value - function_name: str - description: str - parameters: Dict[str, ToolParamDefinition] - remote_execution: Optional[RestAPIExecutionConfig] = None - - -AgentToolDefinition = Annotated[ - Union[ - SearchToolDefinition, - WolframAlphaToolDefinition, - PhotogenToolDefinition, - CodeInterpreterToolDefinition, - FunctionCallToolDefinition, - ], - Field(discriminator="type"), -] - - class StepCommon(BaseModel): turn_id: str step_id: str @@ -211,10 +139,6 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - - tools: Optional[List[AgentToolDefinition]] = Field( - default_factory=list, deprecated=True - ) available_tools: Optional[List[str]] = Field(default_factory=list) preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 60b2bdab9..15d59ca8f 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -21,6 +21,8 @@ class ToolParameter(BaseModel): name: str parameter_type: str description: str + required: bool + default: Optional[Any] = None @json_schema_type diff --git a/llama_stack/llama_stack/providers/tests/agents/conftest.py b/llama_stack/llama_stack/providers/tests/agents/conftest.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/llama_stack/providers/tests/agents/conftest.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 00d8bbd36..8d52ac1b9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -5,13 +5,15 @@ # the root directory of this source tree. import copy +import json import logging import os +import re import secrets import string import uuid from datetime import datetime -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, List from urllib.parse import urlparse import httpx @@ -29,16 +31,11 @@ from llama_stack.apis.agents import ( AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, Attachment, - CodeInterpreterToolDefinition, - FunctionCallToolDefinition, InferenceStep, - PhotogenToolDefinition, - SearchToolDefinition, ShieldCallStep, StepType, ToolExecutionStep, Turn, - WolframAlphaToolDefinition, ) from llama_stack.apis.common.content_types import ( URL, @@ -67,15 +64,6 @@ from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin -from .tools.base import BaseTool -from .tools.builtin import ( - CodeInterpreterTool, - PhotogenTool, - SearchTool, - WolframAlphaTool, - interpret_content_as_attachment, -) -from .tools.safety import SafeTool log = logging.getLogger(__name__) @@ -86,6 +74,9 @@ def make_random_string(length: int = 8): ) +TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") + + class ChatAgent(ShieldRunnerMixin): def __init__( self, @@ -111,29 +102,6 @@ class ChatAgent(ShieldRunnerMixin): self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api - builtin_tools = [] - for tool_defn in agent_config.tools: - if isinstance(tool_defn, WolframAlphaToolDefinition): - tool = WolframAlphaTool(tool_defn.api_key) - elif isinstance(tool_defn, SearchToolDefinition): - tool = SearchTool(tool_defn.engine, tool_defn.api_key) - elif isinstance(tool_defn, CodeInterpreterToolDefinition): - tool = CodeInterpreterTool() - elif isinstance(tool_defn, PhotogenToolDefinition): - tool = PhotogenTool(dump_dir=self.tempdir) - else: - continue - - builtin_tools.append( - SafeTool( - tool, - safety_api, - tool_defn.input_shields, - tool_defn.output_shields, - ) - ) - self.tools_dict = {t.get_name(): t for t in builtin_tools} - ShieldRunnerMixin.__init__( self, safety_api, @@ -453,7 +421,7 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=self._get_tools(), + tools=await self._get_tools(), tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=sampling_params, @@ -595,7 +563,8 @@ class ChatAgent(ShieldRunnerMixin): }, ) as span: result_messages = await execute_tool_call_maybe( - self.tools_dict, + self.tool_runtime_api, + session_id, [message], ) assert ( @@ -627,6 +596,20 @@ class ChatAgent(ShieldRunnerMixin): # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially + def interpret_content_as_attachment( + content: str, + ) -> Optional[Attachment]: + match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) + if match: + snippet = match.group(1) + data = json.loads(snippet) + return Attachment( + url=URL(uri="file://" + data["filepath"]), + mime_type=data["mimetype"], + ) + + return None + if out_attachment := interpret_content_as_attachment( result_message.content ): @@ -639,25 +622,25 @@ class ChatAgent(ShieldRunnerMixin): n_iter += 1 - def _get_tools(self) -> List[ToolDefinition]: + async def _get_tools(self) -> List[ToolDefinition]: ret = [] - for t in self.agent_config.tools: - if isinstance(t, SearchToolDefinition): - ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search)) - elif isinstance(t, WolframAlphaToolDefinition): - ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha)) - elif isinstance(t, PhotogenToolDefinition): - ret.append(ToolDefinition(tool_name=BuiltinTool.photogen)) - elif isinstance(t, CodeInterpreterToolDefinition): - ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter)) - elif isinstance(t, FunctionCallToolDefinition): - ret.append( - ToolDefinition( - tool_name=t.function_name, - description=t.description, - parameters=t.parameters, - ) + for tool_name in self.agent_config.available_tools: + tool = await self.tool_groups_api.get_tool(tool_name) + params = {} + for param in tool.parameters: + params[param.name] = ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, ) + ret.append( + ToolDefinition( + tool_name=tool.identifier, + description=tool.description, + parameters=params, + ) + ) return ret @@ -696,7 +679,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa async def execute_tool_call_maybe( - tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage] + tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage] ) -> List[ToolResponseMessage]: # While Tools.run interface takes a list of messages, # All tools currently only run on a single message @@ -712,7 +695,17 @@ async def execute_tool_call_maybe( name = name.value - assert name in tools_dict, f"Tool {name} not found" - tool = tools_dict[name] - result_messages = await tool.run(messages) - return result_messages + result = await tool_runtime_api.invoke_tool( + tool_name=name, + args=dict( + session_id=session_id, + **tool_call.arguments, + ), + ) + return [ + ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=result.content, + ) + ] diff --git a/llama_stack/providers/inline/agents/meta_reference/rag/__init__.py b/llama_stack/providers/inline/agents/meta_reference/rag/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/rag/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tests/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tests/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py b/llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py deleted file mode 100644 index 495cd2c92..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import unittest - -from llama_models.llama3.api.datatypes import ( - Attachment, - BuiltinTool, - CompletionMessage, - StopReason, - ToolCall, -) - -from ..tools.builtin import CodeInterpreterTool - - -class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase): - async def test_matplotlib(self): - tool = CodeInterpreterTool() - code = """ -import matplotlib.pyplot as plt -import numpy as np - -x = np.array([1, 1]) -y = np.array([0, 10]) - -plt.plot(x, y) -plt.title('x = 1') -plt.xlabel('x') -plt.ylabel('y') -plt.grid(True) -plt.axvline(x=1, color='r') -plt.show() - """ - message = CompletionMessage( - role="assistant", - content="", - tool_calls=[ - ToolCall( - call_id="call_id", - tool_name=BuiltinTool.code_interpreter, - arguments={"code": code}, - ) - ], - stop_reason=StopReason.end_of_message, - ) - ret = await tool.run([message]) - - self.assertEqual(len(ret), 1) - - output = ret[0].content - self.assertIsInstance(output, Attachment) - self.assertEqual(output.mime_type, "image/png") - - async def test_path_unlink(self): - tool = CodeInterpreterTool() - code = """ -import os -from pathlib import Path -import tempfile - -dpath = Path(os.environ["MPLCONFIGDIR"]) -with open(dpath / "test", "w") as f: - f.write("hello") - -Path(dpath / "test").unlink() -print("_OK_") - """ - message = CompletionMessage( - role="assistant", - content="", - tool_calls=[ - ToolCall( - call_id="call_id", - tool_name=BuiltinTool.code_interpreter, - arguments={"code": code}, - ) - ], - stop_reason=StopReason.end_of_message, - ) - ret = await tool.run([message]) - - self.assertEqual(len(ret), 1) - - output = ret[0].content - self.assertTrue("_OK_" in output) - - -if __name__ == "__main__": - unittest.main() diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tools/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/base.py b/llama_stack/providers/inline/agents/meta_reference/tools/base.py deleted file mode 100644 index 15fba7e2e..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/base.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from abc import ABC, abstractmethod -from typing import List - -from llama_stack.apis.inference import Message - - -class BaseTool(ABC): - @abstractmethod - def get_name(self) -> str: - raise NotImplementedError - - @abstractmethod - async def run(self, messages: List[Message]) -> List[Message]: - raise NotImplementedError diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py b/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py deleted file mode 100644 index 5045bf32d..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py +++ /dev/null @@ -1,396 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import logging -import re -import tempfile - -from abc import abstractmethod -from typing import List, Optional - -import requests - -from .ipython_tool.code_execution import ( - CodeExecutionContext, - CodeExecutionRequest, - CodeExecutor, - TOOLS_ATTACHMENT_KEY_REGEX, -) - -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 - -from .base import BaseTool - - -log = logging.getLogger(__name__) - - -def interpret_content_as_attachment(content: str) -> Optional[Attachment]: - match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) - if match: - snippet = match.group(1) - data = json.loads(snippet) - return Attachment( - url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"] - ) - - return None - - -class SingleMessageBuiltinTool(BaseTool): - async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: - assert len(messages) == 1, f"Expected single message, got {len(messages)}" - - message = messages[0] - assert len(message.tool_calls) == 1, "Expected a single tool call" - - tool_call = messages[0].tool_calls[0] - - query = tool_call.arguments["query"] - response: str = await self.run_impl(query) - - message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=response, - ) - return [message] - - @abstractmethod - async def run_impl(self, query: str) -> str: - raise NotImplementedError() - - -class PhotogenTool(SingleMessageBuiltinTool): - def __init__(self, dump_dir: str) -> None: - self.dump_dir = dump_dir - - def get_name(self) -> str: - return BuiltinTool.photogen.value - - async def run_impl(self, query: str) -> str: - """ - Implement this to give the model an ability to generate images. - - Return: - info = { - "filepath": str(image_filepath), - "mimetype": "image/png", - } - """ - raise NotImplementedError() - - -class SearchTool(SingleMessageBuiltinTool): - def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None: - self.api_key = api_key - self.engine_type = engine - if engine == SearchEngineType.bing: - self.engine = BingSearch(api_key, **kwargs) - elif engine == SearchEngineType.brave: - self.engine = BraveSearch(api_key, **kwargs) - elif engine == SearchEngineType.tavily: - self.engine = TavilySearch(api_key, **kwargs) - else: - raise ValueError(f"Unknown search engine: {engine}") - - def get_name(self) -> str: - return BuiltinTool.brave_search.value - - async def run_impl(self, query: str) -> str: - return await self.engine.search(query) - - -class BingSearch: - def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None: - self.api_key = api_key - self.top_k = top_k - - async def search(self, query: str) -> str: - url = "https://api.bing.microsoft.com/v7.0/search" - headers = { - "Ocp-Apim-Subscription-Key": self.api_key, - } - params = { - "count": self.top_k, - "textDecorations": True, - "textFormat": "HTML", - "q": query, - } - - response = requests.get(url=url, params=params, headers=headers) - response.raise_for_status() - clean = self._clean_response(response.json()) - return json.dumps(clean) - - def _clean_response(self, search_response): - clean_response = [] - query = search_response["queryContext"]["originalQuery"] - if "webPages" in search_response: - pages = search_response["webPages"]["value"] - for p in pages: - selected_keys = {"name", "url", "snippet"} - clean_response.append( - {k: v for k, v in p.items() if k in selected_keys} - ) - if "news" in search_response: - clean_news = [] - news = search_response["news"]["value"] - for n in news: - selected_keys = {"name", "url", "description"} - clean_news.append({k: v for k, v in n.items() if k in selected_keys}) - - clean_response.append(clean_news) - - return {"query": query, "top_k": clean_response} - - -class BraveSearch: - def __init__(self, api_key: str) -> None: - self.api_key = api_key - - async def search(self, query: str) -> str: - url = "https://api.search.brave.com/res/v1/web/search" - headers = { - "X-Subscription-Token": self.api_key, - "Accept-Encoding": "gzip", - "Accept": "application/json", - } - payload = {"q": query} - response = requests.get(url=url, params=payload, headers=headers) - return json.dumps(self._clean_brave_response(response.json())) - - def _clean_brave_response(self, search_response, top_k=3): - query = None - clean_response = [] - if "query" in search_response: - if "original" in search_response["query"]: - query = search_response["query"]["original"] - if "mixed" in search_response: - mixed_results = search_response["mixed"] - for m in mixed_results["main"][:top_k]: - r_type = m["type"] - results = search_response[r_type]["results"] - if r_type == "web": - # For web data - add a single output from the search - idx = m["index"] - selected_keys = [ - "type", - "title", - "url", - "description", - "date", - "extra_snippets", - ] - cleaned = { - k: v for k, v in results[idx].items() if k in selected_keys - } - elif r_type == "faq": - # For faw data - take a list of all the questions & answers - selected_keys = ["type", "question", "answer", "title", "url"] - cleaned = [] - for q in results: - cleaned.append( - {k: v for k, v in q.items() if k in selected_keys} - ) - elif r_type == "infobox": - idx = m["index"] - selected_keys = [ - "type", - "title", - "url", - "description", - "long_desc", - ] - cleaned = { - k: v for k, v in results[idx].items() if k in selected_keys - } - elif r_type == "videos": - selected_keys = [ - "type", - "url", - "title", - "description", - "date", - ] - cleaned = [] - for q in results: - cleaned.append( - {k: v for k, v in q.items() if k in selected_keys} - ) - elif r_type == "locations": - # For faw data - take a list of all the questions & answers - selected_keys = [ - "type", - "title", - "url", - "description", - "coordinates", - "postal_address", - "contact", - "rating", - "distance", - "zoom_level", - ] - cleaned = [] - for q in results: - cleaned.append( - {k: v for k, v in q.items() if k in selected_keys} - ) - elif r_type == "news": - # For faw data - take a list of all the questions & answers - selected_keys = [ - "type", - "title", - "url", - "description", - ] - cleaned = [] - for q in results: - cleaned.append( - {k: v for k, v in q.items() if k in selected_keys} - ) - else: - cleaned = [] - - clean_response.append(cleaned) - - return {"query": query, "top_k": clean_response} - - -class TavilySearch: - def __init__(self, api_key: str) -> None: - self.api_key = api_key - - async def search(self, query: str) -> str: - response = requests.post( - "https://api.tavily.com/search", - json={"api_key": self.api_key, "query": query}, - ) - return json.dumps(self._clean_tavily_response(response.json())) - - def _clean_tavily_response(self, search_response, top_k=3): - return {"query": search_response["query"], "top_k": search_response["results"]} - - -class WolframAlphaTool(SingleMessageBuiltinTool): - def __init__(self, api_key: str) -> None: - self.api_key = api_key - self.url = "https://api.wolframalpha.com/v2/query" - - def get_name(self) -> str: - return BuiltinTool.wolfram_alpha.value - - async def run_impl(self, query: str) -> str: - params = { - "input": query, - "appid": self.api_key, - "format": "plaintext", - "output": "json", - } - response = requests.get( - self.url, - params=params, - ) - - return json.dumps(self._clean_wolfram_alpha_response(response.json())) - - def _clean_wolfram_alpha_response(self, wa_response): - remove = { - "queryresult": [ - "datatypes", - "error", - "timedout", - "timedoutpods", - "numpods", - "timing", - "parsetiming", - "parsetimedout", - "recalculate", - "id", - "host", - "server", - "related", - "version", - { - "pods": [ - "scanner", - "id", - "error", - "expressiontypes", - "states", - "infos", - "position", - "numsubpods", - ] - }, - "assumptions", - ], - } - for main_key in remove: - for key_to_remove in remove[main_key]: - try: - if key_to_remove == "assumptions": - if "assumptions" in wa_response[main_key]: - del wa_response[main_key][key_to_remove] - if isinstance(key_to_remove, dict): - for sub_key in key_to_remove: - if sub_key == "pods": - for i in range(len(wa_response[main_key][sub_key])): - if ( - wa_response[main_key][sub_key][i]["title"] - == "Result" - ): - del wa_response[main_key][sub_key][i + 1 :] - break - sub_items = wa_response[main_key][sub_key] - for i in range(len(sub_items)): - for sub_key_to_remove in key_to_remove[sub_key]: - if sub_key_to_remove in sub_items[i]: - del sub_items[i][sub_key_to_remove] - elif key_to_remove in wa_response[main_key]: - del wa_response[main_key][key_to_remove] - except KeyError: - pass - return wa_response - - -class CodeInterpreterTool(BaseTool): - def __init__(self) -> None: - ctx = CodeExecutionContext( - matplotlib_dump_dir=tempfile.mkdtemp(), - ) - self.code_executor = CodeExecutor(ctx) - - def get_name(self) -> str: - return BuiltinTool.code_interpreter.value - - async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: - message = messages[0] - assert len(message.tool_calls) == 1, "Expected a single tool call" - - tool_call = messages[0].tool_calls[0] - script = tool_call.arguments["code"] - - req = CodeExecutionRequest(scripts=[script]) - res = self.code_executor.execute(req) - - pieces = [res["process_status"]] - for out_type in ["stdout", "stderr"]: - res_out = res[out_type] - if res_out != "": - pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"]) - if out_type == "stderr": - log.error(f"ipython tool error: ↓\n{res_out}") - - message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content="\n".join(pieces), - ) - return [message] diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py deleted file mode 100644 index 10f64ec94..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import errno - -# Disabling potentially dangerous functions -import os as _os -from functools import partial - -os_funcs_to_disable = [ - "kill", - "system", - "putenv", - "remove", - "removedirs", - "rmdir", - "fchdir", - "setuid", - "fork", - "forkpty", - "killpg", - "rename", - "renames", - "truncate", - "replace", - # "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly - "fchmod", - "fchown", - "chmod", - "chown", - "chroot", - "fchdir", - "lchflags", - "lchmod", - "lchown", - "chdir", -] - - -def call_not_allowed(*args, **kwargs): - raise OSError(errno.EPERM, "Call are not permitted in this environment") - - -for func_name in os_funcs_to_disable: - if hasattr(_os, func_name): - setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}")) - -import shutil as _shutil - -for func_name in ["rmtree", "move", "chown"]: - if hasattr(_shutil, func_name): - setattr( - _shutil, - func_name, - partial(call_not_allowed, _func_name=f"shutil.{func_name}"), - ) - -import subprocess as _subprocess - - -def popen_not_allowed(*args, **kwargs): - raise _subprocess.CalledProcessError( - -1, - args[0] if args else "unknown", - stderr="subprocess.Popen is not allowed in this environment", - ) - - -_subprocess.Popen = popen_not_allowed - - -import atexit as _atexit -import builtins as _builtins -import io as _io -import json as _json -import sys as _sys - -# NB! The following "unused" imports crucial, make sure not not to remove -# them with linters - they're used in code_execution.py -from contextlib import ( # noqa - contextmanager as _contextmanager, - redirect_stderr as _redirect_stderr, - redirect_stdout as _redirect_stdout, -) -from multiprocessing.connection import Connection as _Connection - -# Mangle imports to avoid polluting model execution namespace. - -_IO_SINK = _io.StringIO() -_NETWORK_TIMEOUT = 5 -_NETWORK_CONNECTIONS = None - - -def _open_connections(): - global _NETWORK_CONNECTIONS - if _NETWORK_CONNECTIONS is not None: - # Ensure connections only opened once. - return _NETWORK_CONNECTIONS - req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2] - req_con = _Connection(int(req_w_fd), readable=False) - resp_con = _Connection(int(resp_r_fd), writable=False) - _NETWORK_CONNECTIONS = (req_con, resp_con) - return _NETWORK_CONNECTIONS - - -_builtins._open_connections = _open_connections - - -@_atexit.register -def _close_connections(): - global _NETWORK_CONNECTIONS - if _NETWORK_CONNECTIONS is None: - return - for con in _NETWORK_CONNECTIONS: - con.close() - del _NETWORK_CONNECTIONS - - -def _network_call(request): - # NOTE: We communicate with the parent process in json, encoded - # in raw bytes. We do this because native send/recv methods use - # pickle which involves execution of arbitrary code. - _open_connections() - req_con, resp_con = _NETWORK_CONNECTIONS - - req_con.send_bytes(_json.dumps(request).encode("utf-8")) - if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None: - raise Exception(f"Network request timed out: {_json.dumps(request)}") - else: - return _json.loads(resp_con.recv_bytes().decode("utf-8")) diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py deleted file mode 100644 index fa2e367e5..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import base64 -import json -import multiprocessing -import os -import re -import subprocess -import sys -import tempfile -import textwrap -import time -from dataclasses import dataclass -from datetime import datetime -from io import BytesIO -from pathlib import Path -from typing import List - -from PIL import Image - -from .utils import get_code_env_prefix - -TOOLS_ATTACHMENT_KEY = "__tools_attachment__" -TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") - -DIRNAME = Path(__file__).parent - -CODE_EXEC_TIMEOUT = 20 -CODE_ENV_PREFIX = get_code_env_prefix() - -STDOUTERR_SINK_WRAPPER_TEMPLATE = """\ -with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK): -{code}\ -""" - -TRYEXCEPT_WRAPPER_TEMPLATE = """\ -try: -{code} -except: - pass\ -""" - - -def generate_bwrap_command(bind_dirs: List[str]) -> str: - """ - Generate the bwrap command string for binding all - directories in the current directory read-only. - """ - bwrap_args = "" - bwrap_args += "--ro-bind / / " - # Add the --dev flag to mount device files - bwrap_args += "--dev /dev " - for d in bind_dirs: - bwrap_args += f"--bind {d} {d} " - - # Add the --unshare-all flag to isolate the sandbox from the rest of the system - bwrap_args += "--unshare-all " - # Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies - bwrap_args += "--die-with-parent " - return bwrap_args - - -@dataclass -class CodeExecutionContext: - matplotlib_dump_dir: str - use_proxy: bool = False - - -@dataclass -class CodeExecutionRequest: - scripts: List[str] - only_last_cell_stdouterr: bool = True - only_last_cell_fail: bool = True - seed: int = 0 - strip_fpaths_in_stderr: bool = True - - -class CodeExecutor: - def __init__(self, context: CodeExecutionContext): - self.context = context - - def execute(self, req: CodeExecutionRequest) -> dict: - scripts = req.scripts - for i in range(len(scripts) - 1): - if req.only_last_cell_stdouterr: - scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format( - code=textwrap.indent(scripts[i], " " * 4) - ) - if req.only_last_cell_fail: - scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format( - code=textwrap.indent(scripts[i], " " * 4) - ) - - # Seeds prefix: - seed = req.seed - seeds_prefix = f"""\ -def _set_seeds(): - import random - random.seed({seed}) - import numpy as np - np.random.seed({seed}) -_set_seeds()\ -""" - - script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts) - with tempfile.TemporaryDirectory() as dpath: - bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath]) - cmd = [*bwrap_prefix.split(), sys.executable, "-c", script] - code_fpath = os.path.join(dpath, "code.py") - with open(code_fpath, "w") as f: - f.write(script) - - try: - python_path = os.environ.get("PYTHONPATH", "") - env = dict( - os.environ, - PYTHONHASHSEED=str(seed), - MPLCONFIGDIR=dpath, - MPLBACKEND="module://matplotlib_custom_backend", - PYTHONPATH=f"{DIRNAME}:{python_path}", - ) - stdout, stderr, returncode = do_subprocess( - cmd=cmd, - env=env, - ctx=self.context, - ) - - stderr = stderr.strip() - if req.strip_fpaths_in_stderr: - pattern = r'File "([^"]+)", line (\d+)' - stderr = re.sub(pattern, r"line \2", stderr) - - return { - "process_status": "completed", - "returncode": returncode, - "stdout": stdout.strip(), - "stderr": stderr, - } - - except subprocess.TimeoutExpired: - return { - "process_status": "timeout", - "stdout": "Timed out", - "stderr": "Timed out", - } - - except Exception as e: - return { - "process_status": "error", - "error_type": type(e).__name__, - "stderr": str(e), - "stdout": str(e), - } - - -def process_matplotlib_response(response, matplotlib_dump_dir: str): - image_data = response["image_data"] - # Convert the base64 string to a bytes object - images = [base64.b64decode(d["image_base64"]) for d in image_data] - # Create a list of PIL images from the bytes objects - images = [Image.open(BytesIO(img)) for img in images] - # Create a list of image paths - image_paths = [] - for i, img in enumerate(images): - # create new directory for each day to better organize data: - dump_dname = datetime.today().strftime("%Y-%m-%d") - dump_dpath = Path(matplotlib_dump_dir, dump_dname) - dump_dpath.mkdir(parents=True, exist_ok=True) - # save image into a file - dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png" - dump_fpath = dump_dpath / dump_fname - img.save(dump_fpath, "PNG") - image_paths.append(str(dump_fpath)) - - # this is kind of convoluted, we send back this response to the subprocess which - # prints it out - info = { - "filepath": str(image_paths[-1]), - "mimetype": "image/png", - } - return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}" - - -def execute_subprocess_request(request, ctx: CodeExecutionContext): - "Route requests from the subprocess (via network Pipes) to the internet/tools." - if request["type"] == "matplotlib": - return process_matplotlib_response(request, ctx.matplotlib_dump_dir) - else: - raise Exception(f'Unrecognised network request type: {request["type"]}') - - -def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext): - # Create Pipes to be used for any external tool/network requests. - req_r, req_w = multiprocessing.Pipe(duplex=False) - resp_r, resp_w = multiprocessing.Pipe(duplex=False) - - cmd += [str(req_w.fileno()), str(resp_r.fileno())] - proc = subprocess.Popen( - cmd, - pass_fds=(req_w.fileno(), resp_r.fileno()), - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - close_fds=True, - env=env, - ) - - # Close unnecessary fds. - req_w.close() - resp_r.close() - - pipe_close = False - done_read = False - start = time.monotonic() - while proc.poll() is None and not pipe_close: - if req_r.poll(0.1): - # NB: Python pipe semantics for poll and recv mean that - # poll() returns True is a pipe is closed. - # CF old school PEP from '09 - # https://bugs.python.org/issue5573 - try: - request = json.loads(req_r.recv_bytes().decode("utf-8")) - response = execute_subprocess_request(request, ctx) - - resp_w.send_bytes(json.dumps(response).encode("utf-8")) - except EOFError: - # The request pipe is closed - set a marker to exit - # after the next attempt at reading stdout/stderr. - pipe_close = True - - try: - # If lots has been printed, pipe might be full but - # proc cannot exit until all the stdout/stderr - # been written/read. - stdout, stderr = proc.communicate(timeout=0.3) - done_read = True - except subprocess.TimeoutExpired: - # The program has not terminated. Ignore it, there - # may be more network/tool requests. - continue - if time.monotonic() - start > CODE_EXEC_TIMEOUT: - proc.terminate() - raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT) - - if not done_read: - # Solve race condition where process terminates before - # we hit the while loop. - stdout, stderr = proc.communicate(timeout=0.3) - - resp_w.close() - req_r.close() - return stdout, stderr, proc.returncode diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py deleted file mode 100644 index 7fec08cf2..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -A custom Matplotlib backend that overrides the show method to return image bytes. -""" - -import base64 -import io -import json as _json -import logging - -import matplotlib -from matplotlib.backend_bases import FigureManagerBase - -# Import necessary components from Matplotlib -from matplotlib.backends.backend_agg import FigureCanvasAgg - -log = logging.getLogger(__name__) - - -class CustomFigureCanvas(FigureCanvasAgg): - def show(self): - # Save the figure to a BytesIO object - buf = io.BytesIO() - self.print_png(buf) - image_bytes = buf.getvalue() - buf.close() - return image_bytes - - -class CustomFigureManager(FigureManagerBase): - def __init__(self, canvas, num): - super().__init__(canvas, num) - - -# Mimic module initialization that integrates with the Matplotlib backend system -def _create_figure_manager(num, *args, **kwargs): - """ - Create a custom figure manager instance. - """ - FigureClass = kwargs.pop("FigureClass", None) # noqa: N806 - if FigureClass is None: - from matplotlib.figure import Figure - - FigureClass = Figure # noqa: N806 - fig = FigureClass(*args, **kwargs) - canvas = CustomFigureCanvas(fig) - manager = CustomFigureManager(canvas, num) - return manager - - -def show(): - """ - Handle all figures and potentially return their images as bytes. - - This function iterates over all figures registered with the custom backend, - renders them as images in bytes format, and could return a list of bytes objects, - one for each figure, or handle them as needed. - """ - image_data = [] - for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers(): - # Get the figure from the manager - fig = manager.canvas.figure - buf = io.BytesIO() # Create a buffer for the figure - fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format - buf.seek(0) # Go to the beginning of the buffer - image_bytes = buf.getvalue() # Retrieve bytes value - image_base64 = base64.b64encode(image_bytes).decode("utf-8") - image_data.append({"image_base64": image_base64}) - buf.close() - - req_con, resp_con = _open_connections() - - _json_dump = _json.dumps( - { - "type": "matplotlib", - "image_data": image_data, - } - ) - req_con.send_bytes(_json_dump.encode("utf-8")) - resp = _json.loads(resp_con.recv_bytes().decode("utf-8")) - log.info(resp) - - -FigureCanvas = CustomFigureCanvas -FigureManager = CustomFigureManager diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py deleted file mode 100644 index d6f539a39..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import os - -DIR = os.path.dirname(os.path.realpath(__file__)) -CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py") -CODE_ENV_PREFIX = None - - -def get_code_env_prefix() -> str: - global CODE_ENV_PREFIX - - if CODE_ENV_PREFIX is None: - with open(CODE_ENV_PREFIX_FILE, "r") as f: - CODE_ENV_PREFIX = f.read() - - return CODE_ENV_PREFIX diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index dd9882aa6..f5158b57c 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os import tempfile import pytest @@ -68,7 +69,14 @@ def tool_runtime_memory() -> ProviderFixture: provider_id="memory-runtime", provider_type="inline::memory-runtime", config={}, - ) + ), + Provider( + provider_id="brave-search", + provider_type="inline::brave-search", + config={ + "api_key": os.environ["BRAVE_SEARCH_API_KEY"], + }, + ), ], ) @@ -131,6 +139,20 @@ async def agents_stack(request, inference_model, safety_shield): ) ) tool_groups = [ + ToolGroupInput( + tool_group_id="brave_search_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + ToolDef( + name="brave_search", + description="brave_search", + parameters=[], + metadata={}, + ), + ], + ), + provider_id="brave-search", + ), ToolGroupInput( tool_group_id="memory_group", tool_group=UserDefinedToolGroupDef( @@ -163,7 +185,7 @@ async def agents_stack(request, inference_model, safety_shield): ], ), provider_id="memory-runtime", - ) + ), ] test_stack = await construct_stack_for_test( diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 4ff94e4fe..78ca2341f 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -50,7 +50,8 @@ def common_params(inference_model): sampling_params=SamplingParams(temperature=0.7, top_p=0.95), input_shields=[], output_shields=[], - tools=[], + available_tools=[], + preprocessing_tools=[], max_infer_iters=5, ) @@ -91,7 +92,7 @@ async def create_agent_turn_with_search_tool( agents_stack: Dict[str, object], search_query_messages: List[object], common_params: Dict[str, str], - search_tool_definition: SearchToolDefinition, + tool_name: str, ) -> None: """ Create an agent turn with a search tool. @@ -107,7 +108,7 @@ async def create_agent_turn_with_search_tool( agent_config = AgentConfig( **{ **common_params, - "tools": [search_tool_definition], + "available_tools": [tool_name], } ) @@ -254,7 +255,6 @@ class TestAgents: agent_config = AgentConfig( **{ **common_params, - "tools": [], "preprocessing_tools": ["memory"], "tool_choice": ToolChoice.auto, } @@ -295,29 +295,11 @@ class TestAgents: if "BRAVE_SEARCH_API_KEY" not in os.environ: pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") - search_tool_definition = SearchToolDefinition( - type=AgentTool.brave_search.value, - api_key=os.environ["BRAVE_SEARCH_API_KEY"], - engine=SearchEngineType.brave, - ) await create_agent_turn_with_search_tool( - agents_stack, search_query_messages, common_params, search_tool_definition - ) - - @pytest.mark.asyncio - async def test_create_agent_turn_with_tavily_search( - self, agents_stack, search_query_messages, common_params - ): - if "TAVILY_SEARCH_API_KEY" not in os.environ: - pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") - - search_tool_definition = SearchToolDefinition( - type=AgentTool.brave_search.value, # place holder only - api_key=os.environ["TAVILY_SEARCH_API_KEY"], - engine=SearchEngineType.tavily, - ) - await create_agent_turn_with_search_tool( - agents_stack, search_query_messages, common_params, search_tool_definition + agents_stack, + search_query_messages, + common_params, + "brave_search", )