mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-13 00:26:10 +00:00
agents to use tools api (#673)
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
parent
596afc6497
commit
a5c57cd381
116 changed files with 4959 additions and 2778 deletions
5
llama_stack/providers/inline/tool_runtime/__init__.py
Normal file
5
llama_stack/providers/inline/tool_runtime/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .brave_search import BraveSearchToolRuntimeImpl
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
||||
impl = BraveSearchToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,123 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolRuntimeImpl(
|
||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: BraveSearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier != "brave_search":
|
||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
raise ValueError(
|
||||
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
raise NotImplementedError("Brave search tool group not supported")
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept-Encoding": "gzip",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": args["query"]}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
results = self._clean_brave_response(response.json())
|
||||
content_items = "\n".join([str(result) for result in results])
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
)
|
||||
|
||||
def _clean_brave_response(self, search_response):
|
||||
clean_response = []
|
||||
if "mixed" in search_response:
|
||||
mixed_results = search_response["mixed"]
|
||||
for m in mixed_results["main"][: self.config.max_results]:
|
||||
r_type = m["type"]
|
||||
results = search_response[r_type]["results"]
|
||||
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
|
||||
clean_response.append(cleaned)
|
||||
|
||||
return clean_response
|
||||
|
||||
def _clean_result_by_type(self, r_type, results, idx=None):
|
||||
type_cleaners = {
|
||||
"web": (
|
||||
["type", "title", "url", "description", "date", "extra_snippets"],
|
||||
lambda x: x[idx],
|
||||
),
|
||||
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
|
||||
"infobox": (
|
||||
["type", "title", "url", "description", "long_desc"],
|
||||
lambda x: x[idx],
|
||||
),
|
||||
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
|
||||
"locations": (
|
||||
[
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"coordinates",
|
||||
"postal_address",
|
||||
"contact",
|
||||
"rating",
|
||||
"distance",
|
||||
"zoom_level",
|
||||
],
|
||||
lambda x: x,
|
||||
),
|
||||
"news": (["type", "title", "url", "description"], lambda x: x),
|
||||
}
|
||||
|
||||
if r_type not in type_cleaners:
|
||||
return ""
|
||||
|
||||
selected_keys, result_selector = type_cleaners[r_type]
|
||||
results = result_selector(results)
|
||||
|
||||
if isinstance(results, list):
|
||||
cleaned = [
|
||||
{k: v for k, v in item.items() if k in selected_keys}
|
||||
for item in results
|
||||
]
|
||||
else:
|
||||
cleaned = {k: v for k, v in results.items() if k in selected_keys}
|
||||
|
||||
return str(cleaned)
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BraveSearchToolConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Brave Search API Key",
|
||||
)
|
||||
max_results: int = Field(
|
||||
default=3,
|
||||
description="The maximum number of results to return",
|
||||
)
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
|
||||
# Disabling potentially dangerous functions
|
||||
import os as _os
|
||||
from functools import partial
|
||||
|
||||
os_funcs_to_disable = [
|
||||
"kill",
|
||||
"system",
|
||||
"putenv",
|
||||
"remove",
|
||||
"removedirs",
|
||||
"rmdir",
|
||||
"fchdir",
|
||||
"setuid",
|
||||
"fork",
|
||||
"forkpty",
|
||||
"killpg",
|
||||
"rename",
|
||||
"renames",
|
||||
"truncate",
|
||||
"replace",
|
||||
# "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
|
||||
"fchmod",
|
||||
"fchown",
|
||||
"chmod",
|
||||
"chown",
|
||||
"chroot",
|
||||
"fchdir",
|
||||
"lchflags",
|
||||
"lchmod",
|
||||
"lchown",
|
||||
"chdir",
|
||||
]
|
||||
|
||||
|
||||
def call_not_allowed(*args, **kwargs):
|
||||
raise OSError(errno.EPERM, "Call are not permitted in this environment")
|
||||
|
||||
|
||||
for func_name in os_funcs_to_disable:
|
||||
if hasattr(_os, func_name):
|
||||
setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
|
||||
|
||||
import shutil as _shutil
|
||||
|
||||
for func_name in ["rmtree", "move", "chown"]:
|
||||
if hasattr(_shutil, func_name):
|
||||
setattr(
|
||||
_shutil,
|
||||
func_name,
|
||||
partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
|
||||
)
|
||||
|
||||
import subprocess as _subprocess
|
||||
|
||||
|
||||
def popen_not_allowed(*args, **kwargs):
|
||||
raise _subprocess.CalledProcessError(
|
||||
-1,
|
||||
args[0] if args else "unknown",
|
||||
stderr="subprocess.Popen is not allowed in this environment",
|
||||
)
|
||||
|
||||
|
||||
_subprocess.Popen = popen_not_allowed
|
||||
|
||||
|
||||
import atexit as _atexit
|
||||
import builtins as _builtins
|
||||
import io as _io
|
||||
import json as _json
|
||||
import sys as _sys
|
||||
|
||||
# NB! The following "unused" imports crucial, make sure not not to remove
|
||||
# them with linters - they're used in code_execution.py
|
||||
from contextlib import ( # noqa
|
||||
contextmanager as _contextmanager,
|
||||
redirect_stderr as _redirect_stderr,
|
||||
redirect_stdout as _redirect_stdout,
|
||||
)
|
||||
from multiprocessing.connection import Connection as _Connection
|
||||
|
||||
# Mangle imports to avoid polluting model execution namespace.
|
||||
|
||||
_IO_SINK = _io.StringIO()
|
||||
_NETWORK_TIMEOUT = 5
|
||||
_NETWORK_CONNECTIONS = None
|
||||
|
||||
|
||||
def _open_connections():
|
||||
global _NETWORK_CONNECTIONS
|
||||
if _NETWORK_CONNECTIONS is not None:
|
||||
# Ensure connections only opened once.
|
||||
return _NETWORK_CONNECTIONS
|
||||
req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
|
||||
req_con = _Connection(int(req_w_fd), readable=False)
|
||||
resp_con = _Connection(int(resp_r_fd), writable=False)
|
||||
_NETWORK_CONNECTIONS = (req_con, resp_con)
|
||||
return _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
_builtins._open_connections = _open_connections
|
||||
|
||||
|
||||
@_atexit.register
|
||||
def _close_connections():
|
||||
global _NETWORK_CONNECTIONS
|
||||
if _NETWORK_CONNECTIONS is None:
|
||||
return
|
||||
for con in _NETWORK_CONNECTIONS:
|
||||
con.close()
|
||||
del _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
def _network_call(request):
|
||||
# NOTE: We communicate with the parent process in json, encoded
|
||||
# in raw bytes. We do this because native send/recv methods use
|
||||
# pickle which involves execution of arbitrary code.
|
||||
_open_connections()
|
||||
req_con, resp_con = _NETWORK_CONNECTIONS
|
||||
|
||||
req_con.send_bytes(_json.dumps(request).encode("utf-8"))
|
||||
if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
|
||||
raise Exception(f"Network request timed out: {_json.dumps(request)}")
|
||||
else:
|
||||
return _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .utils import get_code_env_prefix
|
||||
|
||||
TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
|
||||
DIRNAME = Path(__file__).parent
|
||||
|
||||
CODE_EXEC_TIMEOUT = 20
|
||||
CODE_ENV_PREFIX = get_code_env_prefix()
|
||||
|
||||
STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
|
||||
with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
|
||||
{code}\
|
||||
"""
|
||||
|
||||
TRYEXCEPT_WRAPPER_TEMPLATE = """\
|
||||
try:
|
||||
{code}
|
||||
except:
|
||||
pass\
|
||||
"""
|
||||
|
||||
|
||||
def generate_bwrap_command(bind_dirs: List[str]) -> str:
|
||||
"""
|
||||
Generate the bwrap command string for binding all
|
||||
directories in the current directory read-only.
|
||||
"""
|
||||
bwrap_args = ""
|
||||
bwrap_args += "--ro-bind / / "
|
||||
# Add the --dev flag to mount device files
|
||||
bwrap_args += "--dev /dev "
|
||||
for d in bind_dirs:
|
||||
bwrap_args += f"--bind {d} {d} "
|
||||
|
||||
# Add the --unshare-all flag to isolate the sandbox from the rest of the system
|
||||
bwrap_args += "--unshare-all "
|
||||
# Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
|
||||
bwrap_args += "--die-with-parent "
|
||||
return bwrap_args
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeExecutionContext:
|
||||
matplotlib_dump_dir: str
|
||||
use_proxy: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeExecutionRequest:
|
||||
scripts: List[str]
|
||||
only_last_cell_stdouterr: bool = True
|
||||
only_last_cell_fail: bool = True
|
||||
seed: int = 0
|
||||
strip_fpaths_in_stderr: bool = True
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
def __init__(self, context: CodeExecutionContext):
|
||||
self.context = context
|
||||
|
||||
def execute(self, req: CodeExecutionRequest) -> dict:
|
||||
scripts = req.scripts
|
||||
for i in range(len(scripts) - 1):
|
||||
if req.only_last_cell_stdouterr:
|
||||
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
|
||||
code=textwrap.indent(scripts[i], " " * 4)
|
||||
)
|
||||
if req.only_last_cell_fail:
|
||||
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
|
||||
code=textwrap.indent(scripts[i], " " * 4)
|
||||
)
|
||||
|
||||
# Seeds prefix:
|
||||
seed = req.seed
|
||||
seeds_prefix = f"""\
|
||||
def _set_seeds():
|
||||
import random
|
||||
random.seed({seed})
|
||||
import numpy as np
|
||||
np.random.seed({seed})
|
||||
_set_seeds()\
|
||||
"""
|
||||
|
||||
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
|
||||
with tempfile.TemporaryDirectory() as dpath:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
code_fpath = os.path.join(dpath, "code.py")
|
||||
with open(code_fpath, "w") as f:
|
||||
f.write(script)
|
||||
|
||||
try:
|
||||
python_path = os.environ.get("PYTHONPATH", "")
|
||||
env = dict(
|
||||
os.environ,
|
||||
PYTHONHASHSEED=str(seed),
|
||||
MPLCONFIGDIR=dpath,
|
||||
MPLBACKEND="module://matplotlib_custom_backend",
|
||||
PYTHONPATH=f"{DIRNAME}:{python_path}",
|
||||
)
|
||||
stdout, stderr, returncode = do_subprocess(
|
||||
cmd=cmd,
|
||||
env=env,
|
||||
ctx=self.context,
|
||||
)
|
||||
|
||||
stderr = stderr.strip()
|
||||
if req.strip_fpaths_in_stderr:
|
||||
pattern = r'File "([^"]+)", line (\d+)'
|
||||
stderr = re.sub(pattern, r"line \2", stderr)
|
||||
|
||||
return {
|
||||
"process_status": "completed",
|
||||
"returncode": returncode,
|
||||
"stdout": stdout.strip(),
|
||||
"stderr": stderr,
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"process_status": "timeout",
|
||||
"stdout": "Timed out",
|
||||
"stderr": "Timed out",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"process_status": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"stderr": str(e),
|
||||
"stdout": str(e),
|
||||
}
|
||||
|
||||
|
||||
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||
image_data = response["image_data"]
|
||||
# Convert the base64 string to a bytes object
|
||||
images = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||
# Create a list of PIL images from the bytes objects
|
||||
images = [Image.open(BytesIO(img)) for img in images]
|
||||
# Create a list of image paths
|
||||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
# create new directory for each day to better organize data:
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
||||
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||
# save image into a file
|
||||
dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
|
||||
dump_fpath = dump_dpath / dump_fname
|
||||
img.save(dump_fpath, "PNG")
|
||||
image_paths.append(str(dump_fpath))
|
||||
|
||||
# this is kind of convoluted, we send back this response to the subprocess which
|
||||
# prints it out
|
||||
info = {
|
||||
"filepath": str(image_paths[-1]),
|
||||
"mimetype": "image/png",
|
||||
}
|
||||
return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
|
||||
|
||||
|
||||
def execute_subprocess_request(request, ctx: CodeExecutionContext):
|
||||
"Route requests from the subprocess (via network Pipes) to the internet/tools."
|
||||
if request["type"] == "matplotlib":
|
||||
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
|
||||
else:
|
||||
raise Exception(f'Unrecognised network request type: {request["type"]}')
|
||||
|
||||
|
||||
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
|
||||
# Create Pipes to be used for any external tool/network requests.
|
||||
req_r, req_w = multiprocessing.Pipe(duplex=False)
|
||||
resp_r, resp_w = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
cmd += [str(req_w.fileno()), str(resp_r.fileno())]
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
pass_fds=(req_w.fileno(), resp_r.fileno()),
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
close_fds=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Close unnecessary fds.
|
||||
req_w.close()
|
||||
resp_r.close()
|
||||
|
||||
pipe_close = False
|
||||
done_read = False
|
||||
start = time.monotonic()
|
||||
while proc.poll() is None and not pipe_close:
|
||||
if req_r.poll(0.1):
|
||||
# NB: Python pipe semantics for poll and recv mean that
|
||||
# poll() returns True is a pipe is closed.
|
||||
# CF old school PEP from '09
|
||||
# https://bugs.python.org/issue5573
|
||||
try:
|
||||
request = json.loads(req_r.recv_bytes().decode("utf-8"))
|
||||
response = execute_subprocess_request(request, ctx)
|
||||
|
||||
resp_w.send_bytes(json.dumps(response).encode("utf-8"))
|
||||
except EOFError:
|
||||
# The request pipe is closed - set a marker to exit
|
||||
# after the next attempt at reading stdout/stderr.
|
||||
pipe_close = True
|
||||
|
||||
try:
|
||||
# If lots has been printed, pipe might be full but
|
||||
# proc cannot exit until all the stdout/stderr
|
||||
# been written/read.
|
||||
stdout, stderr = proc.communicate(timeout=0.3)
|
||||
done_read = True
|
||||
except subprocess.TimeoutExpired:
|
||||
# The program has not terminated. Ignore it, there
|
||||
# may be more network/tool requests.
|
||||
continue
|
||||
if time.monotonic() - start > CODE_EXEC_TIMEOUT:
|
||||
proc.terminate()
|
||||
raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
|
||||
|
||||
if not done_read:
|
||||
# Solve race condition where process terminates before
|
||||
# we hit the while loop.
|
||||
stdout, stderr = proc.communicate(timeout=0.3)
|
||||
|
||||
resp_w.close()
|
||||
req_r.close()
|
||||
return stdout, stderr, proc.returncode
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(self, config: CodeInterpreterToolConfig):
|
||||
self.config = config
|
||||
ctx = CodeExecutionContext(
|
||||
matplotlib_dump_dir=tempfile.mkdtemp(),
|
||||
)
|
||||
self.code_executor = CodeExecutor(ctx)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="code_interpreter",
|
||||
description="Execute code",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="code",
|
||||
description="The code to execute",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
script = args["code"]
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
res = self.code_executor.execute(req)
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
res_out = res[out_type]
|
||||
if res_out != "":
|
||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||
if out_type == "stderr":
|
||||
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||
return ToolInvocationResult(content="\n".join(pieces))
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(BaseModel):
|
||||
pass
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
A custom Matplotlib backend that overrides the show method to return image bytes.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json as _json
|
||||
import logging
|
||||
|
||||
import matplotlib
|
||||
from matplotlib.backend_bases import FigureManagerBase
|
||||
|
||||
# Import necessary components from Matplotlib
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomFigureCanvas(FigureCanvasAgg):
|
||||
def show(self):
|
||||
# Save the figure to a BytesIO object
|
||||
buf = io.BytesIO()
|
||||
self.print_png(buf)
|
||||
image_bytes = buf.getvalue()
|
||||
buf.close()
|
||||
return image_bytes
|
||||
|
||||
|
||||
class CustomFigureManager(FigureManagerBase):
|
||||
def __init__(self, canvas, num):
|
||||
super().__init__(canvas, num)
|
||||
|
||||
|
||||
# Mimic module initialization that integrates with the Matplotlib backend system
|
||||
def _create_figure_manager(num, *args, **kwargs):
|
||||
"""
|
||||
Create a custom figure manager instance.
|
||||
"""
|
||||
FigureClass = kwargs.pop("FigureClass", None) # noqa: N806
|
||||
if FigureClass is None:
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
FigureClass = Figure # noqa: N806
|
||||
fig = FigureClass(*args, **kwargs)
|
||||
canvas = CustomFigureCanvas(fig)
|
||||
manager = CustomFigureManager(canvas, num)
|
||||
return manager
|
||||
|
||||
|
||||
def show():
|
||||
"""
|
||||
Handle all figures and potentially return their images as bytes.
|
||||
|
||||
This function iterates over all figures registered with the custom backend,
|
||||
renders them as images in bytes format, and could return a list of bytes objects,
|
||||
one for each figure, or handle them as needed.
|
||||
"""
|
||||
image_data = []
|
||||
for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers():
|
||||
# Get the figure from the manager
|
||||
fig = manager.canvas.figure
|
||||
buf = io.BytesIO() # Create a buffer for the figure
|
||||
fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format
|
||||
buf.seek(0) # Go to the beginning of the buffer
|
||||
image_bytes = buf.getvalue() # Retrieve bytes value
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_data.append({"image_base64": image_base64})
|
||||
buf.close()
|
||||
|
||||
req_con, resp_con = _open_connections()
|
||||
|
||||
_json_dump = _json.dumps(
|
||||
{
|
||||
"type": "matplotlib",
|
||||
"image_data": image_data,
|
||||
}
|
||||
)
|
||||
req_con.send_bytes(_json_dump.encode("utf-8"))
|
||||
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
||||
log.info(resp)
|
||||
|
||||
|
||||
FigureCanvas = CustomFigureCanvas
|
||||
FigureManager = CustomFigureManager
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
|
||||
CODE_ENV_PREFIX = None
|
||||
|
||||
|
||||
def get_code_env_prefix() -> str:
|
||||
global CODE_ENV_PREFIX
|
||||
|
||||
if CODE_ENV_PREFIX is None:
|
||||
with open(CODE_ENV_PREFIX_FILE, "r") as f:
|
||||
CODE_ENV_PREFIX = f.read()
|
||||
|
||||
return CODE_ENV_PREFIX
|
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MemoryToolRuntimeConfig
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
|
||||
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
impl = MemoryToolRuntimeImpl(
|
||||
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
90
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
90
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["vector"] = "vector"
|
||||
|
||||
|
||||
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyvalue"] = "keyvalue"
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyword"] = "keyword"
|
||||
|
||||
|
||||
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["graph"] = "graph"
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankConfig,
|
||||
KeyValueMemoryBankConfig,
|
||||
KeywordMemoryBankConfig,
|
||||
GraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||
MemoryQueryGenerator.default.value
|
||||
)
|
||||
sep: str = " "
|
||||
|
||||
|
||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||
|
||||
|
||||
MemoryQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
CustomMemoryQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryToolConfig(BaseModel):
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MemoryToolRuntimeConfig(BaseModel):
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates a query that will be used for
|
||||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
||||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {
|
||||
"messages": [
|
||||
message.model_dump() if isinstance(message, BaseModel) else message
|
||||
for message in messages
|
||||
]
|
||||
}
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
response = await inference_api.chat_completion(
|
||||
model_id=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
query = response.completion_message.content
|
||||
|
||||
return query
|
146
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
146
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
|
@ -0,0 +1,146 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.tools import (
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
|
||||
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||
)
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: MemoryToolRuntimeConfig,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
inference_api: Inference,
|
||||
):
|
||||
self.config = config
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="query_memory",
|
||||
description="Retrieve context from memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="messages",
|
||||
description="The input messages to search for",
|
||||
parameter_type="array",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def _retrieve_context(
|
||||
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
input_messages,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": self.config.max_chunks,
|
||||
},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
]
|
||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: self.config.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > self.config.max_tokens_in_context:
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
|
||||
return [
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||
final_args = tool_group.args or {}
|
||||
final_args.update(args)
|
||||
config = MemoryToolConfig()
|
||||
if tool.metadata and tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
if "memory_bank_ids" in final_args:
|
||||
bank_ids = final_args["memory_bank_ids"]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
if "messages" not in final_args:
|
||||
raise ValueError("messages are required")
|
||||
context = await self._retrieve_context(
|
||||
final_args["messages"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
context = []
|
||||
return ToolInvocationResult(
|
||||
content=concat_interleaved_content(context), error_code=0
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue