mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
add code interpreter
This commit is contained in:
parent
0155700ea6
commit
40f35f3a8d
10 changed files with 545 additions and 1 deletions
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||||
|
from .config import CodeInterpreterToolConfig
|
||||||
|
|
||||||
|
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||||
|
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,133 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import errno
|
||||||
|
|
||||||
|
# Disabling potentially dangerous functions
|
||||||
|
import os as _os
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
os_funcs_to_disable = [
|
||||||
|
"kill",
|
||||||
|
"system",
|
||||||
|
"putenv",
|
||||||
|
"remove",
|
||||||
|
"removedirs",
|
||||||
|
"rmdir",
|
||||||
|
"fchdir",
|
||||||
|
"setuid",
|
||||||
|
"fork",
|
||||||
|
"forkpty",
|
||||||
|
"killpg",
|
||||||
|
"rename",
|
||||||
|
"renames",
|
||||||
|
"truncate",
|
||||||
|
"replace",
|
||||||
|
# "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
|
||||||
|
"fchmod",
|
||||||
|
"fchown",
|
||||||
|
"chmod",
|
||||||
|
"chown",
|
||||||
|
"chroot",
|
||||||
|
"fchdir",
|
||||||
|
"lchflags",
|
||||||
|
"lchmod",
|
||||||
|
"lchown",
|
||||||
|
"chdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def call_not_allowed(*args, **kwargs):
|
||||||
|
raise OSError(errno.EPERM, "Call are not permitted in this environment")
|
||||||
|
|
||||||
|
|
||||||
|
for func_name in os_funcs_to_disable:
|
||||||
|
if hasattr(_os, func_name):
|
||||||
|
setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
|
||||||
|
|
||||||
|
import shutil as _shutil
|
||||||
|
|
||||||
|
for func_name in ["rmtree", "move", "chown"]:
|
||||||
|
if hasattr(_shutil, func_name):
|
||||||
|
setattr(
|
||||||
|
_shutil,
|
||||||
|
func_name,
|
||||||
|
partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
import subprocess as _subprocess
|
||||||
|
|
||||||
|
|
||||||
|
def popen_not_allowed(*args, **kwargs):
|
||||||
|
raise _subprocess.CalledProcessError(
|
||||||
|
-1,
|
||||||
|
args[0] if args else "unknown",
|
||||||
|
stderr="subprocess.Popen is not allowed in this environment",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_subprocess.Popen = popen_not_allowed
|
||||||
|
|
||||||
|
|
||||||
|
import atexit as _atexit
|
||||||
|
import builtins as _builtins
|
||||||
|
import io as _io
|
||||||
|
import json as _json
|
||||||
|
import sys as _sys
|
||||||
|
|
||||||
|
# NB! The following "unused" imports crucial, make sure not not to remove
|
||||||
|
# them with linters - they're used in code_execution.py
|
||||||
|
from contextlib import ( # noqa
|
||||||
|
contextmanager as _contextmanager,
|
||||||
|
redirect_stderr as _redirect_stderr,
|
||||||
|
redirect_stdout as _redirect_stdout,
|
||||||
|
)
|
||||||
|
from multiprocessing.connection import Connection as _Connection
|
||||||
|
|
||||||
|
# Mangle imports to avoid polluting model execution namespace.
|
||||||
|
|
||||||
|
_IO_SINK = _io.StringIO()
|
||||||
|
_NETWORK_TIMEOUT = 5
|
||||||
|
_NETWORK_CONNECTIONS = None
|
||||||
|
|
||||||
|
|
||||||
|
def _open_connections():
|
||||||
|
global _NETWORK_CONNECTIONS
|
||||||
|
if _NETWORK_CONNECTIONS is not None:
|
||||||
|
# Ensure connections only opened once.
|
||||||
|
return _NETWORK_CONNECTIONS
|
||||||
|
req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
|
||||||
|
req_con = _Connection(int(req_w_fd), readable=False)
|
||||||
|
resp_con = _Connection(int(resp_r_fd), writable=False)
|
||||||
|
_NETWORK_CONNECTIONS = (req_con, resp_con)
|
||||||
|
return _NETWORK_CONNECTIONS
|
||||||
|
|
||||||
|
|
||||||
|
_builtins._open_connections = _open_connections
|
||||||
|
|
||||||
|
|
||||||
|
@_atexit.register
|
||||||
|
def _close_connections():
|
||||||
|
global _NETWORK_CONNECTIONS
|
||||||
|
if _NETWORK_CONNECTIONS is None:
|
||||||
|
return
|
||||||
|
for con in _NETWORK_CONNECTIONS:
|
||||||
|
con.close()
|
||||||
|
del _NETWORK_CONNECTIONS
|
||||||
|
|
||||||
|
|
||||||
|
def _network_call(request):
|
||||||
|
# NOTE: We communicate with the parent process in json, encoded
|
||||||
|
# in raw bytes. We do this because native send/recv methods use
|
||||||
|
# pickle which involves execution of arbitrary code.
|
||||||
|
_open_connections()
|
||||||
|
req_con, resp_con = _NETWORK_CONNECTIONS
|
||||||
|
|
||||||
|
req_con.send_bytes(_json.dumps(request).encode("utf-8"))
|
||||||
|
if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
|
||||||
|
raise Exception(f"Network request timed out: {_json.dumps(request)}")
|
||||||
|
else:
|
||||||
|
return _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
|
@ -0,0 +1,256 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import textwrap
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .utils import get_code_env_prefix
|
||||||
|
|
||||||
|
TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
|
||||||
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
|
|
||||||
|
DIRNAME = Path(__file__).parent
|
||||||
|
|
||||||
|
CODE_EXEC_TIMEOUT = 20
|
||||||
|
CODE_ENV_PREFIX = get_code_env_prefix()
|
||||||
|
|
||||||
|
STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
|
||||||
|
with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
|
||||||
|
{code}\
|
||||||
|
"""
|
||||||
|
|
||||||
|
TRYEXCEPT_WRAPPER_TEMPLATE = """\
|
||||||
|
try:
|
||||||
|
{code}
|
||||||
|
except:
|
||||||
|
pass\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_bwrap_command(bind_dirs: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Generate the bwrap command string for binding all
|
||||||
|
directories in the current directory read-only.
|
||||||
|
"""
|
||||||
|
bwrap_args = ""
|
||||||
|
bwrap_args += "--ro-bind / / "
|
||||||
|
# Add the --dev flag to mount device files
|
||||||
|
bwrap_args += "--dev /dev "
|
||||||
|
for d in bind_dirs:
|
||||||
|
bwrap_args += f"--bind {d} {d} "
|
||||||
|
|
||||||
|
# Add the --unshare-all flag to isolate the sandbox from the rest of the system
|
||||||
|
bwrap_args += "--unshare-all "
|
||||||
|
# Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
|
||||||
|
bwrap_args += "--die-with-parent "
|
||||||
|
return bwrap_args
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeExecutionContext:
|
||||||
|
matplotlib_dump_dir: str
|
||||||
|
use_proxy: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeExecutionRequest:
|
||||||
|
scripts: List[str]
|
||||||
|
only_last_cell_stdouterr: bool = True
|
||||||
|
only_last_cell_fail: bool = True
|
||||||
|
seed: int = 0
|
||||||
|
strip_fpaths_in_stderr: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class CodeExecutor:
|
||||||
|
def __init__(self, context: CodeExecutionContext):
|
||||||
|
self.context = context
|
||||||
|
|
||||||
|
def execute(self, req: CodeExecutionRequest) -> dict:
|
||||||
|
scripts = req.scripts
|
||||||
|
for i in range(len(scripts) - 1):
|
||||||
|
if req.only_last_cell_stdouterr:
|
||||||
|
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
|
||||||
|
code=textwrap.indent(scripts[i], " " * 4)
|
||||||
|
)
|
||||||
|
if req.only_last_cell_fail:
|
||||||
|
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
|
||||||
|
code=textwrap.indent(scripts[i], " " * 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Seeds prefix:
|
||||||
|
seed = req.seed
|
||||||
|
seeds_prefix = f"""\
|
||||||
|
def _set_seeds():
|
||||||
|
import random
|
||||||
|
random.seed({seed})
|
||||||
|
import numpy as np
|
||||||
|
np.random.seed({seed})
|
||||||
|
_set_seeds()\
|
||||||
|
"""
|
||||||
|
|
||||||
|
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
|
||||||
|
with tempfile.TemporaryDirectory() as dpath:
|
||||||
|
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||||
|
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||||
|
code_fpath = os.path.join(dpath, "code.py")
|
||||||
|
with open(code_fpath, "w") as f:
|
||||||
|
f.write(script)
|
||||||
|
|
||||||
|
try:
|
||||||
|
python_path = os.environ.get("PYTHONPATH", "")
|
||||||
|
env = dict(
|
||||||
|
os.environ,
|
||||||
|
PYTHONHASHSEED=str(seed),
|
||||||
|
MPLCONFIGDIR=dpath,
|
||||||
|
MPLBACKEND="module://matplotlib_custom_backend",
|
||||||
|
PYTHONPATH=f"{DIRNAME}:{python_path}",
|
||||||
|
)
|
||||||
|
stdout, stderr, returncode = do_subprocess(
|
||||||
|
cmd=cmd,
|
||||||
|
env=env,
|
||||||
|
ctx=self.context,
|
||||||
|
)
|
||||||
|
|
||||||
|
stderr = stderr.strip()
|
||||||
|
if req.strip_fpaths_in_stderr:
|
||||||
|
pattern = r'File "([^"]+)", line (\d+)'
|
||||||
|
stderr = re.sub(pattern, r"line \2", stderr)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"process_status": "completed",
|
||||||
|
"returncode": returncode,
|
||||||
|
"stdout": stdout.strip(),
|
||||||
|
"stderr": stderr,
|
||||||
|
}
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return {
|
||||||
|
"process_status": "timeout",
|
||||||
|
"stdout": "Timed out",
|
||||||
|
"stderr": "Timed out",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"process_status": "error",
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
"stderr": str(e),
|
||||||
|
"stdout": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||||
|
image_data = response["image_data"]
|
||||||
|
# Convert the base64 string to a bytes object
|
||||||
|
images = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||||
|
# Create a list of PIL images from the bytes objects
|
||||||
|
images = [Image.open(BytesIO(img)) for img in images]
|
||||||
|
# Create a list of image paths
|
||||||
|
image_paths = []
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
# create new directory for each day to better organize data:
|
||||||
|
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
||||||
|
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||||
|
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||||
|
# save image into a file
|
||||||
|
dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
|
||||||
|
dump_fpath = dump_dpath / dump_fname
|
||||||
|
img.save(dump_fpath, "PNG")
|
||||||
|
image_paths.append(str(dump_fpath))
|
||||||
|
|
||||||
|
# this is kind of convoluted, we send back this response to the subprocess which
|
||||||
|
# prints it out
|
||||||
|
info = {
|
||||||
|
"filepath": str(image_paths[-1]),
|
||||||
|
"mimetype": "image/png",
|
||||||
|
}
|
||||||
|
return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
|
||||||
|
|
||||||
|
|
||||||
|
def execute_subprocess_request(request, ctx: CodeExecutionContext):
|
||||||
|
"Route requests from the subprocess (via network Pipes) to the internet/tools."
|
||||||
|
if request["type"] == "matplotlib":
|
||||||
|
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
|
||||||
|
else:
|
||||||
|
raise Exception(f'Unrecognised network request type: {request["type"]}')
|
||||||
|
|
||||||
|
|
||||||
|
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
|
||||||
|
# Create Pipes to be used for any external tool/network requests.
|
||||||
|
req_r, req_w = multiprocessing.Pipe(duplex=False)
|
||||||
|
resp_r, resp_w = multiprocessing.Pipe(duplex=False)
|
||||||
|
|
||||||
|
cmd += [str(req_w.fileno()), str(resp_r.fileno())]
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
pass_fds=(req_w.fileno(), resp_r.fileno()),
|
||||||
|
text=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
close_fds=True,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close unnecessary fds.
|
||||||
|
req_w.close()
|
||||||
|
resp_r.close()
|
||||||
|
|
||||||
|
pipe_close = False
|
||||||
|
done_read = False
|
||||||
|
start = time.monotonic()
|
||||||
|
while proc.poll() is None and not pipe_close:
|
||||||
|
if req_r.poll(0.1):
|
||||||
|
# NB: Python pipe semantics for poll and recv mean that
|
||||||
|
# poll() returns True is a pipe is closed.
|
||||||
|
# CF old school PEP from '09
|
||||||
|
# https://bugs.python.org/issue5573
|
||||||
|
try:
|
||||||
|
request = json.loads(req_r.recv_bytes().decode("utf-8"))
|
||||||
|
response = execute_subprocess_request(request, ctx)
|
||||||
|
|
||||||
|
resp_w.send_bytes(json.dumps(response).encode("utf-8"))
|
||||||
|
except EOFError:
|
||||||
|
# The request pipe is closed - set a marker to exit
|
||||||
|
# after the next attempt at reading stdout/stderr.
|
||||||
|
pipe_close = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If lots has been printed, pipe might be full but
|
||||||
|
# proc cannot exit until all the stdout/stderr
|
||||||
|
# been written/read.
|
||||||
|
stdout, stderr = proc.communicate(timeout=0.3)
|
||||||
|
done_read = True
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
# The program has not terminated. Ignore it, there
|
||||||
|
# may be more network/tool requests.
|
||||||
|
continue
|
||||||
|
if time.monotonic() - start > CODE_EXEC_TIMEOUT:
|
||||||
|
proc.terminate()
|
||||||
|
raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
|
||||||
|
|
||||||
|
if not done_read:
|
||||||
|
# Solve race condition where process terminates before
|
||||||
|
# we hit the while loop.
|
||||||
|
stdout, stderr = proc.communicate(timeout=0.3)
|
||||||
|
|
||||||
|
resp_w.close()
|
||||||
|
req_r.close()
|
||||||
|
return stdout, stderr, proc.returncode
|
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
|
||||||
|
from .config import CodeInterpreterToolConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
def __init__(self, config: CodeInterpreterToolConfig):
|
||||||
|
self.config = config
|
||||||
|
ctx = CodeExecutionContext(
|
||||||
|
matplotlib_dump_dir=tempfile.mkdtemp(),
|
||||||
|
)
|
||||||
|
self.code_executor = CodeExecutor(ctx)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
if tool.identifier != "code_interpreter":
|
||||||
|
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||||
|
raise NotImplementedError("Code interpreter tool group not supported")
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
script = args["code"]
|
||||||
|
req = CodeExecutionRequest(scripts=[script])
|
||||||
|
res = self.code_executor.execute(req)
|
||||||
|
pieces = [res["process_status"]]
|
||||||
|
for out_type in ["stdout", "stderr"]:
|
||||||
|
res_out = res[out_type]
|
||||||
|
if res_out != "":
|
||||||
|
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||||
|
if out_type == "stderr":
|
||||||
|
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||||
|
return ToolInvocationResult(content="\n".join(pieces))
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterToolConfig(BaseModel):
|
||||||
|
pass
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
|
||||||
|
CODE_ENV_PREFIX = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_code_env_prefix() -> str:
|
||||||
|
global CODE_ENV_PREFIX
|
||||||
|
|
||||||
|
if CODE_ENV_PREFIX is None:
|
||||||
|
with open(CODE_ENV_PREFIX_FILE, "r") as f:
|
||||||
|
CODE_ENV_PREFIX = f.read()
|
||||||
|
|
||||||
|
return CODE_ENV_PREFIX
|
|
@ -41,6 +41,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||||
provider_data_validator="llama_stack.providers.inline.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.inline.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
provider_type="inline::code-interpreter",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
|
||||||
|
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -84,6 +84,11 @@ def tool_runtime_memory() -> ProviderFixture:
|
||||||
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
provider_type="inline::code-interpreter",
|
||||||
|
config={},
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -221,6 +226,20 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
),
|
),
|
||||||
provider_id="memory-runtime",
|
provider_id="memory-runtime",
|
||||||
),
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
tool_group_id="code_interpreter_group",
|
||||||
|
tool_group=UserDefinedToolGroupDef(
|
||||||
|
tools=[
|
||||||
|
ToolDef(
|
||||||
|
name="code_interpreter",
|
||||||
|
description="code_interpreter",
|
||||||
|
parameters=[],
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
|
|
|
@ -110,3 +110,29 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||||
assert "Tool:brave_search Response:" in logs_str
|
assert "Tool:brave_search Response:" in logs_str
|
||||||
assert "mark zuckerberg" in logs_str.lower()
|
assert "mark zuckerberg" in logs_str.lower()
|
||||||
assert "No Violation" in logs_str
|
assert "No Violation" in logs_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||||
|
agent_config = {
|
||||||
|
**agent_config,
|
||||||
|
"available_tools": [
|
||||||
|
"code_interpreter",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Write code to answer the question: What is the 100th prime number?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
|
logs_str = "".join(logs)
|
||||||
|
|
||||||
|
assert "541" in logs_str
|
||||||
|
assert "Tool:code_interpreter Response" in logs_str
|
||||||
|
|
|
@ -6,8 +6,8 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
|
||||||
|
|
||||||
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
from llama_stack.providers.tests.env import get_env_or_fail
|
from llama_stack.providers.tests.env import get_env_or_fail
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue