mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 01:26:10 +00:00
agents to use tools api (#673)
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
parent
596afc6497
commit
a5c57cd381
116 changed files with 4959 additions and 2778 deletions
|
@ -22,6 +22,8 @@ async def get_provider_impl(
|
|||
deps[Api.memory],
|
||||
deps[Api.safety],
|
||||
deps[Api.memory_banks],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
@ -13,16 +13,16 @@ import secrets
|
|||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentToolGroup,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseEvent,
|
||||
AgentTurnResponseEventType,
|
||||
|
@ -33,25 +33,14 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
Attachment,
|
||||
CodeInterpreterToolDefinition,
|
||||
FunctionCallToolDefinition,
|
||||
Document,
|
||||
InferenceStep,
|
||||
MemoryRetrievalStep,
|
||||
MemoryToolDefinition,
|
||||
PhotogenToolDefinition,
|
||||
SearchToolDefinition,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
WolframAlphaToolDefinition,
|
||||
)
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
|
@ -62,32 +51,20 @@ from llama_stack.apis.inference import (
|
|||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
from .rag.context_retriever import generate_rag_query
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
from .tools.base import BaseTool
|
||||
from .tools.builtin import (
|
||||
CodeInterpreterTool,
|
||||
interpret_content_as_attachment,
|
||||
PhotogenTool,
|
||||
SearchTool,
|
||||
WolframAlphaTool,
|
||||
)
|
||||
from .tools.safety import SafeTool
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -98,6 +75,12 @@ def make_random_string(length: int = 8):
|
|||
)
|
||||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "query_memory"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
MEMORY_GROUP = "builtin::memory"
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -108,6 +91,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
persistence_store: KVStore,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
|
@ -118,29 +103,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.memory_banks_api = memory_banks_api
|
||||
self.safety_api = safety_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
|
||||
builtin_tools = []
|
||||
for tool_defn in agent_config.tools:
|
||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
||||
tool = WolframAlphaTool(tool_defn.api_key)
|
||||
elif isinstance(tool_defn, SearchToolDefinition):
|
||||
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
|
||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||
tool = CodeInterpreterTool()
|
||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||
tool = PhotogenTool(dump_dir=self.tempdir)
|
||||
else:
|
||||
continue
|
||||
|
||||
builtin_tools.append(
|
||||
SafeTool(
|
||||
tool,
|
||||
safety_api,
|
||||
tool_defn.input_shields,
|
||||
tool_defn.output_shields,
|
||||
)
|
||||
)
|
||||
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
|
@ -228,9 +192,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents,
|
||||
toolgroups_for_turn=request.toolgroups,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
|
@ -278,9 +243,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -297,7 +263,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
session_id, turn_id, input_messages, attachments, sampling_params, stream
|
||||
session_id,
|
||||
turn_id,
|
||||
input_messages,
|
||||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
toolgroups_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -353,6 +325,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
|
@ -373,6 +346,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
|
@ -388,73 +362,116 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||
need_rag_context = await self._should_retrieve_context(
|
||||
input_messages, attachments
|
||||
)
|
||||
if need_rag_context:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
toolgroup_args = {}
|
||||
for toolgroup in self.agent_config.toolgroups:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
if toolgroups_for_turn:
|
||||
for toolgroup in toolgroups_for_turn:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(
|
||||
session_id, documents, input_messages, tool_defs
|
||||
)
|
||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||
if memory_tool_group is None:
|
||||
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
query_args = {
|
||||
"messages": [msg.content for msg in input_messages],
|
||||
**toolgroup_args.get(memory_tool_group, {}),
|
||||
}
|
||||
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
with tracing.span("retrieve_rag_context") as span:
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session_id, input_messages, attachments
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
if "memory_bank_ids" not in query_args:
|
||||
query_args["memory_bank_ids"] = []
|
||||
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call_delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
content=ToolCall(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
arguments={},
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
args=query_args,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
arguments={},
|
||||
)
|
||||
],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
content=result.content,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", rag_context)
|
||||
span.set_attribute("bank_ids", bank_ids)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
step_details=MemoryRetrievalStep(
|
||||
turn_id=turn_id,
|
||||
step_id=step_id,
|
||||
memory_bank_ids=bank_ids,
|
||||
inserted_context=rag_context or "",
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if rag_context:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = rag_context
|
||||
|
||||
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||
# TODO: we need to migrate URL away from str type
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
urls += [
|
||||
URL(uri=a.content) for a in attachments if pattern.match(a.content)
|
||||
]
|
||||
msg = await attachment_message(self.tempdir, urls)
|
||||
input_messages.append(msg)
|
||||
span.set_attribute("output", result.content)
|
||||
span.set_attribute("error_code", result.error_code)
|
||||
span.set_attribute("error_message", result.error_message)
|
||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||
if result.error_code == 0:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = result.content
|
||||
|
||||
output_attachments = []
|
||||
|
||||
n_iter = 0
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
msg = input_messages[-1]
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -473,7 +490,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=self._get_tools(),
|
||||
tools=[
|
||||
tool
|
||||
for tool in tool_defs.values()
|
||||
if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP
|
||||
],
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
|
@ -572,9 +593,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += attachments
|
||||
message.content += output_attachments
|
||||
else:
|
||||
message.content = [message.content] + attachments
|
||||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
log.info(f"Partial message: {str(message)}")
|
||||
|
@ -582,9 +603,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
log.info(f"{str(message)}")
|
||||
tool_call = message.tool_calls[0]
|
||||
|
||||
name = tool_call.tool_name
|
||||
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
|
||||
if tool_call.tool_name in client_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
|
@ -607,16 +626,22 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"tool_name": tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
self.tools_dict,
|
||||
self.tool_runtime_api,
|
||||
session_id,
|
||||
[message],
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
|
@ -628,6 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
|
@ -647,7 +673,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
|
||||
if out_attachment := interpret_content_as_attachment(
|
||||
if out_attachment := _interpret_content_as_attachment(
|
||||
result_message.content
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
|
@ -659,6 +685,150 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
n_iter += 1
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
agent_config_toolgroups = set(
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
for toolgroup in self.agent_config.toolgroups
|
||||
)
|
||||
toolgroups_for_turn_set = (
|
||||
agent_config_toolgroups
|
||||
if toolgroups_for_turn is None
|
||||
else {
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
for toolgroup in toolgroups_for_turn
|
||||
}
|
||||
)
|
||||
|
||||
tool_def_map = {}
|
||||
tool_to_group = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_def_map.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
tool_def_map[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name in agent_config_toolgroups:
|
||||
if toolgroup_name not in toolgroups_for_turn_set:
|
||||
continue
|
||||
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
|
||||
for tool_def in tools:
|
||||
if (
|
||||
toolgroup_name.startswith("builtin")
|
||||
and toolgroup_name != MEMORY_GROUP
|
||||
):
|
||||
tool_name = tool_def.identifier
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
if tool_name == "web_search":
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
else:
|
||||
built_in_type = BuiltinTool(tool_name)
|
||||
|
||||
if tool_def_map.get(built_in_type, None):
|
||||
raise ValueError(f"Tool {built_in_type} already exists")
|
||||
|
||||
tool_def_map[built_in_type] = ToolDefinition(
|
||||
tool_name=built_in_type
|
||||
)
|
||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
if tool_def_map.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
tool_def_map[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
|
||||
return tool_def_map, tool_to_group
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
session_id: str,
|
||||
documents: List[Document],
|
||||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
for d in documents:
|
||||
if isinstance(d.content, URL):
|
||||
url_items.append(d.content)
|
||||
elif pattern.match(d.content):
|
||||
url_items.append(URL(uri=d.content))
|
||||
else:
|
||||
content_items.append(d)
|
||||
|
||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||
if code_interpreter_tool:
|
||||
for c in content_items:
|
||||
temp_file_path = os.path.join(
|
||||
self.tempdir, f"{make_random_string()}.txt"
|
||||
)
|
||||
with open(temp_file_path, "w") as temp_file:
|
||||
temp_file.write(c.content)
|
||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||
|
||||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
else:
|
||||
# if no memory or code_interpreter tool is available,
|
||||
# we try to load the data from the URLs and content items as a message to inference
|
||||
# and add it to the last message's context
|
||||
input_messages[-1].context = "\n".join(
|
||||
[doc.content for doc in content_items]
|
||||
+ await load_data_from_urls(url_items)
|
||||
)
|
||||
|
||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
|
@ -679,129 +849,39 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
return bank_id
|
||||
|
||||
async def _should_retrieve_context(
|
||||
self, messages: List[Message], attachments: List[Attachment]
|
||||
) -> bool:
|
||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||
if attachments:
|
||||
if (
|
||||
AgentTool.code_interpreter.value in enabled_tools
|
||||
and self.agent_config.tool_choice == ToolChoice.required
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
return AgentTool.memory.value in enabled_tools
|
||||
|
||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||
for t in self.agent_config.tools:
|
||||
if t.type == AgentTool.memory.value:
|
||||
return t
|
||||
|
||||
return None
|
||||
|
||||
async def _retrieve_context(
|
||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
assert memory is not None, "Memory tool not configured"
|
||||
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
||||
|
||||
if attachments:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
bank_ids.append(bank_id)
|
||||
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for a in attachments
|
||||
]
|
||||
with tracing.span("insert_documents"):
|
||||
await self.memory_api.insert_documents(bank_id, documents)
|
||||
else:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info.memory_bank_id:
|
||||
bank_ids.append(session_info.memory_bank_id)
|
||||
|
||||
if not bank_ids:
|
||||
# this can happen if the per-session memory bank is not yet populated
|
||||
# (i.e., no prior turns uploaded an Attachment)
|
||||
return None, []
|
||||
|
||||
query = await generate_rag_query(
|
||||
memory.query_generator_config, messages, inference_api=self.inference_api
|
||||
)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": 5,
|
||||
},
|
||||
async def add_to_session_memory_bank(
|
||||
self, session_id: str, data: List[Document]
|
||||
) -> None:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
for a in data
|
||||
]
|
||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
await self.memory_api.insert_documents(
|
||||
bank_id=bank_id,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: memory.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > memory.max_tokens_in_context:
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
|
||||
return (
|
||||
concat_interleaved_content(
|
||||
[
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
]
|
||||
),
|
||||
bank_ids,
|
||||
)
|
||||
|
||||
def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
for t in self.agent_config.tools:
|
||||
if isinstance(t, SearchToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
||||
elif isinstance(t, WolframAlphaToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
||||
elif isinstance(t, PhotogenToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
|
||||
elif isinstance(t, CodeInterpreterToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
|
||||
elif isinstance(t, FunctionCallToolDefinition):
|
||||
ret.append(
|
||||
ToolDefinition(
|
||||
tool_name=t.function_name,
|
||||
description=t.description,
|
||||
parameters=t.parameters,
|
||||
)
|
||||
)
|
||||
return ret
|
||||
async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
||||
data = []
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
if uri.startswith("file://"):
|
||||
filepath = uri[len("file://") :]
|
||||
with open(filepath, "r") as f:
|
||||
data.append(f.read())
|
||||
elif uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
resp = r.text
|
||||
data.append(resp)
|
||||
return data
|
||||
|
||||
|
||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||
|
@ -839,7 +919,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
messages: List[CompletionMessage],
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> List[ToolResponseMessage]:
|
||||
# While Tools.run interface takes a list of messages,
|
||||
# All tools currently only run on a single message
|
||||
|
@ -851,11 +935,45 @@ async def execute_tool_call_maybe(
|
|||
|
||||
tool_call = message.tool_calls[0]
|
||||
name = tool_call.tool_name
|
||||
assert isinstance(name, BuiltinTool)
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
raise ValueError(f"Tool {name} not found in any tool group")
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
tool_call_args = tool_call.arguments
|
||||
tool_call_args.update(toolgroup_args.get(group_name, {}))
|
||||
if isinstance(name, BuiltinTool):
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = WEB_SEARCH_TOOL
|
||||
else:
|
||||
name = name.value
|
||||
|
||||
name = name.value
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
args=dict(
|
||||
session_id=session_id,
|
||||
**tool_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
assert name in tools_dict, f"Tool {name} not found"
|
||||
tool = tools_dict[name]
|
||||
result_messages = await tool.run(messages)
|
||||
return result_messages
|
||||
return [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=result.content,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
content: str,
|
||||
) -> Optional[Attachment]:
|
||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||
if match:
|
||||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]),
|
||||
mime_type=data["mimetype"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
@ -19,17 +19,17 @@ from llama_stack.apis.agents import (
|
|||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
Attachment,
|
||||
Document,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
|
@ -47,12 +47,16 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
memory_api: Memory,
|
||||
safety_api: Safety,
|
||||
memory_banks_api: MemoryBanks,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.safety_api = safety_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
|
@ -112,6 +116,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
safety_api=self.safety_api,
|
||||
memory_api=self.memory_api,
|
||||
memory_banks_api=self.memory_banks_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
persistence_store=(
|
||||
self.persistence_store
|
||||
if agent_config.enable_session_persistence
|
||||
|
@ -141,15 +147,17 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
|
@ -8,13 +8,11 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,72 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.apis.inference import Message, UserMessage
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates a query that will be used for
|
||||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
||||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {"messages": [m.model_dump() for m in messages]}
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
response = await inference_api.chat_completion(
|
||||
model_id=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
query = response.completion_message.content
|
||||
|
||||
return query
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
Attachment,
|
||||
BuiltinTool,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
from ..tools.builtin import CodeInterpreterTool
|
||||
|
||||
|
||||
class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_matplotlib(self):
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
x = np.array([1, 1])
|
||||
y = np.array([0, 10])
|
||||
|
||||
plt.plot(x, y)
|
||||
plt.title('x = 1')
|
||||
plt.xlabel('x')
|
||||
plt.ylabel('y')
|
||||
plt.grid(True)
|
||||
plt.axvline(x=1, color='r')
|
||||
plt.show()
|
||||
"""
|
||||
message = CompletionMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_id",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
arguments={"code": code},
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_message,
|
||||
)
|
||||
ret = await tool.run([message])
|
||||
|
||||
self.assertEqual(len(ret), 1)
|
||||
|
||||
output = ret[0].content
|
||||
self.assertIsInstance(output, Attachment)
|
||||
self.assertEqual(output.mime_type, "image/png")
|
||||
|
||||
async def test_path_unlink(self):
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
dpath = Path(os.environ["MPLCONFIGDIR"])
|
||||
with open(dpath / "test", "w") as f:
|
||||
f.write("hello")
|
||||
|
||||
Path(dpath / "test").unlink()
|
||||
print("_OK_")
|
||||
"""
|
||||
message = CompletionMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_id",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
arguments={"code": code},
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_message,
|
||||
)
|
||||
ret = await tool.run([message])
|
||||
|
||||
self.assertEqual(len(ret), 1)
|
||||
|
||||
output = ret[0].content
|
||||
self.assertTrue("_OK_" in output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -4,21 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
StepType,
|
||||
)
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
|
@ -27,13 +32,24 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
|
||||
from ..agents import (
|
||||
AGENT_INSTANCES_BY_ID,
|
||||
MetaReferenceAgentsImpl,
|
||||
MetaReferenceInferenceConfig,
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolHost,
|
||||
ToolInvocationResult,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.agents import (
|
||||
MetaReferenceAgentsImpl,
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class MockInferenceAPI:
|
||||
|
@ -48,10 +64,10 @@ class MockInferenceAPI:
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
if stream:
|
||||
async def stream_response():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="start",
|
||||
|
@ -65,19 +81,7 @@ class MockInferenceAPI:
|
|||
delta="AI is a fascinating field...",
|
||||
)
|
||||
)
|
||||
# yield ChatCompletionResponseStreamChunk(
|
||||
# event=ChatCompletionResponseEvent(
|
||||
# event_type="progress",
|
||||
# delta=ToolCallDelta(
|
||||
# content=ToolCall(
|
||||
# call_id="123",
|
||||
# tool_name=BuiltinTool.brave_search.value,
|
||||
# arguments={"query": "AI history"},
|
||||
# ),
|
||||
# parse_status="success",
|
||||
# ),
|
||||
# )
|
||||
# )
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="complete",
|
||||
|
@ -85,12 +89,17 @@ class MockInferenceAPI:
|
|||
stop_reason="end_of_turn",
|
||||
)
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_response()
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
role="assistant", content="Mock response", stop_reason="end_of_turn"
|
||||
role="assistant",
|
||||
content="Mock response",
|
||||
stop_reason="end_of_turn",
|
||||
),
|
||||
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
|
||||
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -165,6 +174,98 @@ class MockMemoryAPI:
|
|||
self.documents[bank_id].pop(doc_id, None)
|
||||
|
||||
|
||||
class MockToolGroupsAPI:
|
||||
async def register_tool_group(
|
||||
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return ToolGroup(
|
||||
identifier=toolgroup_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
)
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return []
|
||||
|
||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
if tool_group_id == MEMORY_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier=MEMORY_QUERY_TOOL,
|
||||
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||
toolgroup_id=MEMORY_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::memory",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier="code_interpreter",
|
||||
provider_resource_id="code_interpreter",
|
||||
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::code_interpreter",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return Tool(
|
||||
identifier=tool_name,
|
||||
provider_resource_id=tool_name,
|
||||
toolgroup_id="mock_group",
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="mock_provider",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockToolRuntimeAPI:
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return []
|
||||
|
||||
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
|
||||
return ToolInvocationResult(content={"result": "Mock tool result"})
|
||||
|
||||
|
||||
class MockMemoryBanksAPI:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return []
|
||||
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||
return None
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank:
|
||||
return VectorMemoryBank(
|
||||
identifier=memory_bank_id,
|
||||
provider_resource_id=provider_memory_bank_id or memory_bank_id,
|
||||
embedding_model="mock_model",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
return MockInferenceAPI()
|
||||
|
@ -181,64 +282,107 @@ def mock_memory_api():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
||||
def mock_tool_groups_api():
|
||||
return MockToolGroupsAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_runtime_api():
|
||||
return MockToolRuntimeAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory_banks_api():
|
||||
return MockMemoryBanksAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_agents_impl(
|
||||
mock_inference_api,
|
||||
mock_safety_api,
|
||||
mock_memory_api,
|
||||
mock_memory_banks_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_tool_groups_api,
|
||||
):
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config=MetaReferenceInferenceConfig(),
|
||||
config=MetaReferenceAgentsImplConfig(
|
||||
persistence_store=SqliteKVStoreConfig(
|
||||
db_name=sqlite_file.name,
|
||||
),
|
||||
),
|
||||
inference_api=mock_inference_api,
|
||||
safety_api=mock_safety_api,
|
||||
memory_api=mock_memory_api,
|
||||
memory_banks_api=mock_memory_banks_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent(get_agents_impl):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
sampling_params=SamplingParams(),
|
||||
tools=[
|
||||
# SearchToolDefinition(
|
||||
# name="brave_search",
|
||||
# api_key="test_key",
|
||||
# ),
|
||||
],
|
||||
toolgroups=[],
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
agent = AGENT_INSTANCES_BY_ID[response.agent_id]
|
||||
return agent
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
MEMORY_TOOLGROUP = "builtin::memory"
|
||||
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent_with_tools(get_agents_impl, request):
|
||||
impl = await get_agents_impl
|
||||
toolgroups = request.param
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_create_session(chat_agent):
|
||||
session = chat_agent.create_session("Test Session")
|
||||
assert session.session_name == "Test Session"
|
||||
assert session.turns == []
|
||||
assert session.session_id in chat_agent.sessions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_create_and_execute_turn(chat_agent):
|
||||
session = chat_agent.create_session("Test Session")
|
||||
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id="random",
|
||||
session_id=session.session_id,
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Hello")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
print(responses)
|
||||
assert len(responses) > 0
|
||||
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
|
||||
assert (
|
||||
len(responses) == 7
|
||||
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
|
||||
assert responses[0].event.payload.turn_id is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_multiple_shields_wrapper(chat_agent):
|
||||
async def test_run_multiple_shields_wrapper(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
messages = [UserMessage(content="Test message")]
|
||||
shields = ["test_shield"]
|
||||
|
||||
|
@ -254,69 +398,95 @@ async def test_run_multiple_shields_wrapper(chat_agent):
|
|||
|
||||
assert len(responses) == 2 # StepStart, StepComplete
|
||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||
assert not responses[1].event.payload.step_details.response.is_violation
|
||||
assert not responses[1].event.payload.step_details.violation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
|
||||
async def test_chat_agent_complex_turn(chat_agent):
|
||||
# Setup
|
||||
session = chat_agent.create_session("Test Session")
|
||||
async def test_chat_agent_complex_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id="random",
|
||||
session_id=session.session_id,
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Execute the turn
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
# Assertions
|
||||
assert len(responses) > 0
|
||||
|
||||
# Check for the presence of different step types
|
||||
step_types = [
|
||||
response.event.payload.step_type
|
||||
for response in responses
|
||||
if hasattr(response.event.payload, "step_type")
|
||||
]
|
||||
|
||||
assert "shield_call" in step_types, "Shield call step is missing"
|
||||
assert "inference" in step_types, "Inference step is missing"
|
||||
assert "tool_execution" in step_types, "Tool execution step is missing"
|
||||
assert StepType.shield_call in step_types, "Shield call step is missing"
|
||||
assert StepType.inference in step_types, "Inference step is missing"
|
||||
|
||||
# Check for the presence of start and complete events
|
||||
event_types = [
|
||||
response.event.payload.event_type
|
||||
for response in responses
|
||||
if hasattr(response.event.payload, "event_type")
|
||||
]
|
||||
assert "start" in event_types, "Start event is missing"
|
||||
assert "complete" in event_types, "Complete event is missing"
|
||||
assert "turn_start" in event_types, "Start event is missing"
|
||||
assert "turn_complete" in event_types, "Complete event is missing"
|
||||
|
||||
# Check for the presence of tool call
|
||||
tool_calls = [
|
||||
response.event.payload.tool_call
|
||||
for response in responses
|
||||
if hasattr(response.event.payload, "tool_call")
|
||||
]
|
||||
assert any(
|
||||
tool_call
|
||||
for tool_call in tool_calls
|
||||
if tool_call and tool_call.content.get("name") == "memory"
|
||||
), "Memory tool call is missing"
|
||||
|
||||
# Check for the final turn complete event
|
||||
assert any(
|
||||
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||
for response in responses
|
||||
), "Turn complete event is missing"
|
||||
turn_complete_payload = next(
|
||||
response.event.payload
|
||||
for response in responses
|
||||
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||
)
|
||||
turn = turn_complete_payload.turn
|
||||
assert turn.input_messages == request.messages, "Input messages do not match"
|
||||
|
||||
# Verify the turn was added to the session
|
||||
assert len(session.turns) == 1, "Turn was not added to the session"
|
||||
assert (
|
||||
session.turns[0].input_messages == request.messages
|
||||
), "Input messages do not match"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"toolgroups, expected_memory, expected_code_interpreter",
|
||||
[
|
||||
([], False, False), # no tools
|
||||
([MEMORY_TOOLGROUP], True, False), # memory only
|
||||
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
|
||||
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
||||
],
|
||||
)
|
||||
async def test_chat_agent_tools(
|
||||
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
|
||||
):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
chat_agent = await impl.get_agent(response.agent_id)
|
||||
|
||||
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||
if expected_memory:
|
||||
assert MEMORY_QUERY_TOOL in tool_defs
|
||||
if expected_code_interpreter:
|
||||
assert BuiltinTool.code_interpreter in tool_defs
|
||||
if expected_memory and expected_code_interpreter:
|
||||
# override the tools for turn
|
||||
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||
toolgroups_for_turn=[
|
||||
AgentToolGroupWithArgs(
|
||||
name=MEMORY_TOOLGROUP,
|
||||
args={"memory_banks": ["test_memory_bank"]},
|
||||
)
|
||||
]
|
||||
)
|
||||
assert MEMORY_QUERY_TOOL in new_tool_defs
|
||||
assert BuiltinTool.code_interpreter not in new_tool_defs
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
raise NotImplementedError
|
|
@ -1,396 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from .ipython_tool.code_execution import (
|
||||
CodeExecutionContext,
|
||||
CodeExecutionRequest,
|
||||
CodeExecutor,
|
||||
TOOLS_ATTACHMENT_KEY_REGEX,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||
if match:
|
||||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class SingleMessageBuiltinTool(BaseTool):
|
||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
||||
assert len(messages) == 1, f"Expected single message, got {len(messages)}"
|
||||
|
||||
message = messages[0]
|
||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
||||
|
||||
tool_call = messages[0].tool_calls[0]
|
||||
|
||||
query = tool_call.arguments["query"]
|
||||
response: str = await self.run_impl(query)
|
||||
|
||||
message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=response,
|
||||
)
|
||||
return [message]
|
||||
|
||||
@abstractmethod
|
||||
async def run_impl(self, query: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PhotogenTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, dump_dir: str) -> None:
|
||||
self.dump_dir = dump_dir
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.photogen.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
"""
|
||||
Implement this to give the model an ability to generate images.
|
||||
|
||||
Return:
|
||||
info = {
|
||||
"filepath": str(image_filepath),
|
||||
"mimetype": "image/png",
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SearchTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
self.engine_type = engine
|
||||
if engine == SearchEngineType.bing:
|
||||
self.engine = BingSearch(api_key, **kwargs)
|
||||
elif engine == SearchEngineType.brave:
|
||||
self.engine = BraveSearch(api_key, **kwargs)
|
||||
elif engine == SearchEngineType.tavily:
|
||||
self.engine = TavilySearch(api_key, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown search engine: {engine}")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.brave_search.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
return await self.engine.search(query)
|
||||
|
||||
|
||||
class BingSearch:
|
||||
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
}
|
||||
params = {
|
||||
"count": self.top_k,
|
||||
"textDecorations": True,
|
||||
"textFormat": "HTML",
|
||||
"q": query,
|
||||
}
|
||||
|
||||
response = requests.get(url=url, params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
clean = self._clean_response(response.json())
|
||||
return json.dumps(clean)
|
||||
|
||||
def _clean_response(self, search_response):
|
||||
clean_response = []
|
||||
query = search_response["queryContext"]["originalQuery"]
|
||||
if "webPages" in search_response:
|
||||
pages = search_response["webPages"]["value"]
|
||||
for p in pages:
|
||||
selected_keys = {"name", "url", "snippet"}
|
||||
clean_response.append(
|
||||
{k: v for k, v in p.items() if k in selected_keys}
|
||||
)
|
||||
if "news" in search_response:
|
||||
clean_news = []
|
||||
news = search_response["news"]["value"]
|
||||
for n in news:
|
||||
selected_keys = {"name", "url", "description"}
|
||||
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||
|
||||
clean_response.append(clean_news)
|
||||
|
||||
return {"query": query, "top_k": clean_response}
|
||||
|
||||
|
||||
class BraveSearch:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key,
|
||||
"Accept-Encoding": "gzip",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": query}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
return json.dumps(self._clean_brave_response(response.json()))
|
||||
|
||||
def _clean_brave_response(self, search_response, top_k=3):
|
||||
query = None
|
||||
clean_response = []
|
||||
if "query" in search_response:
|
||||
if "original" in search_response["query"]:
|
||||
query = search_response["query"]["original"]
|
||||
if "mixed" in search_response:
|
||||
mixed_results = search_response["mixed"]
|
||||
for m in mixed_results["main"][:top_k]:
|
||||
r_type = m["type"]
|
||||
results = search_response[r_type]["results"]
|
||||
if r_type == "web":
|
||||
# For web data - add a single output from the search
|
||||
idx = m["index"]
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"date",
|
||||
"extra_snippets",
|
||||
]
|
||||
cleaned = {
|
||||
k: v for k, v in results[idx].items() if k in selected_keys
|
||||
}
|
||||
elif r_type == "faq":
|
||||
# For faw data - take a list of all the questions & answers
|
||||
selected_keys = ["type", "question", "answer", "title", "url"]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
elif r_type == "infobox":
|
||||
idx = m["index"]
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"long_desc",
|
||||
]
|
||||
cleaned = {
|
||||
k: v for k, v in results[idx].items() if k in selected_keys
|
||||
}
|
||||
elif r_type == "videos":
|
||||
selected_keys = [
|
||||
"type",
|
||||
"url",
|
||||
"title",
|
||||
"description",
|
||||
"date",
|
||||
]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
elif r_type == "locations":
|
||||
# For faw data - take a list of all the questions & answers
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"coordinates",
|
||||
"postal_address",
|
||||
"contact",
|
||||
"rating",
|
||||
"distance",
|
||||
"zoom_level",
|
||||
]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
elif r_type == "news":
|
||||
# For faw data - take a list of all the questions & answers
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
else:
|
||||
cleaned = []
|
||||
|
||||
clean_response.append(cleaned)
|
||||
|
||||
return {"query": query, "top_k": clean_response}
|
||||
|
||||
|
||||
class TavilySearch:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
response = requests.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": self.api_key, "query": query},
|
||||
)
|
||||
return json.dumps(self._clean_tavily_response(response.json()))
|
||||
|
||||
def _clean_tavily_response(self, search_response, top_k=3):
|
||||
return {"query": search_response["query"], "top_k": search_response["results"]}
|
||||
|
||||
|
||||
class WolframAlphaTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.url = "https://api.wolframalpha.com/v2/query"
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.wolfram_alpha.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
params = {
|
||||
"input": query,
|
||||
"appid": self.api_key,
|
||||
"format": "plaintext",
|
||||
"output": "json",
|
||||
}
|
||||
response = requests.get(
|
||||
self.url,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
||||
|
||||
def _clean_wolfram_alpha_response(self, wa_response):
|
||||
remove = {
|
||||
"queryresult": [
|
||||
"datatypes",
|
||||
"error",
|
||||
"timedout",
|
||||
"timedoutpods",
|
||||
"numpods",
|
||||
"timing",
|
||||
"parsetiming",
|
||||
"parsetimedout",
|
||||
"recalculate",
|
||||
"id",
|
||||
"host",
|
||||
"server",
|
||||
"related",
|
||||
"version",
|
||||
{
|
||||
"pods": [
|
||||
"scanner",
|
||||
"id",
|
||||
"error",
|
||||
"expressiontypes",
|
||||
"states",
|
||||
"infos",
|
||||
"position",
|
||||
"numsubpods",
|
||||
]
|
||||
},
|
||||
"assumptions",
|
||||
],
|
||||
}
|
||||
for main_key in remove:
|
||||
for key_to_remove in remove[main_key]:
|
||||
try:
|
||||
if key_to_remove == "assumptions":
|
||||
if "assumptions" in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
if isinstance(key_to_remove, dict):
|
||||
for sub_key in key_to_remove:
|
||||
if sub_key == "pods":
|
||||
for i in range(len(wa_response[main_key][sub_key])):
|
||||
if (
|
||||
wa_response[main_key][sub_key][i]["title"]
|
||||
== "Result"
|
||||
):
|
||||
del wa_response[main_key][sub_key][i + 1 :]
|
||||
break
|
||||
sub_items = wa_response[main_key][sub_key]
|
||||
for i in range(len(sub_items)):
|
||||
for sub_key_to_remove in key_to_remove[sub_key]:
|
||||
if sub_key_to_remove in sub_items[i]:
|
||||
del sub_items[i][sub_key_to_remove]
|
||||
elif key_to_remove in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
except KeyError:
|
||||
pass
|
||||
return wa_response
|
||||
|
||||
|
||||
class CodeInterpreterTool(BaseTool):
|
||||
def __init__(self) -> None:
|
||||
ctx = CodeExecutionContext(
|
||||
matplotlib_dump_dir=tempfile.mkdtemp(),
|
||||
)
|
||||
self.code_executor = CodeExecutor(ctx)
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.code_interpreter.value
|
||||
|
||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
||||
message = messages[0]
|
||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
||||
|
||||
tool_call = messages[0].tool_calls[0]
|
||||
script = tool_call.arguments["code"]
|
||||
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
res = self.code_executor.execute(req)
|
||||
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
res_out = res[out_type]
|
||||
if res_out != "":
|
||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||
if out_type == "stderr":
|
||||
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||
|
||||
message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content="\n".join(pieces),
|
||||
)
|
||||
return [message]
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,133 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
|
||||
# Disabling potentially dangerous functions
|
||||
import os as _os
|
||||
from functools import partial
|
||||
|
||||
os_funcs_to_disable = [
|
||||
"kill",
|
||||
"system",
|
||||
"putenv",
|
||||
"remove",
|
||||
"removedirs",
|
||||
"rmdir",
|
||||
"fchdir",
|
||||
"setuid",
|
||||
"fork",
|
||||
"forkpty",
|
||||
"killpg",
|
||||
"rename",
|
||||
"renames",
|
||||
"truncate",
|
||||
"replace",
|
||||
# "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
|
||||
"fchmod",
|
||||
"fchown",
|
||||
"chmod",
|
||||
"chown",
|
||||
"chroot",
|
||||
"fchdir",
|
||||
"lchflags",
|
||||
"lchmod",
|
||||
"lchown",
|
||||
"chdir",
|
||||
]
|
||||
|
||||
|
||||
def call_not_allowed(*args, **kwargs):
|
||||
raise OSError(errno.EPERM, "Call are not permitted in this environment")
|
||||
|
||||
|
||||
for func_name in os_funcs_to_disable:
|
||||
if hasattr(_os, func_name):
|
||||
setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
|
||||
|
||||
import shutil as _shutil
|
||||
|
||||
for func_name in ["rmtree", "move", "chown"]:
|
||||
if hasattr(_shutil, func_name):
|
||||
setattr(
|
||||
_shutil,
|
||||
func_name,
|
||||
partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
|
||||
)
|
||||
|
||||
import subprocess as _subprocess
|
||||
|
||||
|
||||
def popen_not_allowed(*args, **kwargs):
|
||||
raise _subprocess.CalledProcessError(
|
||||
-1,
|
||||
args[0] if args else "unknown",
|
||||
stderr="subprocess.Popen is not allowed in this environment",
|
||||
)
|
||||
|
||||
|
||||
_subprocess.Popen = popen_not_allowed
|
||||
|
||||
|
||||
import atexit as _atexit
|
||||
import builtins as _builtins
|
||||
import io as _io
|
||||
import json as _json
|
||||
import sys as _sys
|
||||
|
||||
# NB! The following "unused" imports crucial, make sure not not to remove
|
||||
# them with linters - they're used in code_execution.py
|
||||
from contextlib import ( # noqa
|
||||
contextmanager as _contextmanager,
|
||||
redirect_stderr as _redirect_stderr,
|
||||
redirect_stdout as _redirect_stdout,
|
||||
)
|
||||
from multiprocessing.connection import Connection as _Connection
|
||||
|
||||
# Mangle imports to avoid polluting model execution namespace.
|
||||
|
||||
_IO_SINK = _io.StringIO()
|
||||
_NETWORK_TIMEOUT = 5
|
||||
_NETWORK_CONNECTIONS = None
|
||||
|
||||
|
||||
def _open_connections():
|
||||
global _NETWORK_CONNECTIONS
|
||||
if _NETWORK_CONNECTIONS is not None:
|
||||
# Ensure connections only opened once.
|
||||
return _NETWORK_CONNECTIONS
|
||||
req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
|
||||
req_con = _Connection(int(req_w_fd), readable=False)
|
||||
resp_con = _Connection(int(resp_r_fd), writable=False)
|
||||
_NETWORK_CONNECTIONS = (req_con, resp_con)
|
||||
return _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
_builtins._open_connections = _open_connections
|
||||
|
||||
|
||||
@_atexit.register
|
||||
def _close_connections():
|
||||
global _NETWORK_CONNECTIONS
|
||||
if _NETWORK_CONNECTIONS is None:
|
||||
return
|
||||
for con in _NETWORK_CONNECTIONS:
|
||||
con.close()
|
||||
del _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
def _network_call(request):
|
||||
# NOTE: We communicate with the parent process in json, encoded
|
||||
# in raw bytes. We do this because native send/recv methods use
|
||||
# pickle which involves execution of arbitrary code.
|
||||
_open_connections()
|
||||
req_con, resp_con = _NETWORK_CONNECTIONS
|
||||
|
||||
req_con.send_bytes(_json.dumps(request).encode("utf-8"))
|
||||
if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
|
||||
raise Exception(f"Network request timed out: {_json.dumps(request)}")
|
||||
else:
|
||||
return _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
|
@ -1,256 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .utils import get_code_env_prefix
|
||||
|
||||
TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
|
||||
DIRNAME = Path(__file__).parent
|
||||
|
||||
CODE_EXEC_TIMEOUT = 20
|
||||
CODE_ENV_PREFIX = get_code_env_prefix()
|
||||
|
||||
STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
|
||||
with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
|
||||
{code}\
|
||||
"""
|
||||
|
||||
TRYEXCEPT_WRAPPER_TEMPLATE = """\
|
||||
try:
|
||||
{code}
|
||||
except:
|
||||
pass\
|
||||
"""
|
||||
|
||||
|
||||
def generate_bwrap_command(bind_dirs: List[str]) -> str:
|
||||
"""
|
||||
Generate the bwrap command string for binding all
|
||||
directories in the current directory read-only.
|
||||
"""
|
||||
bwrap_args = ""
|
||||
bwrap_args += "--ro-bind / / "
|
||||
# Add the --dev flag to mount device files
|
||||
bwrap_args += "--dev /dev "
|
||||
for d in bind_dirs:
|
||||
bwrap_args += f"--bind {d} {d} "
|
||||
|
||||
# Add the --unshare-all flag to isolate the sandbox from the rest of the system
|
||||
bwrap_args += "--unshare-all "
|
||||
# Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
|
||||
bwrap_args += "--die-with-parent "
|
||||
return bwrap_args
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeExecutionContext:
|
||||
matplotlib_dump_dir: str
|
||||
use_proxy: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeExecutionRequest:
|
||||
scripts: List[str]
|
||||
only_last_cell_stdouterr: bool = True
|
||||
only_last_cell_fail: bool = True
|
||||
seed: int = 0
|
||||
strip_fpaths_in_stderr: bool = True
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
def __init__(self, context: CodeExecutionContext):
|
||||
self.context = context
|
||||
|
||||
def execute(self, req: CodeExecutionRequest) -> dict:
|
||||
scripts = req.scripts
|
||||
for i in range(len(scripts) - 1):
|
||||
if req.only_last_cell_stdouterr:
|
||||
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
|
||||
code=textwrap.indent(scripts[i], " " * 4)
|
||||
)
|
||||
if req.only_last_cell_fail:
|
||||
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
|
||||
code=textwrap.indent(scripts[i], " " * 4)
|
||||
)
|
||||
|
||||
# Seeds prefix:
|
||||
seed = req.seed
|
||||
seeds_prefix = f"""\
|
||||
def _set_seeds():
|
||||
import random
|
||||
random.seed({seed})
|
||||
import numpy as np
|
||||
np.random.seed({seed})
|
||||
_set_seeds()\
|
||||
"""
|
||||
|
||||
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
|
||||
with tempfile.TemporaryDirectory() as dpath:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
code_fpath = os.path.join(dpath, "code.py")
|
||||
with open(code_fpath, "w") as f:
|
||||
f.write(script)
|
||||
|
||||
try:
|
||||
python_path = os.environ.get("PYTHONPATH", "")
|
||||
env = dict(
|
||||
os.environ,
|
||||
PYTHONHASHSEED=str(seed),
|
||||
MPLCONFIGDIR=dpath,
|
||||
MPLBACKEND="module://matplotlib_custom_backend",
|
||||
PYTHONPATH=f"{DIRNAME}:{python_path}",
|
||||
)
|
||||
stdout, stderr, returncode = do_subprocess(
|
||||
cmd=cmd,
|
||||
env=env,
|
||||
ctx=self.context,
|
||||
)
|
||||
|
||||
stderr = stderr.strip()
|
||||
if req.strip_fpaths_in_stderr:
|
||||
pattern = r'File "([^"]+)", line (\d+)'
|
||||
stderr = re.sub(pattern, r"line \2", stderr)
|
||||
|
||||
return {
|
||||
"process_status": "completed",
|
||||
"returncode": returncode,
|
||||
"stdout": stdout.strip(),
|
||||
"stderr": stderr,
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"process_status": "timeout",
|
||||
"stdout": "Timed out",
|
||||
"stderr": "Timed out",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"process_status": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"stderr": str(e),
|
||||
"stdout": str(e),
|
||||
}
|
||||
|
||||
|
||||
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||
image_data = response["image_data"]
|
||||
# Convert the base64 string to a bytes object
|
||||
images = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||
# Create a list of PIL images from the bytes objects
|
||||
images = [Image.open(BytesIO(img)) for img in images]
|
||||
# Create a list of image paths
|
||||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
# create new directory for each day to better organize data:
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
||||
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||
# save image into a file
|
||||
dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
|
||||
dump_fpath = dump_dpath / dump_fname
|
||||
img.save(dump_fpath, "PNG")
|
||||
image_paths.append(str(dump_fpath))
|
||||
|
||||
# this is kind of convoluted, we send back this response to the subprocess which
|
||||
# prints it out
|
||||
info = {
|
||||
"filepath": str(image_paths[-1]),
|
||||
"mimetype": "image/png",
|
||||
}
|
||||
return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
|
||||
|
||||
|
||||
def execute_subprocess_request(request, ctx: CodeExecutionContext):
|
||||
"Route requests from the subprocess (via network Pipes) to the internet/tools."
|
||||
if request["type"] == "matplotlib":
|
||||
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
|
||||
else:
|
||||
raise Exception(f'Unrecognised network request type: {request["type"]}')
|
||||
|
||||
|
||||
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
|
||||
# Create Pipes to be used for any external tool/network requests.
|
||||
req_r, req_w = multiprocessing.Pipe(duplex=False)
|
||||
resp_r, resp_w = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
cmd += [str(req_w.fileno()), str(resp_r.fileno())]
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
pass_fds=(req_w.fileno(), resp_r.fileno()),
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
close_fds=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Close unnecessary fds.
|
||||
req_w.close()
|
||||
resp_r.close()
|
||||
|
||||
pipe_close = False
|
||||
done_read = False
|
||||
start = time.monotonic()
|
||||
while proc.poll() is None and not pipe_close:
|
||||
if req_r.poll(0.1):
|
||||
# NB: Python pipe semantics for poll and recv mean that
|
||||
# poll() returns True is a pipe is closed.
|
||||
# CF old school PEP from '09
|
||||
# https://bugs.python.org/issue5573
|
||||
try:
|
||||
request = json.loads(req_r.recv_bytes().decode("utf-8"))
|
||||
response = execute_subprocess_request(request, ctx)
|
||||
|
||||
resp_w.send_bytes(json.dumps(response).encode("utf-8"))
|
||||
except EOFError:
|
||||
# The request pipe is closed - set a marker to exit
|
||||
# after the next attempt at reading stdout/stderr.
|
||||
pipe_close = True
|
||||
|
||||
try:
|
||||
# If lots has been printed, pipe might be full but
|
||||
# proc cannot exit until all the stdout/stderr
|
||||
# been written/read.
|
||||
stdout, stderr = proc.communicate(timeout=0.3)
|
||||
done_read = True
|
||||
except subprocess.TimeoutExpired:
|
||||
# The program has not terminated. Ignore it, there
|
||||
# may be more network/tool requests.
|
||||
continue
|
||||
if time.monotonic() - start > CODE_EXEC_TIMEOUT:
|
||||
proc.terminate()
|
||||
raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
|
||||
|
||||
if not done_read:
|
||||
# Solve race condition where process terminates before
|
||||
# we hit the while loop.
|
||||
stdout, stderr = proc.communicate(timeout=0.3)
|
||||
|
||||
resp_w.close()
|
||||
req_r.close()
|
||||
return stdout, stderr, proc.returncode
|
|
@ -1,90 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
A custom Matplotlib backend that overrides the show method to return image bytes.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json as _json
|
||||
import logging
|
||||
|
||||
import matplotlib
|
||||
from matplotlib.backend_bases import FigureManagerBase
|
||||
|
||||
# Import necessary components from Matplotlib
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomFigureCanvas(FigureCanvasAgg):
|
||||
def show(self):
|
||||
# Save the figure to a BytesIO object
|
||||
buf = io.BytesIO()
|
||||
self.print_png(buf)
|
||||
image_bytes = buf.getvalue()
|
||||
buf.close()
|
||||
return image_bytes
|
||||
|
||||
|
||||
class CustomFigureManager(FigureManagerBase):
|
||||
def __init__(self, canvas, num):
|
||||
super().__init__(canvas, num)
|
||||
|
||||
|
||||
# Mimic module initialization that integrates with the Matplotlib backend system
|
||||
def _create_figure_manager(num, *args, **kwargs):
|
||||
"""
|
||||
Create a custom figure manager instance.
|
||||
"""
|
||||
FigureClass = kwargs.pop("FigureClass", None) # noqa: N806
|
||||
if FigureClass is None:
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
FigureClass = Figure # noqa: N806
|
||||
fig = FigureClass(*args, **kwargs)
|
||||
canvas = CustomFigureCanvas(fig)
|
||||
manager = CustomFigureManager(canvas, num)
|
||||
return manager
|
||||
|
||||
|
||||
def show():
|
||||
"""
|
||||
Handle all figures and potentially return their images as bytes.
|
||||
|
||||
This function iterates over all figures registered with the custom backend,
|
||||
renders them as images in bytes format, and could return a list of bytes objects,
|
||||
one for each figure, or handle them as needed.
|
||||
"""
|
||||
image_data = []
|
||||
for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers():
|
||||
# Get the figure from the manager
|
||||
fig = manager.canvas.figure
|
||||
buf = io.BytesIO() # Create a buffer for the figure
|
||||
fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format
|
||||
buf.seek(0) # Go to the beginning of the buffer
|
||||
image_bytes = buf.getvalue() # Retrieve bytes value
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_data.append({"image_base64": image_base64})
|
||||
buf.close()
|
||||
|
||||
req_con, resp_con = _open_connections()
|
||||
|
||||
_json_dump = _json.dumps(
|
||||
{
|
||||
"type": "matplotlib",
|
||||
"image_data": image_data,
|
||||
}
|
||||
)
|
||||
req_con.send_bytes(_json_dump.encode("utf-8"))
|
||||
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
||||
log.info(resp)
|
||||
|
||||
|
||||
FigureCanvas = CustomFigureCanvas
|
||||
FigureManager = CustomFigureManager
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
|
||||
CODE_ENV_PREFIX = None
|
||||
|
||||
|
||||
def get_code_env_prefix() -> str:
|
||||
global CODE_ENV_PREFIX
|
||||
|
||||
if CODE_ENV_PREFIX is None:
|
||||
with open(CODE_ENV_PREFIX_FILE, "r") as f:
|
||||
CODE_ENV_PREFIX = f.read()
|
||||
|
||||
return CODE_ENV_PREFIX
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from ..safety import ShieldRunnerMixin
|
||||
from .builtin import BaseTool
|
||||
|
||||
|
||||
class SafeTool(BaseTool, ShieldRunnerMixin):
|
||||
"""A tool that makes other tools safety enabled"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
safety_api: Safety,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
):
|
||||
self._tool = tool
|
||||
ShieldRunnerMixin.__init__(
|
||||
self, safety_api, input_shields=input_shields, output_shields=output_shields
|
||||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self._tool.get_name()
|
||||
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
if self.input_shields:
|
||||
await self.run_multiple_shields(messages, self.input_shields)
|
||||
# run the underlying tool
|
||||
res = await self._tool.run(messages)
|
||||
if self.output_shields:
|
||||
await self.run_multiple_shields(messages, self.output_shields)
|
||||
|
||||
return res
|
Loading…
Add table
Add a link
Reference in a new issue