migrate tools and make tool runtime discover

This commit is contained in:
Dinesh Yeduguru 2024-12-17 14:00:29 -08:00
parent 69a17e93b7
commit 482a0e4839
13 changed files with 1007 additions and 25 deletions

View file

@ -4,55 +4,100 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import logging
from enum import Enum
from typing import Any, Dict
import llama_stack.providers.inline.tool_runtime.meta_reference.builtins as builtins
import pkgutil
from typing import Any, Dict, Optional, Type
from llama_stack.apis.tools import Tool, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from .config import MetaReferenceToolRuntimeConfig
logger = logging.getLogger(__name__)
class ToolType(Enum):
bing_search = "bing_search"
brave_search = "brave_search"
tavily_search = "tavily_search"
print_tool = "print_tool"
class MetaReferenceToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: MetaReferenceToolRuntimeConfig):
self.config = config
self.tools: Dict[str, Type[BaseTool]] = {}
self.tool_instances: Dict[str, BaseTool] = {}
self._discover_tools()
def _discover_tools(self):
# Import all tools from the tools package
tools_package = "llama_stack.providers.inline.tool_runtime.tools"
package = importlib.import_module(tools_package)
for _, name, _ in pkgutil.iter_modules(package.__path__):
module = importlib.import_module(f"{tools_package}.{name}")
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (
isinstance(attr, type)
and issubclass(attr, BaseTool)
and attr != BaseTool
):
self.tools[attr.tool_id()] = attr
async def _create_tool_instance(
self, tool_id: str, tool_def: Optional[Tool] = None
) -> BaseTool:
"""Create a new tool instance with proper configuration"""
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found in available tools")
tool_class = self.tools[tool_id]
# Get tool definition if not provided
if tool_def is None:
tool_def = await self.tool_store.get_tool(tool_id)
# Build configuration
config = dict(tool_def.provider_metadata.get("config") or {})
if tool_class.requires_api_key:
config["api_key"] = self._get_api_key()
return tool_class(config=config)
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
print(f"registering tool {tool.identifier}")
if tool.provider_resource_id not in ToolType.__members__:
raise ValueError(
f"Tool {tool.identifier} not a supported tool by Meta Reference"
)
if tool.identifier not in self.tools:
raise ValueError(f"Tool {tool.identifier} not found in available tools")
async def unregister_tool(self, tool_id: str) -> None:
raise NotImplementedError("Meta Reference does not support unregistering tools")
# Validate provider_metadata against tool's config type if specified
tool_class = self.tools[tool.identifier]
config_type = tool_class.get_provider_config_type()
if (
config_type
and tool.provider_metadata
and tool.provider_metadata.get("config")
):
config_type(**tool.provider_metadata.get("config"))
self.tool_instances[tool.identifier] = await self._create_tool_instance(
tool.identifier, tool
)
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
tool = await self.tool_store.get_tool(tool_id)
if args.get("__api_key__") is not None:
logger.warning(
"__api_key__ is a reserved argument for this tool: {tool_id}"
)
args["__api_key__"] = self._get_api_key()
return await getattr(builtins, tool.provider_resource_id)(**args)
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found")
if tool_id not in self.tool_instances:
self.tool_instances[tool_id] = await self._create_tool_instance(tool_id)
return await self.tool_instances[tool_id].execute(**args)
async def unregister_tool(self, tool_id: str) -> None:
if tool_id in self.tool_instances:
del self.tool_instances[tool_id]
raise NotImplementedError("Meta Reference does not support unregistering tools")
def _get_api_key(self) -> str:
provider_data = self.get_request_provider_data()

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type, TypeVar
T = TypeVar("T")
class BaseTool(ABC):
"""Base class for all tools"""
requires_api_key: bool = False
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
@classmethod
@abstractmethod
def tool_id(cls) -> str:
"""Unique identifier for the tool"""
pass
@abstractmethod
async def execute(self, **kwargs) -> Any:
"""Execute the tool with given arguments"""
pass
@classmethod
def get_provider_config_type(cls) -> Optional[Type[T]]:
"""Override to specify a Pydantic model for tool configuration"""
return None

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import List
import requests
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel
class BingSearchConfig(BaseModel):
api_key: str
max_results: int = 5
class BingSearchTool(BaseTool):
requires_api_key: bool = True
@classmethod
def tool_id(cls) -> str:
return "bing_search"
@classmethod
def get_provider_config_type(cls):
return BingSearchConfig
async def execute(self, query: str) -> List[dict]:
config = BingSearchConfig(**self.config)
url = "https://api.bing.microsoft.com/v7.0/search"
headers = {
"Ocp-Apim-Subscription-Key": config.api_key,
}
params = {
"count": config.max_results,
"textDecorations": True,
"textFormat": "HTML",
"q": query,
}
response = requests.get(url=url, params=params, headers=headers)
response.raise_for_status()
return json.dumps(self._clean_response(response.json()))
def _clean_response(self, search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
clean_response.append(clean_news)
return {"query": query, "results": clean_response}

View file

@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
import requests
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel
class BraveSearchConfig(BaseModel):
api_key: str
max_results: int = 3
class BraveSearchTool(BaseTool):
requires_api_key: bool = True
@classmethod
def tool_id(cls) -> str:
return "brave_search"
@classmethod
def get_provider_config_type(cls):
return BraveSearchConfig
async def execute(self, query: str) -> List[dict]:
config = BraveSearchConfig(**self.config)
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": config.api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": query}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
return self._clean_brave_response(response.json(), config.max_results)
def _clean_brave_response(self, search_response, top_k=3):
query = None
clean_response = []
if "query" in search_response:
if "original" in search_response["query"]:
query = search_response["query"]["original"]
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][:top_k]:
r_type = m["type"]
results = search_response[r_type]["results"]
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
clean_response.append(cleaned)
return {"query": query, "results": clean_response}
def _clean_result_by_type(self, r_type, results, idx=None):
type_cleaners = {
"web": (
["type", "title", "url", "description", "date", "extra_snippets"],
lambda x: x[idx],
),
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
"infobox": (
["type", "title", "url", "description", "long_desc"],
lambda x: x[idx],
),
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
"locations": (
[
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
],
lambda x: x,
),
"news": (["type", "title", "url", "description"], lambda x: x),
}
if r_type not in type_cleaners:
return []
selected_keys, result_selector = type_cleaners[r_type]
results = result_selector(results)
if isinstance(results, list):
return [
{k: v for k, v in item.items() if k in selected_keys}
for item in results
]
return {k: v for k, v in results.items() if k in selected_keys}

View file

@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from typing import Dict
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel
from .ipython_tool.code_execution import (
CodeExecutionContext,
CodeExecutionRequest,
CodeExecutor,
)
class CodeInterpreterConfig(BaseModel):
matplotlib_dump_dir: str = None
class CodeInterpreterTool(BaseTool):
@classmethod
def tool_id(cls) -> str:
return "code_interpreter"
@classmethod
def get_provider_config_type(cls):
return CodeInterpreterConfig
async def execute(self, code: str) -> Dict:
config = CodeInterpreterConfig(**self.config)
ctx = CodeExecutionContext(
matplotlib_dump_dir=config.matplotlib_dump_dir or tempfile.mkdtemp(),
)
executor = CodeExecutor(ctx)
req = CodeExecutionRequest(scripts=[code])
result = executor.execute(req)
response = {"status": result["process_status"], "output": []}
for out_type in ["stdout", "stderr"]:
if result[out_type]:
response["output"].append(
{"type": out_type, "content": result[out_type]}
)
return response

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import errno
# Disabling potentially dangerous functions
import os as _os
from functools import partial
os_funcs_to_disable = [
"kill",
"system",
"putenv",
"remove",
"removedirs",
"rmdir",
"fchdir",
"setuid",
"fork",
"forkpty",
"killpg",
"rename",
"renames",
"truncate",
"replace",
# "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
"fchmod",
"fchown",
"chmod",
"chown",
"chroot",
"fchdir",
"lchflags",
"lchmod",
"lchown",
"chdir",
]
def call_not_allowed(*args, **kwargs):
raise OSError(errno.EPERM, "Call are not permitted in this environment")
for func_name in os_funcs_to_disable:
if hasattr(_os, func_name):
setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
import shutil as _shutil
for func_name in ["rmtree", "move", "chown"]:
if hasattr(_shutil, func_name):
setattr(
_shutil,
func_name,
partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
)
import subprocess as _subprocess
def popen_not_allowed(*args, **kwargs):
raise _subprocess.CalledProcessError(
-1,
args[0] if args else "unknown",
stderr="subprocess.Popen is not allowed in this environment",
)
_subprocess.Popen = popen_not_allowed
import atexit as _atexit
import builtins as _builtins
import io as _io
import json as _json
import sys as _sys
# NB! The following "unused" imports crucial, make sure not not to remove
# them with linters - they're used in code_execution.py
from contextlib import ( # noqa
contextmanager as _contextmanager,
redirect_stderr as _redirect_stderr,
redirect_stdout as _redirect_stdout,
)
from multiprocessing.connection import Connection as _Connection
# Mangle imports to avoid polluting model execution namespace.
_IO_SINK = _io.StringIO()
_NETWORK_TIMEOUT = 5
_NETWORK_CONNECTIONS = None
def _open_connections():
global _NETWORK_CONNECTIONS
if _NETWORK_CONNECTIONS is not None:
# Ensure connections only opened once.
return _NETWORK_CONNECTIONS
req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
req_con = _Connection(int(req_w_fd), readable=False)
resp_con = _Connection(int(resp_r_fd), writable=False)
_NETWORK_CONNECTIONS = (req_con, resp_con)
return _NETWORK_CONNECTIONS
_builtins._open_connections = _open_connections
@_atexit.register
def _close_connections():
global _NETWORK_CONNECTIONS
if _NETWORK_CONNECTIONS is None:
return
for con in _NETWORK_CONNECTIONS:
con.close()
del _NETWORK_CONNECTIONS
def _network_call(request):
# NOTE: We communicate with the parent process in json, encoded
# in raw bytes. We do this because native send/recv methods use
# pickle which involves execution of arbitrary code.
_open_connections()
req_con, resp_con = _NETWORK_CONNECTIONS
req_con.send_bytes(_json.dumps(request).encode("utf-8"))
if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
raise Exception(f"Network request timed out: {_json.dumps(request)}")
else:
return _json.loads(resp_con.recv_bytes().decode("utf-8"))

View file

@ -0,0 +1,256 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import multiprocessing
import os
import re
import subprocess
import sys
import tempfile
import textwrap
import time
from dataclasses import dataclass
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import List
from PIL import Image
from .utils import get_code_env_prefix
TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
DIRNAME = Path(__file__).parent
CODE_EXEC_TIMEOUT = 20
CODE_ENV_PREFIX = get_code_env_prefix()
STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
{code}\
"""
TRYEXCEPT_WRAPPER_TEMPLATE = """\
try:
{code}
except:
pass\
"""
def generate_bwrap_command(bind_dirs: List[str]) -> str:
"""
Generate the bwrap command string for binding all
directories in the current directory read-only.
"""
bwrap_args = ""
bwrap_args += "--ro-bind / / "
# Add the --dev flag to mount device files
bwrap_args += "--dev /dev "
for d in bind_dirs:
bwrap_args += f"--bind {d} {d} "
# Add the --unshare-all flag to isolate the sandbox from the rest of the system
bwrap_args += "--unshare-all "
# Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
bwrap_args += "--die-with-parent "
return bwrap_args
@dataclass
class CodeExecutionContext:
matplotlib_dump_dir: str
use_proxy: bool = False
@dataclass
class CodeExecutionRequest:
scripts: List[str]
only_last_cell_stdouterr: bool = True
only_last_cell_fail: bool = True
seed: int = 0
strip_fpaths_in_stderr: bool = True
class CodeExecutor:
def __init__(self, context: CodeExecutionContext):
self.context = context
def execute(self, req: CodeExecutionRequest) -> dict:
scripts = req.scripts
for i in range(len(scripts) - 1):
if req.only_last_cell_stdouterr:
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
code=textwrap.indent(scripts[i], " " * 4)
)
if req.only_last_cell_fail:
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
code=textwrap.indent(scripts[i], " " * 4)
)
# Seeds prefix:
seed = req.seed
seeds_prefix = f"""\
def _set_seeds():
import random
random.seed({seed})
import numpy as np
np.random.seed({seed})
_set_seeds()\
"""
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
with tempfile.TemporaryDirectory() as dpath:
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
code_fpath = os.path.join(dpath, "code.py")
with open(code_fpath, "w") as f:
f.write(script)
try:
python_path = os.environ.get("PYTHONPATH", "")
env = dict(
os.environ,
PYTHONHASHSEED=str(seed),
MPLCONFIGDIR=dpath,
MPLBACKEND="module://matplotlib_custom_backend",
PYTHONPATH=f"{DIRNAME}:{python_path}",
)
stdout, stderr, returncode = do_subprocess(
cmd=cmd,
env=env,
ctx=self.context,
)
stderr = stderr.strip()
if req.strip_fpaths_in_stderr:
pattern = r'File "([^"]+)", line (\d+)'
stderr = re.sub(pattern, r"line \2", stderr)
return {
"process_status": "completed",
"returncode": returncode,
"stdout": stdout.strip(),
"stderr": stderr,
}
except subprocess.TimeoutExpired:
return {
"process_status": "timeout",
"stdout": "Timed out",
"stderr": "Timed out",
}
except Exception as e:
return {
"process_status": "error",
"error_type": type(e).__name__,
"stderr": str(e),
"stdout": str(e),
}
def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_data = response["image_data"]
# Convert the base64 string to a bytes object
images = [base64.b64decode(d["image_base64"]) for d in image_data]
# Create a list of PIL images from the bytes objects
images = [Image.open(BytesIO(img)) for img in images]
# Create a list of image paths
image_paths = []
for i, img in enumerate(images):
# create new directory for each day to better organize data:
dump_dname = datetime.today().strftime("%Y-%m-%d")
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
dump_dpath.mkdir(parents=True, exist_ok=True)
# save image into a file
dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
dump_fpath = dump_dpath / dump_fname
img.save(dump_fpath, "PNG")
image_paths.append(str(dump_fpath))
# this is kind of convoluted, we send back this response to the subprocess which
# prints it out
info = {
"filepath": str(image_paths[-1]),
"mimetype": "image/png",
}
return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
def execute_subprocess_request(request, ctx: CodeExecutionContext):
"Route requests from the subprocess (via network Pipes) to the internet/tools."
if request["type"] == "matplotlib":
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
else:
raise Exception(f'Unrecognised network request type: {request["type"]}')
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
# Create Pipes to be used for any external tool/network requests.
req_r, req_w = multiprocessing.Pipe(duplex=False)
resp_r, resp_w = multiprocessing.Pipe(duplex=False)
cmd += [str(req_w.fileno()), str(resp_r.fileno())]
proc = subprocess.Popen(
cmd,
pass_fds=(req_w.fileno(), resp_r.fileno()),
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=True,
env=env,
)
# Close unnecessary fds.
req_w.close()
resp_r.close()
pipe_close = False
done_read = False
start = time.monotonic()
while proc.poll() is None and not pipe_close:
if req_r.poll(0.1):
# NB: Python pipe semantics for poll and recv mean that
# poll() returns True is a pipe is closed.
# CF old school PEP from '09
# https://bugs.python.org/issue5573
try:
request = json.loads(req_r.recv_bytes().decode("utf-8"))
response = execute_subprocess_request(request, ctx)
resp_w.send_bytes(json.dumps(response).encode("utf-8"))
except EOFError:
# The request pipe is closed - set a marker to exit
# after the next attempt at reading stdout/stderr.
pipe_close = True
try:
# If lots has been printed, pipe might be full but
# proc cannot exit until all the stdout/stderr
# been written/read.
stdout, stderr = proc.communicate(timeout=0.3)
done_read = True
except subprocess.TimeoutExpired:
# The program has not terminated. Ignore it, there
# may be more network/tool requests.
continue
if time.monotonic() - start > CODE_EXEC_TIMEOUT:
proc.terminate()
raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
if not done_read:
# Solve race condition where process terminates before
# we hit the while loop.
stdout, stderr = proc.communicate(timeout=0.3)
resp_w.close()
req_r.close()
return stdout, stderr, proc.returncode

View file

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
A custom Matplotlib backend that overrides the show method to return image bytes.
"""
import base64
import io
import json as _json
import logging
import matplotlib
from matplotlib.backend_bases import FigureManagerBase
# Import necessary components from Matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg
log = logging.getLogger(__name__)
class CustomFigureCanvas(FigureCanvasAgg):
def show(self):
# Save the figure to a BytesIO object
buf = io.BytesIO()
self.print_png(buf)
image_bytes = buf.getvalue()
buf.close()
return image_bytes
class CustomFigureManager(FigureManagerBase):
def __init__(self, canvas, num):
super().__init__(canvas, num)
# Mimic module initialization that integrates with the Matplotlib backend system
def _create_figure_manager(num, *args, **kwargs):
"""
Create a custom figure manager instance.
"""
FigureClass = kwargs.pop("FigureClass", None) # noqa: N806
if FigureClass is None:
from matplotlib.figure import Figure
FigureClass = Figure # noqa: N806
fig = FigureClass(*args, **kwargs)
canvas = CustomFigureCanvas(fig)
manager = CustomFigureManager(canvas, num)
return manager
def show():
"""
Handle all figures and potentially return their images as bytes.
This function iterates over all figures registered with the custom backend,
renders them as images in bytes format, and could return a list of bytes objects,
one for each figure, or handle them as needed.
"""
image_data = []
for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers():
# Get the figure from the manager
fig = manager.canvas.figure
buf = io.BytesIO() # Create a buffer for the figure
fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format
buf.seek(0) # Go to the beginning of the buffer
image_bytes = buf.getvalue() # Retrieve bytes value
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
image_data.append({"image_base64": image_base64})
buf.close()
req_con, resp_con = _open_connections()
_json_dump = _json.dumps(
{
"type": "matplotlib",
"image_data": image_data,
}
)
req_con.send_bytes(_json_dump.encode("utf-8"))
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
log.info(resp)
FigureCanvas = CustomFigureCanvas
FigureManager = CustomFigureManager

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
DIR = os.path.dirname(os.path.realpath(__file__))
CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
CODE_ENV_PREFIX = None
def get_code_env_prefix() -> str:
global CODE_ENV_PREFIX
if CODE_ENV_PREFIX is None:
with open(CODE_ENV_PREFIX_FILE, "r") as f:
CODE_ENV_PREFIX = f.read()
return CODE_ENV_PREFIX

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel
class PhotogenConfig(BaseModel):
dump_dir: str
class PhotogenTool(BaseTool):
@classmethod
def tool_id(cls) -> str:
return "photogen"
@classmethod
def get_provider_config_type(cls):
return PhotogenConfig
async def execute(self, query: str) -> Dict:
config = PhotogenConfig(**self.config)
"""
Implement this to give the model an ability to generate images.
Return:
info = {
"filepath": str(image_filepath),
"mimetype": "image/png",
}
"""
raise NotImplementedError()

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
import requests
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel
class TavilySearchConfig(BaseModel):
api_key: str
max_results: int = 3
class TavilySearchTool(BaseTool):
requires_api_key: bool = True
@classmethod
def tool_id(cls) -> str:
return "tavily_search"
@classmethod
def get_provider_config_type(cls):
return TavilySearchConfig
async def execute(self, query: str) -> List[dict]:
config = TavilySearchConfig(**self.config)
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": config.api_key, "query": query},
)
response.raise_for_status()
search_response = response.json()
return {
"query": search_response["query"],
"results": search_response["results"][: config.max_results],
}

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Dict
import requests
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel
class WolframAlphaConfig(BaseModel):
api_key: str
class WolframAlphaTool(BaseTool):
requires_api_key: bool = True
@classmethod
def tool_id(cls) -> str:
return "wolfram_alpha"
@classmethod
def get_provider_config_type(cls):
return WolframAlphaConfig
async def execute(self, query: str) -> Dict:
config = WolframAlphaConfig(**self.config)
url = "https://api.wolframalpha.com/v2/query"
params = {
"input": query,
"appid": config.api_key,
"format": "plaintext",
"output": "json",
}
response = requests.get(url, params=params)
response.raise_for_status()
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
def _clean_wolfram_alpha_response(self, wa_response):
remove = {
"queryresult": [
"datatypes",
"error",
"timedout",
"timedoutpods",
"numpods",
"timing",
"parsetiming",
"parsetimedout",
"recalculate",
"id",
"host",
"server",
"related",
"version",
{
"pods": [
"scanner",
"id",
"error",
"expressiontypes",
"states",
"infos",
"position",
"numsubpods",
]
},
"assumptions",
],
}
result = wa_response.copy()
for main_key, to_remove in remove.items():
if main_key not in result:
continue
for item in to_remove:
if isinstance(item, dict):
for sub_key, sub_items in item.items():
if sub_key == "pods":
pods = result[main_key].get(sub_key, [])
for i, pod in enumerate(pods):
if pod.get("title") == "Result":
pods = pods[: i + 1]
break
for remove_key in sub_items:
pod.pop(remove_key, None)
else:
result[main_key].pop(item, None)
return result