API Updates (#73)

* API Keys passed from Client instead of distro configuration

* delete distribution registry

* Rename the "package" word away

* Introduce a "Router" layer for providers

Some providers need to be factorized and considered as thin routing
layers on top of other providers. Consider two examples:

- The inference API should be a routing layer over inference providers,
  routed using the "model" key
- The memory banks API is another instance where various memory bank
  types will be provided by independent providers (e.g., a vector store
  is served by Chroma while a keyvalue memory can be served by Redis or
  PGVector)

This commit introduces a generalized routing layer for this purpose.

* update `apis_to_serve`

* llama_toolchain -> llama_stack

* Codemod from llama_toolchain -> llama_stack

- added providers/registry
- cleaned up api/ subdirectories and moved impls away
- restructured api/api.py
- from llama_stack.apis.<api> import foo should work now
- update imports to do llama_stack.apis.<api>
- update many other imports
- added __init__, fixed some registry imports
- updated registry imports
- create_agentic_system -> create_agent
- AgenticSystem -> Agent

* Moved some stuff out of common/; re-generated OpenAPI spec

* llama-toolchain -> llama-stack (hyphens)

* add control plane API

* add redis adapter + sqlite provider

* move core -> distribution

* Some more toolchain -> stack changes

* small naming shenanigans

* Removing custom tool and agent utilities and moving them client side

* Move control plane to distribution server for now

* Remove control plane from API list

* no codeshield dependency randomly plzzzzz

* Add "fire" as a dependency

* add back event loggers

* stack configure fixes

* use brave instead of bing in the example client

* add init file so it gets packaged

* add init files so it gets packaged

* Update MANIFEST

* bug fix

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Xi Yan <xiyan@meta.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-09-17 19:51:35 -07:00 committed by GitHub
parent f294eac5f5
commit 9487ad8294
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
213 changed files with 1725 additions and 1204 deletions

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceImplConfig
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
from .agents import MetaReferenceAgentsImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgentsImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
)
await impl.initialize()
return impl

View file

@ -0,0 +1,793 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import copy
import os
import secrets
import shutil
import string
import tempfile
import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Tuple
from urllib.parse import urlparse
import httpx
from termcolor import cprint
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool
from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
agent_config: AgentConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
max_infer_iters: int = 10,
):
self.agent_config = agent_config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.tempdir = tempfile.mkdtemp()
self.sessions = {}
ShieldRunnerMixin.__init__(
self,
safety_api,
input_shields=agent_config.input_shields,
output_shields=agent_config.output_shields,
)
def __del__(self):
shutil.rmtree(self.tempdir)
def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = []
# We do not want to keep adding RAG context to the input messages
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
for m in turn.input_messages:
msg = m.copy()
if isinstance(msg, UserMessage):
msg.context = None
messages.append(msg)
# messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
call_id=response.call_id,
tool_name=response.tool_name,
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
def create_session(self, name: str) -> Session:
session_id = str(uuid.uuid4())
session = Session(
session_id=session_id,
session_name=name,
turns=[],
started_at=datetime.now(),
)
self.sessions[session_id] = session
return session
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert (
request.session_id in self.sessions
), f"Session {request.session_id} not found"
session = self.sessions[request.session_id]
messages = []
for i, turn in enumerate(session.turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
# print("processed dialog ======== ")
# print_dialog(messages)
turn_id = str(uuid.uuid4())
start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
steps = []
output_message = None
async for chunk in self.run(
session=session,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
cprint(
f"{chunk.role.capitalize()}: {chunk.content}",
"white",
attrs=["bold"],
)
output_message = chunk
continue
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
session.turns.append(turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def run(
self,
session: Session,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
session, turn_id, input_messages, attachments, sampling_params, stream
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def run_shields_wrapper(
self,
turn_id: str,
messages: List[Message],
shields: List[ShieldDefinition],
touchpoint: str,
) -> AsyncGenerator:
if len(shields) == 0:
return
step_id = str(uuid.uuid4())
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
)
)
await self.run_shields(messages, shields)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=e.response,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
),
)
)
)
async def _run(
self,
session: Session,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
enabled_tools = set(t.type for t in self.agent_config.tools)
need_rag_context = await self._should_retrieve_context(
input_messages, attachments
)
if need_rag_context:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
)
)
)
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context(
session, input_messages, attachments
)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
step_details=MemoryRetrievalStep(
turn_id=turn_id,
step_id=step_id,
memory_bank_ids=bank_ids,
inserted_context=rag_context or "",
),
)
)
)
if rag_context:
last_message = input_messages[-1]
last_message.context = "\n".join(rag_context)
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
msg = await attachment_message(self.tempdir, urls)
input_messages.append(msg)
output_attachments = []
n_iter = 0
while True:
msg = input_messages[-1]
if msg.role == Role.user.value:
color = "blue"
elif msg.role == Role.ipython.value:
color = "yellow"
else:
color = None
cprint(f"{str(msg)}", color=color)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.inference.value,
step_id=step_id,
)
)
)
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
)
)
)
elif isinstance(delta, str):
content += delta
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
)
)
)
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
if event.stop_reason is not None:
stop_reason = event.stop_reason
stop_reason = stop_reason or StopReason.out_of_tokens
message = CompletionMessage(
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.inference.value,
step_id=step_id,
step_details=InferenceStep(
# somewhere deep, we are re-assigning message or closing over some
# variable which causes message to mutate later on. fix with a
# `deepcopy` for now, but this is symptomatic of a deeper issue.
step_id=step_id,
turn_id=turn_id,
model_response=copy.deepcopy(message),
),
)
)
)
if n_iter >= self.max_infer_iters:
cprint("Done with MAX iterations, exiting.")
yield message
break
if stop_reason == StopReason.out_of_tokens:
cprint("Out of token budget, exiting.")
yield message
break
if len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += attachments
else:
message.content = [message.content] + attachments
yield message
else:
cprint(f"Partial message: {str(message)}", color="green")
input_messages = input_messages + [message]
else:
cprint(f"{str(message)}", color="green")
try:
tool_call = message.tool_calls[0]
name = tool_call.tool_name
if not isinstance(name, BuiltinTool):
yield message
return
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
)
)
)
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=result_message.call_id,
tool_name=result_message.tool_name,
content=result_message.content,
)
],
),
)
)
)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
),
)
)
)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=e.response,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
return
if out_attachment := interpret_content_as_attachment(
result_message.content
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(out_attachment)
input_messages = input_messages + [message, result_message]
n_iter += 1
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
if session.memory_bank is None:
session.memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session.session_id}",
config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
return session.memory_bank
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments:
if (
AgentTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required
):
return False
else:
return True
return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools:
if t.type == AgentTool.memory.value:
return t
return None
async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment]
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
assert memory is not None, "Memory tool not configured"
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank = await self._ensure_memory_bank(session)
bank_ids.append(bank.bank_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
]
await self.memory_api.insert_documents(bank.bank_id, documents)
elif session.memory_bank:
bank_ids.append(session.memory_bank.bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated
# (i.e., no prior turns uploaded an Attachment)
return None, []
query = await generate_rag_query(
memory.query_generator_config, messages, inference_api=self.inference_api
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
},
)
for bank_id in bank_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
if not chunks:
return None, bank_ids
tokens = 0
picked = []
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
cprint(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
"red",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
], bank_ids
def _get_tools(self) -> List[ToolDefinition]:
ret = []
for t in self.agent_config.tools:
if isinstance(t, SearchToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
elif isinstance(t, WolframAlphaToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
elif isinstance(t, PhotogenToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
elif isinstance(t, CodeInterpreterToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
elif isinstance(t, FunctionCallToolDefinition):
ret.append(
ToolDefinition(
tool_name=t.function_name,
description=t.description,
parameters=t.parameters,
)
)
return ret
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
content = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
elif uri.startswith("http"):
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
print(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
else:
raise ValueError(f"Unsupported URL {url}")
content.append(f'# There is a file accessible to you at "{filepath}"\n')
return ToolResponseMessage(
call_id="",
tool_name=BuiltinTool.code_interpreter,
content=content,
)
async def execute_tool_call_maybe(
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
# When this changes, we can drop this assert
# Whether to call tools on each message and aggregate
# or aggregate and call tool once, reamins to be seen.
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
name = tool_call.tool_name
assert isinstance(name, BuiltinTool)
name = name.value
assert name in tools_dict, f"Tool {name} not found"
tool = tools_dict[name]
result_messages = await tool.run(messages)
return result_messages
def print_dialog(messages: List[Message]):
for i, m in enumerate(messages):
if m.role == Role.user.value:
color = "red"
elif m.role == Role.assistant.value:
color = "white"
elif m.role == Role.ipython.value:
color = "yellow"
elif m.role == Role.system.value:
color = "green"
else:
color = "white"
s = str(m)
cprint(f"{i} ::: {s[:100]}...", color=color)

View file

@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import tempfile
import uuid
from typing import AsyncGenerator
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
from .agent_instance import ChatAgent
from .config import MetaReferenceImplConfig
from .tools.builtin import (
CodeInterpreterTool,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import with_safety
logger = logging.getLogger()
logger.setLevel(logging.INFO)
AGENT_INSTANCES_BY_ID = {}
class MetaReferenceAgentsImpl(Agents):
def __init__(
self,
config: MetaReferenceImplConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
):
self.config = config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
async def initialize(self) -> None:
pass
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
else:
continue
builtin_tools.append(
with_safety(
tool,
self.safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
agent_config=agent_config,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
builtin_tools=builtin_tools,
)
return AgentCreateResponse(
agent_id=agent_id,
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
session = agent.create_session(session_name)
return AgentSessionCreateResponse(
session_id=session.session_id,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=stream,
)
agent_id = request.agent_id
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
assert (
request.session_id in agent.sessions
), f"Session {request.session_id} not found"
async for event in agent.create_and_execute_turn(request):
yield event

View file

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class MetaReferenceImplConfig(BaseModel): ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from termcolor import cprint # noqa: F401
from llama_stack.apis.inference import * # noqa: F403
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[Message],
**kwargs,
):
"""
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, messages, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, messages, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
# cprint(f"Generated query >>>: {query}", color="green")
return query
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[Message],
**kwargs,
):
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[Message],
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [m.model_dump() for m in messages]}
template = Template(config.template)
content = template.render(m_dict)
model = config.model
message = UserMessage(content=content)
response = inference_api.chat_completion(
ChatCompletionRequest(
model=model,
messages=[message],
stream=False,
)
)
async for chunk in response:
query = chunk.completion_message.content
return query

View file

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from llama_stack.apis.safety import (
OnViolationAction,
RunShieldRequest,
Safety,
ShieldDefinition,
ShieldResponse,
)
from termcolor import cprint
class SafetyException(Exception): # noqa: N818
def __init__(self, response: ShieldResponse):
self.response = response
super().__init__(response.violation_return_message)
class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
self.output_shields = output_shields
async def run_shields(
self, messages: List[Message], shields: List[ShieldDefinition]
) -> List[ShieldResponse]:
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
res = await self.safety_api.run_shields(
RunShieldRequest(
messages=messages,
shields=shields,
)
)
results = res.responses
for shield, r in zip(shields, results):
if r.is_violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return results

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import unittest
from llama_models.llama3.api.datatypes import (
Attachment,
BuiltinTool,
CompletionMessage,
StopReason,
ToolCall,
)
from ..tools.builtin import CodeInterpreterTool
class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase):
async def test_matplotlib(self):
tool = CodeInterpreterTool()
code = """
import matplotlib.pyplot as plt
import numpy as np
x = np.array([1, 1])
y = np.array([0, 10])
plt.plot(x, y)
plt.title('x = 1')
plt.xlabel('x')
plt.ylabel('y')
plt.grid(True)
plt.axvline(x=1, color='r')
plt.show()
"""
message = CompletionMessage(
role="assistant",
content="",
tool_calls=[
ToolCall(
call_id="call_id",
tool_name=BuiltinTool.code_interpreter,
arguments={"code": code},
)
],
stop_reason=StopReason.end_of_message,
)
ret = await tool.run([message])
self.assertEqual(len(ret), 1)
output = ret[0].content
self.assertIsInstance(output, Attachment)
self.assertEqual(output.mime_type, "image/png")
async def test_path_unlink(self):
tool = CodeInterpreterTool()
code = """
import os
from pathlib import Path
import tempfile
dpath = Path(os.environ["MPLCONFIGDIR"])
with open(dpath / "test", "w") as f:
f.write("hello")
Path(dpath / "test").unlink()
print("_OK_")
"""
message = CompletionMessage(
role="assistant",
content="",
tool_calls=[
ToolCall(
call_id="call_id",
tool_name=BuiltinTool.code_interpreter,
arguments={"code": code},
)
],
stop_reason=StopReason.end_of_message,
)
ret = await tool.run([message])
self.assertEqual(len(ret), 1)
output = ret[0].content
self.assertTrue("_OK_" in output)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List
from llama_stack.apis.inference import Message
class BaseTool(ABC):
@abstractmethod
def get_name(self) -> str:
raise NotImplementedError
@abstractmethod
async def run(self, messages: List[Message]) -> List[Message]:
raise NotImplementedError

View file

@ -0,0 +1,375 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
import tempfile
from abc import abstractmethod
from typing import List, Optional
import requests
from termcolor import cprint
from .ipython_tool.code_execution import (
CodeExecutionContext,
CodeExecutionRequest,
CodeExecutor,
TOOLS_ATTACHMENT_KEY_REGEX,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from .base import BaseTool
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
)
return None
class SingleMessageBuiltinTool(BaseTool):
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, f"Expected single message, got {len(messages)}"
message = messages[0]
assert len(message.tool_calls) == 1, "Expected a single tool call"
tool_call = messages[0].tool_calls[0]
query = tool_call.arguments["query"]
response: str = await self.run_impl(query)
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response,
)
return [message]
@abstractmethod
async def run_impl(self, query: str) -> str:
raise NotImplementedError()
class PhotogenTool(SingleMessageBuiltinTool):
def __init__(self, dump_dir: str) -> None:
self.dump_dir = dump_dir
def get_name(self) -> str:
return BuiltinTool.photogen.value
async def run_impl(self, query: str) -> str:
"""
Implement this to give the model an ability to generate images.
Return:
info = {
"filepath": str(image_filepath),
"mimetype": "image/png",
}
"""
raise NotImplementedError()
class SearchTool(SingleMessageBuiltinTool):
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
self.api_key = api_key
if engine == SearchEngineType.bing:
self.engine = BingSearch(api_key, **kwargs)
elif engine == SearchEngineType.brave:
self.engine = BraveSearch(api_key, **kwargs)
else:
raise ValueError(f"Unknown search engine: {engine}")
def get_name(self) -> str:
return BuiltinTool.brave_search.value
async def run_impl(self, query: str) -> str:
return await self.engine.search(query)
class BingSearch:
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
self.api_key = api_key
self.top_k = top_k
async def search(self, query: str) -> str:
url = "https://api.bing.microsoft.com/v7.0/search"
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
}
params = {
"count": self.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": query,
}
response = requests.get(url=url, params=params, headers=headers)
response.raise_for_status()
clean = self._clean_response(response.json())
return json.dumps(clean)
def _clean_response(self, search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
clean_response.append(clean_news)
return {"query": query, "top_k": clean_response}
class BraveSearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
async def search(self, query: str) -> str:
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": self.api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": query}
response = requests.get(url=url, params=payload, headers=headers)
return json.dumps(self._clean_brave_response(response.json()))
def _clean_brave_response(self, search_response, top_k=3):
query = None
clean_response = []
if "query" in search_response:
if "original" in search_response["query"]:
query = search_response["query"]["original"]
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][:top_k]:
r_type = m["type"]
results = search_response[r_type]["results"]
if r_type == "web":
# For web data - add a single output from the search
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"date",
"extra_snippets",
]
cleaned = {
k: v for k, v in results[idx].items() if k in selected_keys
}
elif r_type == "faq":
# For faw data - take a list of all the questions & answers
selected_keys = ["type", "question", "answer", "title", "url"]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "infobox":
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"long_desc",
]
cleaned = {
k: v for k, v in results[idx].items() if k in selected_keys
}
elif r_type == "videos":
selected_keys = [
"type",
"url",
"title",
"description",
"date",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "locations":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "news":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
else:
cleaned = []
clean_response.append(cleaned)
return {"query": query, "top_k": clean_response}
class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.url = "https://api.wolframalpha.com/v2/query"
def get_name(self) -> str:
return BuiltinTool.wolfram_alpha.value
async def run_impl(self, query: str) -> str:
params = {
"input": query,
"appid": self.api_key,
"format": "plaintext",
"output": "json",
}
response = requests.get(
self.url,
params=params,
)
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
def _clean_wolfram_alpha_response(self, wa_response):
remove = {
"queryresult": [
"datatypes",
"error",
"timedout",
"timedoutpods",
"numpods",
"timing",
"parsetiming",
"parsetimedout",
"recalculate",
"id",
"host",
"server",
"related",
"version",
{
"pods": [
"scanner",
"id",
"error",
"expressiontypes",
"states",
"infos",
"position",
"numsubpods",
]
},
"assumptions",
],
}
for main_key in remove:
for key_to_remove in remove[main_key]:
try:
if key_to_remove == "assumptions":
if "assumptions" in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
if isinstance(key_to_remove, dict):
for sub_key in key_to_remove:
if sub_key == "pods":
for i in range(len(wa_response[main_key][sub_key])):
if (
wa_response[main_key][sub_key][i]["title"]
== "Result"
):
del wa_response[main_key][sub_key][i + 1 :]
break
sub_items = wa_response[main_key][sub_key]
for i in range(len(sub_items)):
for sub_key_to_remove in key_to_remove[sub_key]:
if sub_key_to_remove in sub_items[i]:
del sub_items[i][sub_key_to_remove]
elif key_to_remove in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
except KeyError:
pass
return wa_response
class CodeInterpreterTool(BaseTool):
def __init__(self) -> None:
ctx = CodeExecutionContext(
matplotlib_dump_dir=tempfile.mkdtemp(),
)
self.code_executor = CodeExecutor(ctx)
def get_name(self) -> str:
return BuiltinTool.code_interpreter.value
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
message = messages[0]
assert len(message.tool_calls) == 1, "Expected a single tool call"
tool_call = messages[0].tool_calls[0]
script = tool_call.arguments["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]
for out_type in ["stdout", "stderr"]:
res_out = res[out_type]
if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr":
cprint(f"ipython tool error: ↓\n{res_out}", color="red")
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content="\n".join(pieces),
)
return [message]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import errno
# Disabling potentially dangerous functions
import os as _os
from functools import partial
os_funcs_to_disable = [
"kill",
"system",
"putenv",
"remove",
"removedirs",
"rmdir",
"fchdir",
"setuid",
"fork",
"forkpty",
"killpg",
"rename",
"renames",
"truncate",
"replace",
# "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
"fchmod",
"fchown",
"chmod",
"chown",
"chroot",
"fchdir",
"lchflags",
"lchmod",
"lchown",
"chdir",
]
def call_not_allowed(*args, **kwargs):
raise OSError(errno.EPERM, "Call are not permitted in this environment")
for func_name in os_funcs_to_disable:
if hasattr(_os, func_name):
setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
import shutil as _shutil
for func_name in ["rmtree", "move", "chown"]:
if hasattr(_shutil, func_name):
setattr(
_shutil,
func_name,
partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
)
import subprocess as _subprocess
def popen_not_allowed(*args, **kwargs):
raise _subprocess.CalledProcessError(
-1,
args[0] if args else "unknown",
stderr="subprocess.Popen is not allowed in this environment",
)
_subprocess.Popen = popen_not_allowed
import atexit as _atexit
import builtins as _builtins
import io as _io
import json as _json
import sys as _sys
# NB! The following "unused" imports crucial, make sure not not to remove
# them with linters - they're used in code_execution.py
from contextlib import ( # noqa
contextmanager as _contextmanager,
redirect_stderr as _redirect_stderr,
redirect_stdout as _redirect_stdout,
)
from multiprocessing.connection import Connection as _Connection
# Mangle imports to avoid polluting model execution namespace.
_IO_SINK = _io.StringIO()
_NETWORK_TIMEOUT = 5
_NETWORK_CONNECTIONS = None
def _open_connections():
global _NETWORK_CONNECTIONS
if _NETWORK_CONNECTIONS is not None:
# Ensure connections only opened once.
return _NETWORK_CONNECTIONS
req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
req_con = _Connection(int(req_w_fd), readable=False)
resp_con = _Connection(int(resp_r_fd), writable=False)
_NETWORK_CONNECTIONS = (req_con, resp_con)
return _NETWORK_CONNECTIONS
_builtins._open_connections = _open_connections
@_atexit.register
def _close_connections():
global _NETWORK_CONNECTIONS
if _NETWORK_CONNECTIONS is None:
return
for con in _NETWORK_CONNECTIONS:
con.close()
del _NETWORK_CONNECTIONS
def _network_call(request):
# NOTE: We communicate with the parent process in json, encoded
# in raw bytes. We do this because native send/recv methods use
# pickle which involves execution of arbitrary code.
_open_connections()
req_con, resp_con = _NETWORK_CONNECTIONS
req_con.send_bytes(_json.dumps(request).encode("utf-8"))
if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
raise Exception(f"Network request timed out: {_json.dumps(request)}")
else:
return _json.loads(resp_con.recv_bytes().decode("utf-8"))

View file

@ -0,0 +1,256 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import multiprocessing
import os
import re
import subprocess
import sys
import tempfile
import textwrap
import time
from dataclasses import dataclass
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import List
from PIL import Image
from .utils import get_code_env_prefix
TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
DIRNAME = Path(__file__).parent
CODE_EXEC_TIMEOUT = 20
CODE_ENV_PREFIX = get_code_env_prefix()
STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
{code}\
"""
TRYEXCEPT_WRAPPER_TEMPLATE = """\
try:
{code}
except:
pass\
"""
def generate_bwrap_command(bind_dirs: List[str]) -> str:
"""
Generate the bwrap command string for binding all
directories in the current directory read-only.
"""
bwrap_args = ""
bwrap_args += "--ro-bind / / "
# Add the --dev flag to mount device files
bwrap_args += "--dev /dev "
for d in bind_dirs:
bwrap_args += f"--bind {d} {d} "
# Add the --unshare-all flag to isolate the sandbox from the rest of the system
bwrap_args += "--unshare-all "
# Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
bwrap_args += "--die-with-parent "
return bwrap_args
@dataclass
class CodeExecutionContext:
matplotlib_dump_dir: str
use_proxy: bool = False
@dataclass
class CodeExecutionRequest:
scripts: List[str]
only_last_cell_stdouterr: bool = True
only_last_cell_fail: bool = True
seed: int = 0
strip_fpaths_in_stderr: bool = True
class CodeExecutor:
def __init__(self, context: CodeExecutionContext):
self.context = context
def execute(self, req: CodeExecutionRequest) -> dict:
scripts = req.scripts
for i in range(len(scripts) - 1):
if req.only_last_cell_stdouterr:
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
code=textwrap.indent(scripts[i], " " * 4)
)
if req.only_last_cell_fail:
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
code=textwrap.indent(scripts[i], " " * 4)
)
# Seeds prefix:
seed = req.seed
seeds_prefix = f"""\
def _set_seeds():
import random
random.seed({seed})
import numpy as np
np.random.seed({seed})
_set_seeds()\
"""
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
with tempfile.TemporaryDirectory() as dpath:
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
code_fpath = os.path.join(dpath, "code.py")
with open(code_fpath, "w") as f:
f.write(script)
try:
python_path = os.environ.get("PYTHONPATH", "")
env = dict(
os.environ,
PYTHONHASHSEED=str(seed),
MPLCONFIGDIR=dpath,
MPLBACKEND="module://matplotlib_custom_backend",
PYTHONPATH=f"{DIRNAME}:{python_path}",
)
stdout, stderr, returncode = do_subprocess(
cmd=cmd,
env=env,
ctx=self.context,
)
stderr = stderr.strip()
if req.strip_fpaths_in_stderr:
pattern = r'File "([^"]+)", line (\d+)'
stderr = re.sub(pattern, r"line \2", stderr)
return {
"process_status": "completed",
"returncode": returncode,
"stdout": stdout.strip(),
"stderr": stderr,
}
except subprocess.TimeoutExpired:
return {
"process_status": "timeout",
"stdout": "Timed out",
"stderr": "Timed out",
}
except Exception as e:
return {
"process_status": "error",
"error_type": type(e).__name__,
"stderr": str(e),
"stdout": str(e),
}
def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_data = response["image_data"]
# Convert the base64 string to a bytes object
images = [base64.b64decode(d["image_base64"]) for d in image_data]
# Create a list of PIL images from the bytes objects
images = [Image.open(BytesIO(img)) for img in images]
# Create a list of image paths
image_paths = []
for i, img in enumerate(images):
# create new directory for each day to better organize data:
dump_dname = datetime.today().strftime("%Y-%m-%d")
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
dump_dpath.mkdir(parents=True, exist_ok=True)
# save image into a file
dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
dump_fpath = dump_dpath / dump_fname
img.save(dump_fpath, "PNG")
image_paths.append(str(dump_fpath))
# this is kind of convoluted, we send back this response to the subprocess which
# prints it out
info = {
"filepath": str(image_paths[-1]),
"mimetype": "image/png",
}
return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
def execute_subprocess_request(request, ctx: CodeExecutionContext):
"Route requests from the subprocess (via network Pipes) to the internet/tools."
if request["type"] == "matplotlib":
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
else:
raise Exception(f'Unrecognised network request type: {request["type"]}')
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
# Create Pipes to be used for any external tool/network requests.
req_r, req_w = multiprocessing.Pipe(duplex=False)
resp_r, resp_w = multiprocessing.Pipe(duplex=False)
cmd += [str(req_w.fileno()), str(resp_r.fileno())]
proc = subprocess.Popen(
cmd,
pass_fds=(req_w.fileno(), resp_r.fileno()),
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=True,
env=env,
)
# Close unnecessary fds.
req_w.close()
resp_r.close()
pipe_close = False
done_read = False
start = time.monotonic()
while proc.poll() is None and not pipe_close:
if req_r.poll(0.1):
# NB: Python pipe semantics for poll and recv mean that
# poll() returns True is a pipe is closed.
# CF old school PEP from '09
# https://bugs.python.org/issue5573
try:
request = json.loads(req_r.recv_bytes().decode("utf-8"))
response = execute_subprocess_request(request, ctx)
resp_w.send_bytes(json.dumps(response).encode("utf-8"))
except EOFError:
# The request pipe is closed - set a marker to exit
# after the next attempt at reading stdout/stderr.
pipe_close = True
try:
# If lots has been printed, pipe might be full but
# proc cannot exit until all the stdout/stderr
# been written/read.
stdout, stderr = proc.communicate(timeout=0.3)
done_read = True
except subprocess.TimeoutExpired:
# The program has not terminated. Ignore it, there
# may be more network/tool requests.
continue
if time.monotonic() - start > CODE_EXEC_TIMEOUT:
proc.terminate()
raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
if not done_read:
# Solve race condition where process terminates before
# we hit the while loop.
stdout, stderr = proc.communicate(timeout=0.3)
resp_w.close()
req_r.close()
return stdout, stderr, proc.returncode

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
A custom Matplotlib backend that overrides the show method to return image bytes.
"""
import base64
import io
import json as _json
import matplotlib
from matplotlib.backend_bases import FigureManagerBase
# Import necessary components from Matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg
class CustomFigureCanvas(FigureCanvasAgg):
def show(self):
# Save the figure to a BytesIO object
buf = io.BytesIO()
self.print_png(buf)
image_bytes = buf.getvalue()
buf.close()
return image_bytes
class CustomFigureManager(FigureManagerBase):
def __init__(self, canvas, num):
super().__init__(canvas, num)
# Mimic module initialization that integrates with the Matplotlib backend system
def _create_figure_manager(num, *args, **kwargs):
"""
Create a custom figure manager instance.
"""
FigureClass = kwargs.pop("FigureClass", None) # noqa: N806
if FigureClass is None:
from matplotlib.figure import Figure
FigureClass = Figure # noqa: N806
fig = FigureClass(*args, **kwargs)
canvas = CustomFigureCanvas(fig)
manager = CustomFigureManager(canvas, num)
return manager
def show():
"""
Handle all figures and potentially return their images as bytes.
This function iterates over all figures registered with the custom backend,
renders them as images in bytes format, and could return a list of bytes objects,
one for each figure, or handle them as needed.
"""
image_data = []
for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers():
# Get the figure from the manager
fig = manager.canvas.figure
buf = io.BytesIO() # Create a buffer for the figure
fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format
buf.seek(0) # Go to the beginning of the buffer
image_bytes = buf.getvalue() # Retrieve bytes value
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
image_data.append({"image_base64": image_base64})
buf.close()
req_con, resp_con = _open_connections()
_json_dump = _json.dumps(
{
"type": "matplotlib",
"image_data": image_data,
}
)
req_con.send_bytes(_json_dump.encode("utf-8"))
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
print(resp)
FigureCanvas = CustomFigureCanvas
FigureManager = CustomFigureManager

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
DIR = os.path.dirname(os.path.realpath(__file__))
CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
CODE_ENV_PREFIX = None
def get_code_env_prefix() -> str:
global CODE_ENV_PREFIX
if CODE_ENV_PREFIX is None:
with open(CODE_ENV_PREFIX_FILE, "r") as f:
CODE_ENV_PREFIX = f.read()
return CODE_ENV_PREFIX

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, ShieldDefinition
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
from .builtin import BaseTool
class SafeTool(BaseTool, ShieldRunnerMixin):
"""A tool that makes other tools safety enabled"""
def __init__(
self,
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
):
self._tool = tool
ShieldRunnerMixin.__init__(
self, safety_api, input_shields=input_shields, output_shields=output_shields
)
def get_name(self) -> str:
# return the name of the wrapped tool
return self._tool.get_name()
async def run(self, messages: List[Message]) -> List[Message]:
if self.input_shields:
await self.run_shields(messages, self.input_shields)
# run the underlying tool
res = await self._tool.run(messages)
if self.output_shields:
await self.run_shields(messages, self.output_shields)
return res
def with_safety(
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
) -> SafeTool:
return SafeTool(
tool,
safety_api,
input_shields=input_shields,
output_shields=output_shields,
)