diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py new file mode 100644 index 000000000..663b9655b --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .code_interpreter import CodeInterpreterToolRuntimeImpl +from .config import CodeInterpreterToolConfig + +__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"] + + +async def get_provider_impl(config: CodeInterpreterToolConfig, _deps): + impl = CodeInterpreterToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py new file mode 100644 index 000000000..10f64ec94 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import errno + +# Disabling potentially dangerous functions +import os as _os +from functools import partial + +os_funcs_to_disable = [ + "kill", + "system", + "putenv", + "remove", + "removedirs", + "rmdir", + "fchdir", + "setuid", + "fork", + "forkpty", + "killpg", + "rename", + "renames", + "truncate", + "replace", + # "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly + "fchmod", + "fchown", + "chmod", + "chown", + "chroot", + "fchdir", + "lchflags", + "lchmod", + "lchown", + "chdir", +] + + +def call_not_allowed(*args, **kwargs): + raise OSError(errno.EPERM, "Call are not permitted in this environment") + + +for func_name in os_funcs_to_disable: + if hasattr(_os, func_name): + setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}")) + +import shutil as _shutil + +for func_name in ["rmtree", "move", "chown"]: + if hasattr(_shutil, func_name): + setattr( + _shutil, + func_name, + partial(call_not_allowed, _func_name=f"shutil.{func_name}"), + ) + +import subprocess as _subprocess + + +def popen_not_allowed(*args, **kwargs): + raise _subprocess.CalledProcessError( + -1, + args[0] if args else "unknown", + stderr="subprocess.Popen is not allowed in this environment", + ) + + +_subprocess.Popen = popen_not_allowed + + +import atexit as _atexit +import builtins as _builtins +import io as _io +import json as _json +import sys as _sys + +# NB! The following "unused" imports crucial, make sure not not to remove +# them with linters - they're used in code_execution.py +from contextlib import ( # noqa + contextmanager as _contextmanager, + redirect_stderr as _redirect_stderr, + redirect_stdout as _redirect_stdout, +) +from multiprocessing.connection import Connection as _Connection + +# Mangle imports to avoid polluting model execution namespace. + +_IO_SINK = _io.StringIO() +_NETWORK_TIMEOUT = 5 +_NETWORK_CONNECTIONS = None + + +def _open_connections(): + global _NETWORK_CONNECTIONS + if _NETWORK_CONNECTIONS is not None: + # Ensure connections only opened once. + return _NETWORK_CONNECTIONS + req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2] + req_con = _Connection(int(req_w_fd), readable=False) + resp_con = _Connection(int(resp_r_fd), writable=False) + _NETWORK_CONNECTIONS = (req_con, resp_con) + return _NETWORK_CONNECTIONS + + +_builtins._open_connections = _open_connections + + +@_atexit.register +def _close_connections(): + global _NETWORK_CONNECTIONS + if _NETWORK_CONNECTIONS is None: + return + for con in _NETWORK_CONNECTIONS: + con.close() + del _NETWORK_CONNECTIONS + + +def _network_call(request): + # NOTE: We communicate with the parent process in json, encoded + # in raw bytes. We do this because native send/recv methods use + # pickle which involves execution of arbitrary code. + _open_connections() + req_con, resp_con = _NETWORK_CONNECTIONS + + req_con.send_bytes(_json.dumps(request).encode("utf-8")) + if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None: + raise Exception(f"Network request timed out: {_json.dumps(request)}") + else: + return _json.loads(resp_con.recv_bytes().decode("utf-8")) diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py new file mode 100644 index 000000000..fa2e367e5 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import json +import multiprocessing +import os +import re +import subprocess +import sys +import tempfile +import textwrap +import time +from dataclasses import dataclass +from datetime import datetime +from io import BytesIO +from pathlib import Path +from typing import List + +from PIL import Image + +from .utils import get_code_env_prefix + +TOOLS_ATTACHMENT_KEY = "__tools_attachment__" +TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") + +DIRNAME = Path(__file__).parent + +CODE_EXEC_TIMEOUT = 20 +CODE_ENV_PREFIX = get_code_env_prefix() + +STDOUTERR_SINK_WRAPPER_TEMPLATE = """\ +with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK): +{code}\ +""" + +TRYEXCEPT_WRAPPER_TEMPLATE = """\ +try: +{code} +except: + pass\ +""" + + +def generate_bwrap_command(bind_dirs: List[str]) -> str: + """ + Generate the bwrap command string for binding all + directories in the current directory read-only. + """ + bwrap_args = "" + bwrap_args += "--ro-bind / / " + # Add the --dev flag to mount device files + bwrap_args += "--dev /dev " + for d in bind_dirs: + bwrap_args += f"--bind {d} {d} " + + # Add the --unshare-all flag to isolate the sandbox from the rest of the system + bwrap_args += "--unshare-all " + # Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies + bwrap_args += "--die-with-parent " + return bwrap_args + + +@dataclass +class CodeExecutionContext: + matplotlib_dump_dir: str + use_proxy: bool = False + + +@dataclass +class CodeExecutionRequest: + scripts: List[str] + only_last_cell_stdouterr: bool = True + only_last_cell_fail: bool = True + seed: int = 0 + strip_fpaths_in_stderr: bool = True + + +class CodeExecutor: + def __init__(self, context: CodeExecutionContext): + self.context = context + + def execute(self, req: CodeExecutionRequest) -> dict: + scripts = req.scripts + for i in range(len(scripts) - 1): + if req.only_last_cell_stdouterr: + scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format( + code=textwrap.indent(scripts[i], " " * 4) + ) + if req.only_last_cell_fail: + scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format( + code=textwrap.indent(scripts[i], " " * 4) + ) + + # Seeds prefix: + seed = req.seed + seeds_prefix = f"""\ +def _set_seeds(): + import random + random.seed({seed}) + import numpy as np + np.random.seed({seed}) +_set_seeds()\ +""" + + script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts) + with tempfile.TemporaryDirectory() as dpath: + bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath]) + cmd = [*bwrap_prefix.split(), sys.executable, "-c", script] + code_fpath = os.path.join(dpath, "code.py") + with open(code_fpath, "w") as f: + f.write(script) + + try: + python_path = os.environ.get("PYTHONPATH", "") + env = dict( + os.environ, + PYTHONHASHSEED=str(seed), + MPLCONFIGDIR=dpath, + MPLBACKEND="module://matplotlib_custom_backend", + PYTHONPATH=f"{DIRNAME}:{python_path}", + ) + stdout, stderr, returncode = do_subprocess( + cmd=cmd, + env=env, + ctx=self.context, + ) + + stderr = stderr.strip() + if req.strip_fpaths_in_stderr: + pattern = r'File "([^"]+)", line (\d+)' + stderr = re.sub(pattern, r"line \2", stderr) + + return { + "process_status": "completed", + "returncode": returncode, + "stdout": stdout.strip(), + "stderr": stderr, + } + + except subprocess.TimeoutExpired: + return { + "process_status": "timeout", + "stdout": "Timed out", + "stderr": "Timed out", + } + + except Exception as e: + return { + "process_status": "error", + "error_type": type(e).__name__, + "stderr": str(e), + "stdout": str(e), + } + + +def process_matplotlib_response(response, matplotlib_dump_dir: str): + image_data = response["image_data"] + # Convert the base64 string to a bytes object + images = [base64.b64decode(d["image_base64"]) for d in image_data] + # Create a list of PIL images from the bytes objects + images = [Image.open(BytesIO(img)) for img in images] + # Create a list of image paths + image_paths = [] + for i, img in enumerate(images): + # create new directory for each day to better organize data: + dump_dname = datetime.today().strftime("%Y-%m-%d") + dump_dpath = Path(matplotlib_dump_dir, dump_dname) + dump_dpath.mkdir(parents=True, exist_ok=True) + # save image into a file + dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png" + dump_fpath = dump_dpath / dump_fname + img.save(dump_fpath, "PNG") + image_paths.append(str(dump_fpath)) + + # this is kind of convoluted, we send back this response to the subprocess which + # prints it out + info = { + "filepath": str(image_paths[-1]), + "mimetype": "image/png", + } + return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}" + + +def execute_subprocess_request(request, ctx: CodeExecutionContext): + "Route requests from the subprocess (via network Pipes) to the internet/tools." + if request["type"] == "matplotlib": + return process_matplotlib_response(request, ctx.matplotlib_dump_dir) + else: + raise Exception(f'Unrecognised network request type: {request["type"]}') + + +def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext): + # Create Pipes to be used for any external tool/network requests. + req_r, req_w = multiprocessing.Pipe(duplex=False) + resp_r, resp_w = multiprocessing.Pipe(duplex=False) + + cmd += [str(req_w.fileno()), str(resp_r.fileno())] + proc = subprocess.Popen( + cmd, + pass_fds=(req_w.fileno(), resp_r.fileno()), + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + env=env, + ) + + # Close unnecessary fds. + req_w.close() + resp_r.close() + + pipe_close = False + done_read = False + start = time.monotonic() + while proc.poll() is None and not pipe_close: + if req_r.poll(0.1): + # NB: Python pipe semantics for poll and recv mean that + # poll() returns True is a pipe is closed. + # CF old school PEP from '09 + # https://bugs.python.org/issue5573 + try: + request = json.loads(req_r.recv_bytes().decode("utf-8")) + response = execute_subprocess_request(request, ctx) + + resp_w.send_bytes(json.dumps(response).encode("utf-8")) + except EOFError: + # The request pipe is closed - set a marker to exit + # after the next attempt at reading stdout/stderr. + pipe_close = True + + try: + # If lots has been printed, pipe might be full but + # proc cannot exit until all the stdout/stderr + # been written/read. + stdout, stderr = proc.communicate(timeout=0.3) + done_read = True + except subprocess.TimeoutExpired: + # The program has not terminated. Ignore it, there + # may be more network/tool requests. + continue + if time.monotonic() - start > CODE_EXEC_TIMEOUT: + proc.terminate() + raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT) + + if not done_read: + # Solve race condition where process terminates before + # we hit the while loop. + stdout, stderr = proc.communicate(timeout=0.3) + + resp_w.close() + req_r.close() + return stdout, stderr, proc.returncode diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py new file mode 100644 index 000000000..2e062d6d7 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import logging +import tempfile +from typing import Any, Dict, List + +from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor +from .config import CodeInterpreterToolConfig + +log = logging.getLogger(__name__) + + +class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: CodeInterpreterToolConfig): + self.config = config + ctx = CodeExecutionContext( + matplotlib_dump_dir=tempfile.mkdtemp(), + ) + self.code_executor = CodeExecutor(ctx) + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + if tool.identifier != "code_interpreter": + raise ValueError(f"Tool identifier {tool.identifier} is not supported") + + async def unregister_tool(self, tool_id: str) -> None: + return + + async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: + raise NotImplementedError("Code interpreter tool group not supported") + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + script = args["code"] + req = CodeExecutionRequest(scripts=[script]) + res = self.code_executor.execute(req) + pieces = [res["process_status"]] + for out_type in ["stdout", "stderr"]: + res_out = res[out_type] + if res_out != "": + pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"]) + if out_type == "stderr": + log.error(f"ipython tool error: ↓\n{res_out}") + return ToolInvocationResult(content="\n".join(pieces)) diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py new file mode 100644 index 000000000..167a2c318 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class CodeInterpreterToolConfig(BaseModel): + pass diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py new file mode 100644 index 000000000..d6f539a39 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +DIR = os.path.dirname(os.path.realpath(__file__)) +CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py") +CODE_ENV_PREFIX = None + + +def get_code_env_prefix() -> str: + global CODE_ENV_PREFIX + + if CODE_ENV_PREFIX is None: + with open(CODE_ENV_PREFIX_FILE, "r") as f: + CODE_ENV_PREFIX = f.read() + + return CODE_ENV_PREFIX diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 9058fb718..e4e61109f 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -41,6 +41,13 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.tool_runtime.tavily_search.config.TavilySearchToolConfig", provider_data_validator="llama_stack.providers.inline.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", ), + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::code-interpreter", + pip_packages=[], + module="llama_stack.providers.inline.tool_runtime.code_interpreter", + config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig", + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index c0690e4e3..ca44325d7 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -84,6 +84,11 @@ def tool_runtime_memory() -> ProviderFixture: "api_key": os.environ["TAVILY_SEARCH_API_KEY"], }, ), + Provider( + provider_id="code-interpreter", + provider_type="inline::code-interpreter", + config={}, + ), ], ) @@ -221,6 +226,20 @@ async def agents_stack(request, inference_model, safety_shield): ), provider_id="memory-runtime", ), + ToolGroupInput( + tool_group_id="code_interpreter_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + ToolDef( + name="code_interpreter", + description="code_interpreter", + parameters=[], + metadata={}, + ) + ], + ), + provider_id="code-interpreter", + ), ] test_stack = await construct_stack_for_test( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index a7b08239b..4e335d8d3 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -110,3 +110,29 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): assert "Tool:brave_search Response:" in logs_str assert "mark zuckerberg" in logs_str.lower() assert "No Violation" in logs_str + + +def test_builtin_tool_code_execution(llama_stack_client, agent_config): + agent_config = { + **agent_config, + "available_tools": [ + "code_interpreter", + ], + } + agent = Agent(llama_stack_client, agent_config) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "Write code to answer the question: What is the 100th prime number?", + }, + ], + session_id=session_id, + ) + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + + assert "541" in logs_str + assert "Tool:code_interpreter Response" in logs_str diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 2366008dd..28808ae4c 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -6,8 +6,8 @@ import os import pytest -from llama_stack import LlamaStackAsLibraryClient +from llama_stack import LlamaStackAsLibraryClient from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client import LlamaStackClient