mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
migrate tools and make tool runtime discover
This commit is contained in:
parent
69a17e93b7
commit
482a0e4839
13 changed files with 1007 additions and 25 deletions
|
@ -4,55 +4,100 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
import pkgutil
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
import llama_stack.providers.inline.tool_runtime.meta_reference.builtins as builtins
|
|
||||||
|
|
||||||
from llama_stack.apis.tools import Tool, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolRuntime
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
|
||||||
|
|
||||||
from .config import MetaReferenceToolRuntimeConfig
|
from .config import MetaReferenceToolRuntimeConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ToolType(Enum):
|
|
||||||
bing_search = "bing_search"
|
|
||||||
brave_search = "brave_search"
|
|
||||||
tavily_search = "tavily_search"
|
|
||||||
print_tool = "print_tool"
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceToolRuntimeImpl(
|
class MetaReferenceToolRuntimeImpl(
|
||||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
):
|
):
|
||||||
def __init__(self, config: MetaReferenceToolRuntimeConfig):
|
def __init__(self, config: MetaReferenceToolRuntimeConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.tools: Dict[str, Type[BaseTool]] = {}
|
||||||
|
self.tool_instances: Dict[str, BaseTool] = {}
|
||||||
|
self._discover_tools()
|
||||||
|
|
||||||
|
def _discover_tools(self):
|
||||||
|
# Import all tools from the tools package
|
||||||
|
tools_package = "llama_stack.providers.inline.tool_runtime.tools"
|
||||||
|
package = importlib.import_module(tools_package)
|
||||||
|
|
||||||
|
for _, name, _ in pkgutil.iter_modules(package.__path__):
|
||||||
|
module = importlib.import_module(f"{tools_package}.{name}")
|
||||||
|
for attr_name in dir(module):
|
||||||
|
attr = getattr(module, attr_name)
|
||||||
|
if (
|
||||||
|
isinstance(attr, type)
|
||||||
|
and issubclass(attr, BaseTool)
|
||||||
|
and attr != BaseTool
|
||||||
|
):
|
||||||
|
self.tools[attr.tool_id()] = attr
|
||||||
|
|
||||||
|
async def _create_tool_instance(
|
||||||
|
self, tool_id: str, tool_def: Optional[Tool] = None
|
||||||
|
) -> BaseTool:
|
||||||
|
"""Create a new tool instance with proper configuration"""
|
||||||
|
if tool_id not in self.tools:
|
||||||
|
raise ValueError(f"Tool {tool_id} not found in available tools")
|
||||||
|
|
||||||
|
tool_class = self.tools[tool_id]
|
||||||
|
|
||||||
|
# Get tool definition if not provided
|
||||||
|
if tool_def is None:
|
||||||
|
tool_def = await self.tool_store.get_tool(tool_id)
|
||||||
|
|
||||||
|
# Build configuration
|
||||||
|
config = dict(tool_def.provider_metadata.get("config") or {})
|
||||||
|
if tool_class.requires_api_key:
|
||||||
|
config["api_key"] = self._get_api_key()
|
||||||
|
|
||||||
|
return tool_class(config=config)
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool):
|
async def register_tool(self, tool: Tool):
|
||||||
print(f"registering tool {tool.identifier}")
|
if tool.identifier not in self.tools:
|
||||||
if tool.provider_resource_id not in ToolType.__members__:
|
raise ValueError(f"Tool {tool.identifier} not found in available tools")
|
||||||
raise ValueError(
|
|
||||||
f"Tool {tool.identifier} not a supported tool by Meta Reference"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
# Validate provider_metadata against tool's config type if specified
|
||||||
raise NotImplementedError("Meta Reference does not support unregistering tools")
|
tool_class = self.tools[tool.identifier]
|
||||||
|
config_type = tool_class.get_provider_config_type()
|
||||||
|
if (
|
||||||
|
config_type
|
||||||
|
and tool.provider_metadata
|
||||||
|
and tool.provider_metadata.get("config")
|
||||||
|
):
|
||||||
|
config_type(**tool.provider_metadata.get("config"))
|
||||||
|
|
||||||
|
self.tool_instances[tool.identifier] = await self._create_tool_instance(
|
||||||
|
tool.identifier, tool
|
||||||
|
)
|
||||||
|
|
||||||
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
|
||||||
tool = await self.tool_store.get_tool(tool_id)
|
if tool_id not in self.tools:
|
||||||
if args.get("__api_key__") is not None:
|
raise ValueError(f"Tool {tool_id} not found")
|
||||||
logger.warning(
|
|
||||||
"__api_key__ is a reserved argument for this tool: {tool_id}"
|
if tool_id not in self.tool_instances:
|
||||||
)
|
self.tool_instances[tool_id] = await self._create_tool_instance(tool_id)
|
||||||
args["__api_key__"] = self._get_api_key()
|
|
||||||
return await getattr(builtins, tool.provider_resource_id)(**args)
|
return await self.tool_instances[tool_id].execute(**args)
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
if tool_id in self.tool_instances:
|
||||||
|
del self.tool_instances[tool_id]
|
||||||
|
raise NotImplementedError("Meta Reference does not support unregistering tools")
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Optional, Type, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool(ABC):
|
||||||
|
"""Base class for all tools"""
|
||||||
|
|
||||||
|
requires_api_key: bool = False
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||||
|
self.config = config or {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def tool_id(cls) -> str:
|
||||||
|
"""Unique identifier for the tool"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self, **kwargs) -> Any:
|
||||||
|
"""Execute the tool with given arguments"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_config_type(cls) -> Optional[Type[T]]:
|
||||||
|
"""Override to specify a Pydantic model for tool configuration"""
|
||||||
|
return None
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchConfig(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
max_results: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchTool(BaseTool):
|
||||||
|
requires_api_key: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tool_id(cls) -> str:
|
||||||
|
return "bing_search"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_config_type(cls):
|
||||||
|
return BingSearchConfig
|
||||||
|
|
||||||
|
async def execute(self, query: str) -> List[dict]:
|
||||||
|
config = BingSearchConfig(**self.config)
|
||||||
|
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
headers = {
|
||||||
|
"Ocp-Apim-Subscription-Key": config.api_key,
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"count": config.max_results,
|
||||||
|
"textDecorations": True,
|
||||||
|
"textFormat": "HTML",
|
||||||
|
"q": query,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(url=url, params=params, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return json.dumps(self._clean_response(response.json()))
|
||||||
|
|
||||||
|
def _clean_response(self, search_response):
|
||||||
|
clean_response = []
|
||||||
|
query = search_response["queryContext"]["originalQuery"]
|
||||||
|
if "webPages" in search_response:
|
||||||
|
pages = search_response["webPages"]["value"]
|
||||||
|
for p in pages:
|
||||||
|
selected_keys = {"name", "url", "snippet"}
|
||||||
|
clean_response.append(
|
||||||
|
{k: v for k, v in p.items() if k in selected_keys}
|
||||||
|
)
|
||||||
|
if "news" in search_response:
|
||||||
|
clean_news = []
|
||||||
|
news = search_response["news"]["value"]
|
||||||
|
for n in news:
|
||||||
|
selected_keys = {"name", "url", "description"}
|
||||||
|
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||||
|
clean_response.append(clean_news)
|
||||||
|
|
||||||
|
return {"query": query, "results": clean_response}
|
|
@ -0,0 +1,101 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BraveSearchConfig(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
max_results: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
class BraveSearchTool(BaseTool):
|
||||||
|
requires_api_key: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tool_id(cls) -> str:
|
||||||
|
return "brave_search"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_config_type(cls):
|
||||||
|
return BraveSearchConfig
|
||||||
|
|
||||||
|
async def execute(self, query: str) -> List[dict]:
|
||||||
|
config = BraveSearchConfig(**self.config)
|
||||||
|
url = "https://api.search.brave.com/res/v1/web/search"
|
||||||
|
headers = {
|
||||||
|
"X-Subscription-Token": config.api_key,
|
||||||
|
"Accept-Encoding": "gzip",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
payload = {"q": query}
|
||||||
|
response = requests.get(url=url, params=payload, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return self._clean_brave_response(response.json(), config.max_results)
|
||||||
|
|
||||||
|
def _clean_brave_response(self, search_response, top_k=3):
|
||||||
|
query = None
|
||||||
|
clean_response = []
|
||||||
|
if "query" in search_response:
|
||||||
|
if "original" in search_response["query"]:
|
||||||
|
query = search_response["query"]["original"]
|
||||||
|
if "mixed" in search_response:
|
||||||
|
mixed_results = search_response["mixed"]
|
||||||
|
for m in mixed_results["main"][:top_k]:
|
||||||
|
r_type = m["type"]
|
||||||
|
results = search_response[r_type]["results"]
|
||||||
|
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
|
||||||
|
clean_response.append(cleaned)
|
||||||
|
|
||||||
|
return {"query": query, "results": clean_response}
|
||||||
|
|
||||||
|
def _clean_result_by_type(self, r_type, results, idx=None):
|
||||||
|
type_cleaners = {
|
||||||
|
"web": (
|
||||||
|
["type", "title", "url", "description", "date", "extra_snippets"],
|
||||||
|
lambda x: x[idx],
|
||||||
|
),
|
||||||
|
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
|
||||||
|
"infobox": (
|
||||||
|
["type", "title", "url", "description", "long_desc"],
|
||||||
|
lambda x: x[idx],
|
||||||
|
),
|
||||||
|
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
|
||||||
|
"locations": (
|
||||||
|
[
|
||||||
|
"type",
|
||||||
|
"title",
|
||||||
|
"url",
|
||||||
|
"description",
|
||||||
|
"coordinates",
|
||||||
|
"postal_address",
|
||||||
|
"contact",
|
||||||
|
"rating",
|
||||||
|
"distance",
|
||||||
|
"zoom_level",
|
||||||
|
],
|
||||||
|
lambda x: x,
|
||||||
|
),
|
||||||
|
"news": (["type", "title", "url", "description"], lambda x: x),
|
||||||
|
}
|
||||||
|
|
||||||
|
if r_type not in type_cleaners:
|
||||||
|
return []
|
||||||
|
|
||||||
|
selected_keys, result_selector = type_cleaners[r_type]
|
||||||
|
results = result_selector(results)
|
||||||
|
|
||||||
|
if isinstance(results, list):
|
||||||
|
return [
|
||||||
|
{k: v for k, v in item.items() if k in selected_keys}
|
||||||
|
for item in results
|
||||||
|
]
|
||||||
|
return {k: v for k, v in results.items() if k in selected_keys}
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .ipython_tool.code_execution import (
|
||||||
|
CodeExecutionContext,
|
||||||
|
CodeExecutionRequest,
|
||||||
|
CodeExecutor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterConfig(BaseModel):
|
||||||
|
matplotlib_dump_dir: str = None
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterTool(BaseTool):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tool_id(cls) -> str:
|
||||||
|
return "code_interpreter"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_config_type(cls):
|
||||||
|
return CodeInterpreterConfig
|
||||||
|
|
||||||
|
async def execute(self, code: str) -> Dict:
|
||||||
|
config = CodeInterpreterConfig(**self.config)
|
||||||
|
|
||||||
|
ctx = CodeExecutionContext(
|
||||||
|
matplotlib_dump_dir=config.matplotlib_dump_dir or tempfile.mkdtemp(),
|
||||||
|
)
|
||||||
|
executor = CodeExecutor(ctx)
|
||||||
|
|
||||||
|
req = CodeExecutionRequest(scripts=[code])
|
||||||
|
result = executor.execute(req)
|
||||||
|
|
||||||
|
response = {"status": result["process_status"], "output": []}
|
||||||
|
|
||||||
|
for out_type in ["stdout", "stderr"]:
|
||||||
|
if result[out_type]:
|
||||||
|
response["output"].append(
|
||||||
|
{"type": out_type, "content": result[out_type]}
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,133 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import errno
|
||||||
|
|
||||||
|
# Disabling potentially dangerous functions
|
||||||
|
import os as _os
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
os_funcs_to_disable = [
|
||||||
|
"kill",
|
||||||
|
"system",
|
||||||
|
"putenv",
|
||||||
|
"remove",
|
||||||
|
"removedirs",
|
||||||
|
"rmdir",
|
||||||
|
"fchdir",
|
||||||
|
"setuid",
|
||||||
|
"fork",
|
||||||
|
"forkpty",
|
||||||
|
"killpg",
|
||||||
|
"rename",
|
||||||
|
"renames",
|
||||||
|
"truncate",
|
||||||
|
"replace",
|
||||||
|
# "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
|
||||||
|
"fchmod",
|
||||||
|
"fchown",
|
||||||
|
"chmod",
|
||||||
|
"chown",
|
||||||
|
"chroot",
|
||||||
|
"fchdir",
|
||||||
|
"lchflags",
|
||||||
|
"lchmod",
|
||||||
|
"lchown",
|
||||||
|
"chdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def call_not_allowed(*args, **kwargs):
|
||||||
|
raise OSError(errno.EPERM, "Call are not permitted in this environment")
|
||||||
|
|
||||||
|
|
||||||
|
for func_name in os_funcs_to_disable:
|
||||||
|
if hasattr(_os, func_name):
|
||||||
|
setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
|
||||||
|
|
||||||
|
import shutil as _shutil
|
||||||
|
|
||||||
|
for func_name in ["rmtree", "move", "chown"]:
|
||||||
|
if hasattr(_shutil, func_name):
|
||||||
|
setattr(
|
||||||
|
_shutil,
|
||||||
|
func_name,
|
||||||
|
partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
import subprocess as _subprocess
|
||||||
|
|
||||||
|
|
||||||
|
def popen_not_allowed(*args, **kwargs):
|
||||||
|
raise _subprocess.CalledProcessError(
|
||||||
|
-1,
|
||||||
|
args[0] if args else "unknown",
|
||||||
|
stderr="subprocess.Popen is not allowed in this environment",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_subprocess.Popen = popen_not_allowed
|
||||||
|
|
||||||
|
|
||||||
|
import atexit as _atexit
|
||||||
|
import builtins as _builtins
|
||||||
|
import io as _io
|
||||||
|
import json as _json
|
||||||
|
import sys as _sys
|
||||||
|
|
||||||
|
# NB! The following "unused" imports crucial, make sure not not to remove
|
||||||
|
# them with linters - they're used in code_execution.py
|
||||||
|
from contextlib import ( # noqa
|
||||||
|
contextmanager as _contextmanager,
|
||||||
|
redirect_stderr as _redirect_stderr,
|
||||||
|
redirect_stdout as _redirect_stdout,
|
||||||
|
)
|
||||||
|
from multiprocessing.connection import Connection as _Connection
|
||||||
|
|
||||||
|
# Mangle imports to avoid polluting model execution namespace.
|
||||||
|
|
||||||
|
_IO_SINK = _io.StringIO()
|
||||||
|
_NETWORK_TIMEOUT = 5
|
||||||
|
_NETWORK_CONNECTIONS = None
|
||||||
|
|
||||||
|
|
||||||
|
def _open_connections():
|
||||||
|
global _NETWORK_CONNECTIONS
|
||||||
|
if _NETWORK_CONNECTIONS is not None:
|
||||||
|
# Ensure connections only opened once.
|
||||||
|
return _NETWORK_CONNECTIONS
|
||||||
|
req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
|
||||||
|
req_con = _Connection(int(req_w_fd), readable=False)
|
||||||
|
resp_con = _Connection(int(resp_r_fd), writable=False)
|
||||||
|
_NETWORK_CONNECTIONS = (req_con, resp_con)
|
||||||
|
return _NETWORK_CONNECTIONS
|
||||||
|
|
||||||
|
|
||||||
|
_builtins._open_connections = _open_connections
|
||||||
|
|
||||||
|
|
||||||
|
@_atexit.register
|
||||||
|
def _close_connections():
|
||||||
|
global _NETWORK_CONNECTIONS
|
||||||
|
if _NETWORK_CONNECTIONS is None:
|
||||||
|
return
|
||||||
|
for con in _NETWORK_CONNECTIONS:
|
||||||
|
con.close()
|
||||||
|
del _NETWORK_CONNECTIONS
|
||||||
|
|
||||||
|
|
||||||
|
def _network_call(request):
|
||||||
|
# NOTE: We communicate with the parent process in json, encoded
|
||||||
|
# in raw bytes. We do this because native send/recv methods use
|
||||||
|
# pickle which involves execution of arbitrary code.
|
||||||
|
_open_connections()
|
||||||
|
req_con, resp_con = _NETWORK_CONNECTIONS
|
||||||
|
|
||||||
|
req_con.send_bytes(_json.dumps(request).encode("utf-8"))
|
||||||
|
if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
|
||||||
|
raise Exception(f"Network request timed out: {_json.dumps(request)}")
|
||||||
|
else:
|
||||||
|
return _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
|
@ -0,0 +1,256 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import textwrap
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .utils import get_code_env_prefix
|
||||||
|
|
||||||
|
TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
|
||||||
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
|
|
||||||
|
DIRNAME = Path(__file__).parent
|
||||||
|
|
||||||
|
CODE_EXEC_TIMEOUT = 20
|
||||||
|
CODE_ENV_PREFIX = get_code_env_prefix()
|
||||||
|
|
||||||
|
STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
|
||||||
|
with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
|
||||||
|
{code}\
|
||||||
|
"""
|
||||||
|
|
||||||
|
TRYEXCEPT_WRAPPER_TEMPLATE = """\
|
||||||
|
try:
|
||||||
|
{code}
|
||||||
|
except:
|
||||||
|
pass\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_bwrap_command(bind_dirs: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Generate the bwrap command string for binding all
|
||||||
|
directories in the current directory read-only.
|
||||||
|
"""
|
||||||
|
bwrap_args = ""
|
||||||
|
bwrap_args += "--ro-bind / / "
|
||||||
|
# Add the --dev flag to mount device files
|
||||||
|
bwrap_args += "--dev /dev "
|
||||||
|
for d in bind_dirs:
|
||||||
|
bwrap_args += f"--bind {d} {d} "
|
||||||
|
|
||||||
|
# Add the --unshare-all flag to isolate the sandbox from the rest of the system
|
||||||
|
bwrap_args += "--unshare-all "
|
||||||
|
# Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
|
||||||
|
bwrap_args += "--die-with-parent "
|
||||||
|
return bwrap_args
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeExecutionContext:
|
||||||
|
matplotlib_dump_dir: str
|
||||||
|
use_proxy: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeExecutionRequest:
|
||||||
|
scripts: List[str]
|
||||||
|
only_last_cell_stdouterr: bool = True
|
||||||
|
only_last_cell_fail: bool = True
|
||||||
|
seed: int = 0
|
||||||
|
strip_fpaths_in_stderr: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class CodeExecutor:
|
||||||
|
def __init__(self, context: CodeExecutionContext):
|
||||||
|
self.context = context
|
||||||
|
|
||||||
|
def execute(self, req: CodeExecutionRequest) -> dict:
|
||||||
|
scripts = req.scripts
|
||||||
|
for i in range(len(scripts) - 1):
|
||||||
|
if req.only_last_cell_stdouterr:
|
||||||
|
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
|
||||||
|
code=textwrap.indent(scripts[i], " " * 4)
|
||||||
|
)
|
||||||
|
if req.only_last_cell_fail:
|
||||||
|
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
|
||||||
|
code=textwrap.indent(scripts[i], " " * 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Seeds prefix:
|
||||||
|
seed = req.seed
|
||||||
|
seeds_prefix = f"""\
|
||||||
|
def _set_seeds():
|
||||||
|
import random
|
||||||
|
random.seed({seed})
|
||||||
|
import numpy as np
|
||||||
|
np.random.seed({seed})
|
||||||
|
_set_seeds()\
|
||||||
|
"""
|
||||||
|
|
||||||
|
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
|
||||||
|
with tempfile.TemporaryDirectory() as dpath:
|
||||||
|
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||||
|
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||||
|
code_fpath = os.path.join(dpath, "code.py")
|
||||||
|
with open(code_fpath, "w") as f:
|
||||||
|
f.write(script)
|
||||||
|
|
||||||
|
try:
|
||||||
|
python_path = os.environ.get("PYTHONPATH", "")
|
||||||
|
env = dict(
|
||||||
|
os.environ,
|
||||||
|
PYTHONHASHSEED=str(seed),
|
||||||
|
MPLCONFIGDIR=dpath,
|
||||||
|
MPLBACKEND="module://matplotlib_custom_backend",
|
||||||
|
PYTHONPATH=f"{DIRNAME}:{python_path}",
|
||||||
|
)
|
||||||
|
stdout, stderr, returncode = do_subprocess(
|
||||||
|
cmd=cmd,
|
||||||
|
env=env,
|
||||||
|
ctx=self.context,
|
||||||
|
)
|
||||||
|
|
||||||
|
stderr = stderr.strip()
|
||||||
|
if req.strip_fpaths_in_stderr:
|
||||||
|
pattern = r'File "([^"]+)", line (\d+)'
|
||||||
|
stderr = re.sub(pattern, r"line \2", stderr)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"process_status": "completed",
|
||||||
|
"returncode": returncode,
|
||||||
|
"stdout": stdout.strip(),
|
||||||
|
"stderr": stderr,
|
||||||
|
}
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return {
|
||||||
|
"process_status": "timeout",
|
||||||
|
"stdout": "Timed out",
|
||||||
|
"stderr": "Timed out",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"process_status": "error",
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
"stderr": str(e),
|
||||||
|
"stdout": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||||
|
image_data = response["image_data"]
|
||||||
|
# Convert the base64 string to a bytes object
|
||||||
|
images = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||||
|
# Create a list of PIL images from the bytes objects
|
||||||
|
images = [Image.open(BytesIO(img)) for img in images]
|
||||||
|
# Create a list of image paths
|
||||||
|
image_paths = []
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
# create new directory for each day to better organize data:
|
||||||
|
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
||||||
|
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||||
|
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||||
|
# save image into a file
|
||||||
|
dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
|
||||||
|
dump_fpath = dump_dpath / dump_fname
|
||||||
|
img.save(dump_fpath, "PNG")
|
||||||
|
image_paths.append(str(dump_fpath))
|
||||||
|
|
||||||
|
# this is kind of convoluted, we send back this response to the subprocess which
|
||||||
|
# prints it out
|
||||||
|
info = {
|
||||||
|
"filepath": str(image_paths[-1]),
|
||||||
|
"mimetype": "image/png",
|
||||||
|
}
|
||||||
|
return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
|
||||||
|
|
||||||
|
|
||||||
|
def execute_subprocess_request(request, ctx: CodeExecutionContext):
|
||||||
|
"Route requests from the subprocess (via network Pipes) to the internet/tools."
|
||||||
|
if request["type"] == "matplotlib":
|
||||||
|
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
|
||||||
|
else:
|
||||||
|
raise Exception(f'Unrecognised network request type: {request["type"]}')
|
||||||
|
|
||||||
|
|
||||||
|
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
|
||||||
|
# Create Pipes to be used for any external tool/network requests.
|
||||||
|
req_r, req_w = multiprocessing.Pipe(duplex=False)
|
||||||
|
resp_r, resp_w = multiprocessing.Pipe(duplex=False)
|
||||||
|
|
||||||
|
cmd += [str(req_w.fileno()), str(resp_r.fileno())]
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
pass_fds=(req_w.fileno(), resp_r.fileno()),
|
||||||
|
text=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
close_fds=True,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close unnecessary fds.
|
||||||
|
req_w.close()
|
||||||
|
resp_r.close()
|
||||||
|
|
||||||
|
pipe_close = False
|
||||||
|
done_read = False
|
||||||
|
start = time.monotonic()
|
||||||
|
while proc.poll() is None and not pipe_close:
|
||||||
|
if req_r.poll(0.1):
|
||||||
|
# NB: Python pipe semantics for poll and recv mean that
|
||||||
|
# poll() returns True is a pipe is closed.
|
||||||
|
# CF old school PEP from '09
|
||||||
|
# https://bugs.python.org/issue5573
|
||||||
|
try:
|
||||||
|
request = json.loads(req_r.recv_bytes().decode("utf-8"))
|
||||||
|
response = execute_subprocess_request(request, ctx)
|
||||||
|
|
||||||
|
resp_w.send_bytes(json.dumps(response).encode("utf-8"))
|
||||||
|
except EOFError:
|
||||||
|
# The request pipe is closed - set a marker to exit
|
||||||
|
# after the next attempt at reading stdout/stderr.
|
||||||
|
pipe_close = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If lots has been printed, pipe might be full but
|
||||||
|
# proc cannot exit until all the stdout/stderr
|
||||||
|
# been written/read.
|
||||||
|
stdout, stderr = proc.communicate(timeout=0.3)
|
||||||
|
done_read = True
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
# The program has not terminated. Ignore it, there
|
||||||
|
# may be more network/tool requests.
|
||||||
|
continue
|
||||||
|
if time.monotonic() - start > CODE_EXEC_TIMEOUT:
|
||||||
|
proc.terminate()
|
||||||
|
raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
|
||||||
|
|
||||||
|
if not done_read:
|
||||||
|
# Solve race condition where process terminates before
|
||||||
|
# we hit the while loop.
|
||||||
|
stdout, stderr = proc.communicate(timeout=0.3)
|
||||||
|
|
||||||
|
resp_w.close()
|
||||||
|
req_r.close()
|
||||||
|
return stdout, stderr, proc.returncode
|
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
A custom Matplotlib backend that overrides the show method to return image bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json as _json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
from matplotlib.backend_bases import FigureManagerBase
|
||||||
|
|
||||||
|
# Import necessary components from Matplotlib
|
||||||
|
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomFigureCanvas(FigureCanvasAgg):
|
||||||
|
def show(self):
|
||||||
|
# Save the figure to a BytesIO object
|
||||||
|
buf = io.BytesIO()
|
||||||
|
self.print_png(buf)
|
||||||
|
image_bytes = buf.getvalue()
|
||||||
|
buf.close()
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class CustomFigureManager(FigureManagerBase):
|
||||||
|
def __init__(self, canvas, num):
|
||||||
|
super().__init__(canvas, num)
|
||||||
|
|
||||||
|
|
||||||
|
# Mimic module initialization that integrates with the Matplotlib backend system
|
||||||
|
def _create_figure_manager(num, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a custom figure manager instance.
|
||||||
|
"""
|
||||||
|
FigureClass = kwargs.pop("FigureClass", None) # noqa: N806
|
||||||
|
if FigureClass is None:
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
FigureClass = Figure # noqa: N806
|
||||||
|
fig = FigureClass(*args, **kwargs)
|
||||||
|
canvas = CustomFigureCanvas(fig)
|
||||||
|
manager = CustomFigureManager(canvas, num)
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
|
def show():
|
||||||
|
"""
|
||||||
|
Handle all figures and potentially return their images as bytes.
|
||||||
|
|
||||||
|
This function iterates over all figures registered with the custom backend,
|
||||||
|
renders them as images in bytes format, and could return a list of bytes objects,
|
||||||
|
one for each figure, or handle them as needed.
|
||||||
|
"""
|
||||||
|
image_data = []
|
||||||
|
for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers():
|
||||||
|
# Get the figure from the manager
|
||||||
|
fig = manager.canvas.figure
|
||||||
|
buf = io.BytesIO() # Create a buffer for the figure
|
||||||
|
fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format
|
||||||
|
buf.seek(0) # Go to the beginning of the buffer
|
||||||
|
image_bytes = buf.getvalue() # Retrieve bytes value
|
||||||
|
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
image_data.append({"image_base64": image_base64})
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
req_con, resp_con = _open_connections()
|
||||||
|
|
||||||
|
_json_dump = _json.dumps(
|
||||||
|
{
|
||||||
|
"type": "matplotlib",
|
||||||
|
"image_data": image_data,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
req_con.send_bytes(_json_dump.encode("utf-8"))
|
||||||
|
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
||||||
|
log.info(resp)
|
||||||
|
|
||||||
|
|
||||||
|
FigureCanvas = CustomFigureCanvas
|
||||||
|
FigureManager = CustomFigureManager
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
|
||||||
|
CODE_ENV_PREFIX = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_code_env_prefix() -> str:
|
||||||
|
global CODE_ENV_PREFIX
|
||||||
|
|
||||||
|
if CODE_ENV_PREFIX is None:
|
||||||
|
with open(CODE_ENV_PREFIX_FILE, "r") as f:
|
||||||
|
CODE_ENV_PREFIX = f.read()
|
||||||
|
|
||||||
|
return CODE_ENV_PREFIX
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class PhotogenConfig(BaseModel):
|
||||||
|
dump_dir: str
|
||||||
|
|
||||||
|
|
||||||
|
class PhotogenTool(BaseTool):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tool_id(cls) -> str:
|
||||||
|
return "photogen"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_config_type(cls):
|
||||||
|
return PhotogenConfig
|
||||||
|
|
||||||
|
async def execute(self, query: str) -> Dict:
|
||||||
|
config = PhotogenConfig(**self.config)
|
||||||
|
"""
|
||||||
|
Implement this to give the model an ability to generate images.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
info = {
|
||||||
|
"filepath": str(image_filepath),
|
||||||
|
"mimetype": "image/png",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TavilySearchConfig(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
max_results: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
class TavilySearchTool(BaseTool):
|
||||||
|
requires_api_key: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tool_id(cls) -> str:
|
||||||
|
return "tavily_search"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_config_type(cls):
|
||||||
|
return TavilySearchConfig
|
||||||
|
|
||||||
|
async def execute(self, query: str) -> List[dict]:
|
||||||
|
config = TavilySearchConfig(**self.config)
|
||||||
|
response = requests.post(
|
||||||
|
"https://api.tavily.com/search",
|
||||||
|
json={"api_key": config.api_key, "query": query},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
search_response = response.json()
|
||||||
|
return {
|
||||||
|
"query": search_response["query"],
|
||||||
|
"results": search_response["results"][: config.max_results],
|
||||||
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class WolframAlphaConfig(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class WolframAlphaTool(BaseTool):
|
||||||
|
requires_api_key: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tool_id(cls) -> str:
|
||||||
|
return "wolfram_alpha"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_config_type(cls):
|
||||||
|
return WolframAlphaConfig
|
||||||
|
|
||||||
|
async def execute(self, query: str) -> Dict:
|
||||||
|
config = WolframAlphaConfig(**self.config)
|
||||||
|
url = "https://api.wolframalpha.com/v2/query"
|
||||||
|
params = {
|
||||||
|
"input": query,
|
||||||
|
"appid": config.api_key,
|
||||||
|
"format": "plaintext",
|
||||||
|
"output": "json",
|
||||||
|
}
|
||||||
|
response = requests.get(url, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
||||||
|
|
||||||
|
def _clean_wolfram_alpha_response(self, wa_response):
|
||||||
|
remove = {
|
||||||
|
"queryresult": [
|
||||||
|
"datatypes",
|
||||||
|
"error",
|
||||||
|
"timedout",
|
||||||
|
"timedoutpods",
|
||||||
|
"numpods",
|
||||||
|
"timing",
|
||||||
|
"parsetiming",
|
||||||
|
"parsetimedout",
|
||||||
|
"recalculate",
|
||||||
|
"id",
|
||||||
|
"host",
|
||||||
|
"server",
|
||||||
|
"related",
|
||||||
|
"version",
|
||||||
|
{
|
||||||
|
"pods": [
|
||||||
|
"scanner",
|
||||||
|
"id",
|
||||||
|
"error",
|
||||||
|
"expressiontypes",
|
||||||
|
"states",
|
||||||
|
"infos",
|
||||||
|
"position",
|
||||||
|
"numsubpods",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"assumptions",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = wa_response.copy()
|
||||||
|
for main_key, to_remove in remove.items():
|
||||||
|
if main_key not in result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for item in to_remove:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
for sub_key, sub_items in item.items():
|
||||||
|
if sub_key == "pods":
|
||||||
|
pods = result[main_key].get(sub_key, [])
|
||||||
|
for i, pod in enumerate(pods):
|
||||||
|
if pod.get("title") == "Result":
|
||||||
|
pods = pods[: i + 1]
|
||||||
|
break
|
||||||
|
for remove_key in sub_items:
|
||||||
|
pod.pop(remove_key, None)
|
||||||
|
else:
|
||||||
|
result[main_key].pop(item, None)
|
||||||
|
|
||||||
|
return result
|
Loading…
Add table
Add a link
Reference in a new issue