From be19b223913af73905386e43ed74a23698aa92b8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 4 Aug 2024 10:53:38 -0700 Subject: [PATCH] Bring agentic system api to toolchain Add adapter dependencies and resolve adapters using a topological sort --- .gitignore | 1 + llama_toolchain/agentic_system/adapters.py | 29 + .../agentic_system/agentic_system.py | 786 ++++++++++++++++++ .../agentic_system/api/__init__.py | 8 + .../agentic_system/api/datatypes.py | 199 +++++ .../agentic_system/api/endpoints.py | 132 +++ llama_toolchain/agentic_system/client.py | 130 +++ llama_toolchain/agentic_system/config.py | 12 + llama_toolchain/agentic_system/safety.py | 65 ++ .../agentic_system/system_prompt.py | 152 ++++ .../agentic_system/tools/__init__.py | 5 + llama_toolchain/agentic_system/tools/base.py | 21 + .../agentic_system/tools/builtin.py | 326 ++++++++ .../agentic_system/tools/custom.py | 103 +++ .../agentic_system/tools/execute.py | 84 ++ .../tools/ipython_tool/__init__.py | 5 + .../tools/ipython_tool/code_env_prefix.py | 133 +++ .../tools/ipython_tool/code_execution.py | 256 ++++++ .../ipython_tool/matplotlib_custom_backend.py | 87 ++ .../tools/ipython_tool/utils.py | 21 + .../agentic_system/tools/safety.py | 59 ++ llama_toolchain/common/prompt_for_config.py | 6 +- llama_toolchain/distribution/datatypes.py | 19 +- llama_toolchain/distribution/distribution.py | 2 + llama_toolchain/distribution/dynamic.py | 14 +- llama_toolchain/distribution/registry.py | 37 +- llama_toolchain/distribution/server.py | 51 +- llama_toolchain/inference/client.py | 4 + llama_toolchain/inference/inference.py | 8 +- llama_toolchain/safety/client.py | 4 + llama_toolchain/safety/safety.py | 6 +- 31 files changed, 2740 insertions(+), 25 deletions(-) create mode 100644 llama_toolchain/agentic_system/adapters.py create mode 100644 llama_toolchain/agentic_system/agentic_system.py create mode 100644 llama_toolchain/agentic_system/api/__init__.py create mode 100644 llama_toolchain/agentic_system/api/datatypes.py create mode 100644 llama_toolchain/agentic_system/api/endpoints.py create mode 100644 llama_toolchain/agentic_system/client.py create mode 100644 llama_toolchain/agentic_system/config.py create mode 100644 llama_toolchain/agentic_system/safety.py create mode 100644 llama_toolchain/agentic_system/system_prompt.py create mode 100644 llama_toolchain/agentic_system/tools/__init__.py create mode 100644 llama_toolchain/agentic_system/tools/base.py create mode 100644 llama_toolchain/agentic_system/tools/builtin.py create mode 100644 llama_toolchain/agentic_system/tools/custom.py create mode 100644 llama_toolchain/agentic_system/tools/execute.py create mode 100644 llama_toolchain/agentic_system/tools/ipython_tool/__init__.py create mode 100644 llama_toolchain/agentic_system/tools/ipython_tool/code_env_prefix.py create mode 100644 llama_toolchain/agentic_system/tools/ipython_tool/code_execution.py create mode 100644 llama_toolchain/agentic_system/tools/ipython_tool/matplotlib_custom_backend.py create mode 100644 llama_toolchain/agentic_system/tools/ipython_tool/utils.py create mode 100644 llama_toolchain/agentic_system/tools/safety.py diff --git a/.gitignore b/.gitignore index 321e946a9..067fe2c6f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.env __pycache__ dist *.egg-info diff --git a/llama_toolchain/agentic_system/adapters.py b/llama_toolchain/agentic_system/adapters.py new file mode 100644 index 000000000..df8e8c9d6 --- /dev/null +++ b/llama_toolchain/agentic_system/adapters.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter + + +def available_agentic_system_adapters() -> List[Adapter]: + return [ + SourceAdapter( + api_surface=ApiSurface.agentic_system, + adapter_id="meta-reference", + pip_packages=[ + "codeshield", + "torch", + "transformers", + ], + module="llama_toolchain.agentic_system.agentic_system", + config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig", + adapter_dependencies=[ + ApiSurface.inference, + ApiSurface.safety, + ], + ), + ] diff --git a/llama_toolchain/agentic_system/agentic_system.py b/llama_toolchain/agentic_system/agentic_system.py new file mode 100644 index 000000000..011893b44 --- /dev/null +++ b/llama_toolchain/agentic_system/agentic_system.py @@ -0,0 +1,786 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_toolchain.agentic_system.api import AgenticSystem + +from llama_toolchain.distribution.datatypes import Adapter, ApiSurface +from llama_toolchain.inference.api import Inference +from llama_toolchain.safety.api import Safety + +from .config import AgenticSystemConfig +from .api.endpoints import * # noqa + +import logging +import os +import uuid +from datetime import datetime +from typing import AsyncGenerator, Dict, List, Optional + +from llama_toolchain.inference.api import ChatCompletionRequest + +from llama_toolchain.inference.api.datatypes import ( + Attachment, + BuiltinTool, + ChatCompletionResponseEventType, + CompletionMessage, + Message, + Role, + SamplingParams, + StopReason, + ToolCallDelta, + ToolCallParseStatus, + ToolDefinition, + ToolResponse, + ToolResponseMessage, + URL, +) +from llama_toolchain.safety.api.datatypes import ( + BuiltinShield, + ShieldDefinition, + ShieldResponse, +) + +from termcolor import cprint + +from .api.datatypes import ( + AgenticSystemInstanceConfig, + AgenticSystemTurnResponseEvent, + AgenticSystemTurnResponseEventType, + AgenticSystemTurnResponseStepCompletePayload, + AgenticSystemTurnResponseStepProgressPayload, + AgenticSystemTurnResponseStepStartPayload, + AgenticSystemTurnResponseTurnCompletePayload, + AgenticSystemTurnResponseTurnStartPayload, + InferenceStep, + Session, + ShieldCallStep, + StepType, + ToolExecutionStep, + Turn, +) +from .api.endpoints import ( + AgenticSystemCreateRequest, + AgenticSystemCreateResponse, + AgenticSystemSessionCreateRequest, + AgenticSystemSessionCreateResponse, + AgenticSystemTurnCreateRequest, + AgenticSystemTurnResponseStreamChunk, +) +from .safety import SafetyException, ShieldRunnerMixin + +from .system_prompt import get_agentic_prefix_messages +from .tools.base import BaseTool +from .tools.builtin import ( + BraveSearchTool, + CodeInterpreterTool, + PhotogenTool, + SingleMessageBuiltinTool, + WolframAlphaTool, +) +from .tools.safety import with_safety + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +async def get_adapter_impl( + config: AgenticSystemConfig, deps: Dict[ApiSurface, Adapter] +): + assert isinstance( + config, AgenticSystemConfig + ), f"Unexpected config type: {type(config)}" + + impl = MetaReferenceAgenticSystemImpl( + deps[ApiSurface.inference], + deps[ApiSurface.safety], + ) + await impl.initialize() + return impl + + +async def execute_tool_call_maybe( + tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage] +) -> List[ToolResponseMessage]: + # While Tools.run interface takes a list of messages, + # All tools currently only run on a single message + # When this changes, we can drop this assert + # Whether to call tools on each message and aggregate + # or aggregate and call tool once, reamins to be seen. + assert len(messages) == 1, "Expected single message" + message = messages[0] + + tool_call = message.tool_calls[0] + name = tool_call.tool_name + assert isinstance(name, BuiltinTool) + + name = name.value + + assert name in tools_dict, f"Tool {name} not found" + tool = tools_dict[name] + result_messages = await tool.run(messages) + return result_messages + + +def print_dialog(messages: List[Message]): + for i, m in enumerate(messages): + if m.role == Role.user.value: + color = "red" + elif m.role == Role.assistant.value: + color = "white" + elif m.role == Role.ipython.value: + color = "yellow" + elif m.role == Role.system.value: + color = "green" + else: + color = "white" + + s = str(m) + cprint(f"{i} ::: {s[:100]}...", color=color) + + +AGENT_INSTANCES_BY_ID = {} + + +class AgentInstance(ShieldRunnerMixin): + def __init__( + self, + system_id: int, + instance_config: AgenticSystemInstanceConfig, + model: str, + inference_api: Inference, + safety_api: Safety, + builtin_tools: List[SingleMessageBuiltinTool], + custom_tool_definitions: List[ToolDefinition], + input_shields: List[ShieldDefinition], + output_shields: List[ShieldDefinition], + max_infer_iters: int = 10, + prefix_messages: Optional[List[Message]] = None, + ): + self.system_id = system_id + self.instance_config = instance_config + + self.model = model + self.inference_api = inference_api + self.safety_api = safety_api + + if prefix_messages is not None and len(prefix_messages) > 0: + self.prefix_messages = prefix_messages + else: + self.prefix_messages = get_agentic_prefix_messages( + builtin_tools, custom_tool_definitions + ) + + for m in self.prefix_messages: + print(m.content) + + self.max_infer_iters = max_infer_iters + self.tools_dict = {t.get_name(): t for t in builtin_tools} + + self.sessions = {} + + ShieldRunnerMixin.__init__( + self, + safety_api, + input_shields=input_shields, + output_shields=output_shields, + ) + + def create_session(self, name: str) -> Session: + session_id = str(uuid.uuid4()) + session = Session( + session_id=session_id, + session_name=name, + turns=[], + started_at=datetime.now(), + ) + self.sessions[session_id] = session + return session + + async def create_and_execute_turn( + self, request: AgenticSystemTurnCreateRequest + ) -> AsyncGenerator: + assert ( + request.session_id in self.sessions + ), f"Session {request.session_id} not found" + + session = self.sessions[request.session_id] + + messages = [] + for i, turn in enumerate(session.turns): + # print(f"turn {i}") + # print_dialog(turn.input_messages) + messages.extend(turn.input_messages) + for step in turn.steps: + if step.step_type == StepType.inference.value: + messages.append(step.model_response) + elif step.step_type == StepType.tool_execution.value: + for response in step.tool_responses: + messages.append( + ToolResponseMessage( + call_id=response.call_id, + tool_name=response.tool_name, + content=response.content, + ) + ) + elif step.step_type == StepType.shield_call.value: + response = step.response + if response.is_violation: + # TODO: Properly persist the + # CompletionMessage itself in the ShieldResponse + messages.append( + CompletionMessage( + content=response.violation_return_message, + stop_reason=StopReason.end_of_turn, + ) + ) + + messages.extend(request.messages) + + # print("processed dialog ======== ") + # print_dialog(messages) + + turn_id = str(uuid.uuid4()) + params = self.instance_config.sampling_params + start_time = datetime.now() + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseTurnStartPayload( + turn_id=turn_id, + ) + ) + ) + + steps = [] + output_message = None + async for chunk in self.run( + turn_id=turn_id, + input_messages=messages, + temperature=params.temperature, + top_p=params.top_p, + stream=request.stream, + max_gen_len=params.max_tokens, + ): + if isinstance(chunk, CompletionMessage): + cprint( + f"{chunk.role.capitalize()}: {chunk.content}", + "white", + attrs=["bold"], + ) + output_message = chunk + continue + + assert isinstance( + chunk, AgenticSystemTurnResponseStreamChunk + ), f"Unexpected type {type(chunk)}" + event = chunk.event + if ( + event.payload.event_type + == AgenticSystemTurnResponseEventType.step_complete.value + ): + steps.append(event.payload.step_details) + + yield chunk + + assert output_message is not None + + turn = Turn( + turn_id=turn_id, + session_id=request.session_id, + input_messages=request.messages, + output_message=output_message, + started_at=start_time, + completed_at=datetime.now(), + steps=steps, + ) + session.turns.append(turn) + + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + + async def run_shields_wrapper( + self, + turn_id: str, + messages: List[Message], + shields: List[ShieldDefinition], + touchpoint: str, + ) -> AsyncGenerator: + if len(shields) == 0: + return + + step_id = str(uuid.uuid4()) + try: + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepStartPayload( + step_type=StepType.shield_call.value, + step_id=step_id, + metadata=dict(touchpoint=touchpoint), + ) + ) + ) + await self.run_shields(messages, shields) + + except SafetyException as e: + + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepCompletePayload( + step_type=StepType.shield_call.value, + step_details=ShieldCallStep( + step_id=step_id, + turn_id=turn_id, + response=e.response, + ), + ) + ) + ) + + yield CompletionMessage( + content=str(e), + stop_reason=StopReason.end_of_turn, + ) + yield False + + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepCompletePayload( + step_type=StepType.shield_call.value, + step_details=ShieldCallStep( + step_id=step_id, + turn_id=turn_id, + response=ShieldResponse( + # TODO: fix this, give each shield a shield type method and + # fire one event for each shield run + shield_type=BuiltinShield.llama_guard, + is_violation=False, + ), + ), + ) + ) + ) + + async def run( + self, + turn_id: str, + input_messages: List[Message], + temperature: float, + top_p: float, + stream: bool = False, + max_gen_len: Optional[int] = None, + ) -> AsyncGenerator: + # Doing async generators makes downstream code much simpler and everything amenable to + # stremaing. However, it also makes things complicated here because AsyncGenerators cannot + # return a "final value" for the `yield from` statement. we simulate that by yielding a + # final boolean (to see whether an exception happened) and then explicitly testing for it. + + async for res in self.run_shields_wrapper( + turn_id, input_messages, self.input_shields, "user-input" + ): + if isinstance(res, bool): + return + else: + yield res + + async for res in self._run( + turn_id, input_messages, temperature, top_p, stream, max_gen_len + ): + if isinstance(res, bool): + return + elif isinstance(res, CompletionMessage): + final_response = res + break + else: + yield res + + assert final_response is not None + # for output shields run on the full input and output combination + messages = input_messages + [final_response] + + async for res in self.run_shields_wrapper( + turn_id, messages, self.output_shields, "assistant-output" + ): + if isinstance(res, bool): + return + else: + yield res + + yield final_response + + async def _run( + self, + turn_id: str, + input_messages: List[Message], + temperature: float, + top_p: float, + stream: bool = False, + max_gen_len: Optional[int] = None, + ) -> AsyncGenerator: + input_messages = preprocess_dialog(input_messages, self.prefix_messages) + + attachments = [] + + n_iter = 0 + while True: + msg = input_messages[-1] + if msg.role == Role.user.value: + color = "blue" + elif msg.role == Role.ipython.value: + color = "yellow" + else: + color = None + cprint(f"{str(msg)}", color=color) + + step_id = str(uuid.uuid4()) + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepStartPayload( + step_type=StepType.inference.value, + step_id=step_id, + ) + ) + ) + + # where are the available tools? + req = ChatCompletionRequest( + model=self.model, + messages=input_messages, + available_tools=self.instance_config.available_tools, + stream=True, + sampling_params=SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_gen_len, + ), + ) + + tool_calls = [] + content = "" + stop_reason = None + async for chunk in self.inference_api.chat_completion(req): + event = chunk.event + if event.event_type != ChatCompletionResponseEventType.progress: + continue + + delta = event.delta + if isinstance(delta, ToolCallDelta): + if delta.parse_status == ToolCallParseStatus.success: + tool_calls.append(delta.content) + + if stream: + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepProgressPayload( + step_type=StepType.inference.value, + step_id=step_id, + model_response_text_delta="", + tool_call_delta=delta, + ) + ) + ) + + elif isinstance(delta, str): + content += delta + if stream and event.stop_reason is None: + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepProgressPayload( + step_type=StepType.inference.value, + step_id=step_id, + model_response_text_delta=event.delta, + ) + ) + ) + else: + raise ValueError(f"Unexpected delta type {type(delta)}") + + if event.stop_reason is not None: + stop_reason = event.stop_reason + + stop_reason = stop_reason or StopReason.out_of_tokens + message = CompletionMessage( + content=content, + stop_reason=stop_reason, + tool_calls=tool_calls, + ) + + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepCompletePayload( + step_type=StepType.inference.value, + step_id=step_id, + step_details=InferenceStep( + step_id=step_id, turn_id=turn_id, model_response=message + ), + ) + ) + ) + + if n_iter >= self.max_infer_iters: + cprint("Done with MAX iterations, exiting.") + yield message + break + + if stop_reason == StopReason.out_of_tokens: + cprint("Out of token budget, exiting.") + yield message + break + + if len(message.tool_calls) == 0: + if stop_reason == StopReason.end_of_turn: + if len(attachments) > 0: + if isinstance(message.content, list): + message.content += attachments + else: + message.content = [message.content] + attachments + yield message + else: + cprint(f"Partial message: {str(message)}", color="green") + input_messages = input_messages + [message] + else: + cprint(f"{str(message)}", color="green") + try: + tool_call = message.tool_calls[0] + + name = tool_call.tool_name + if not isinstance(name, BuiltinTool): + yield message + return + + step_id = str(uuid.uuid4()) + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepStartPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + ) + ) + ) + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepProgressPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + tool_call=tool_call, + ) + ) + ) + + result_messages = await execute_tool_call_maybe( + self.tools_dict, + [message], + ) + assert ( + len(result_messages) == 1 + ), "Currently not supporting multiple messages" + result_message = result_messages[0] + + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_details=ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[tool_call], + tool_responses=[ + ToolResponse( + call_id=result_message.call_id, + tool_name=result_message.tool_name, + content=result_message.content, + ) + ], + ), + ) + ) + ) + + # TODO: add tool-input touchpoint and a "start" event for this step also + # but that needs a lot more refactoring of Tool code potentially + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepCompletePayload( + step_type=StepType.shield_call.value, + step_details=ShieldCallStep( + step_id=str(uuid.uuid4()), + turn_id=turn_id, + response=ShieldResponse( + # TODO: fix this, give each shield a shield type method and + # fire one event for each shield run + shield_type=BuiltinShield.llama_guard, + is_violation=False, + ), + ), + ) + ) + ) + + except SafetyException as e: + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepCompletePayload( + step_type=StepType.shield_call.value, + step_details=ShieldCallStep( + step_id=str(uuid.uuid4()), + turn_id=turn_id, + response=e.response, + ), + ) + ) + ) + + yield CompletionMessage( + content=str(e), + stop_reason=StopReason.end_of_turn, + ) + yield False + return + + if isinstance(result_message.content, Attachment): + # NOTE: when we push this message back to the model, the model may ignore the + # attached file path etc. since the model is trained to only provide a user message + # with the summary. We keep all generated attachments and then attach them to final message + attachments.append(result_message.content) + elif isinstance(result_message.content, list) or isinstance( + result_message.content, tuple + ): + for c in result_message.content: + if isinstance(c, Attachment): + attachments.append(c) + + input_messages = input_messages + [message, result_message] + + n_iter += 1 + + +class MetaReferenceAgenticSystemImpl(AgenticSystem): + def __init__(self, inference_api: Inference, safety_api: Safety): + self.inference_api = inference_api + self.safety_api = safety_api + + async def initialize(self) -> None: + pass + + async def create_agentic_system( + self, + request: AgenticSystemCreateRequest, + ) -> AgenticSystemCreateResponse: + system_id = str(uuid.uuid4()) + + builtin_tools = [] + custom_tool_definitions = [] + cfg = request.instance_config + for dfn in cfg.available_tools: + if isinstance(dfn.tool_name, BuiltinTool): + if dfn.tool_name == BuiltinTool.wolfram_alpha: + tool = WolframAlphaTool(os.environ.get("WOLFRAM_ALPHA_API_KEY")) + elif dfn.tool_name == BuiltinTool.brave_search: + tool = BraveSearchTool(os.environ.get("BRAVE_SEARCH_API_KEY")) + elif dfn.tool_name == BuiltinTool.code_interpreter: + tool = CodeInterpreterTool() + elif dfn.tool_name == BuiltinTool.photogen: + tool = PhotogenTool( + dump_dir="/tmp/photogen_dump_" + os.environ["USER"], + ) + else: + raise ValueError(f"Unknown builtin tool: {dfn.tool_name}") + + builtin_tools.append( + with_safety( + tool, self.safety_api, dfn.input_shields, dfn.output_shields + ) + ) + else: + custom_tool_definitions.append(dfn) + + AGENT_INSTANCES_BY_ID[system_id] = AgentInstance( + system_id=system_id, + instance_config=request.instance_config, + model=request.model, + inference_api=self.inference_api, + builtin_tools=builtin_tools, + custom_tool_definitions=custom_tool_definitions, + safety_api=self.safety_api, + input_shields=cfg.input_shields, + output_shields=cfg.output_shields, + prefix_messages=cfg.debug_prefix_messages, + ) + + return AgenticSystemCreateResponse( + system_id=system_id, + ) + + async def create_agentic_system_session( + self, + request: AgenticSystemSessionCreateRequest, + ) -> AgenticSystemSessionCreateResponse: + system_id = request.system_id + assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found" + agent = AGENT_INSTANCES_BY_ID[system_id] + + session = agent.create_session(request.session_name) + return AgenticSystemSessionCreateResponse( + session_id=session.session_id, + ) + + async def create_agentic_system_turn( + self, + request: AgenticSystemTurnCreateRequest, + ) -> AsyncGenerator: + system_id = request.system_id + assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found" + agent = AGENT_INSTANCES_BY_ID[system_id] + + assert ( + request.session_id in agent.sessions + ), f"Session {request.session_id} not found" + async for event in agent.create_and_execute_turn(request): + yield event + + +def attachment_message(url: URL) -> ToolResponseMessage: + uri = url.uri + assert uri.startswith("file://") + filepath = uri[len("file://") :] + + return ToolResponseMessage( + call_id="", + tool_name=BuiltinTool.code_interpreter, + content=f'# There is a file accessible to you at "{filepath}"', + ) + + +def preprocess_dialog( + messages: List[Message], prefix_messages: List[Message] +) -> List[Message]: + """ + Preprocesses the dialog by removing the system message and + adding the system message to the beginning of the dialog. + """ + ret = prefix_messages.copy() + + for m in messages: + if m.role == Role.system.value: + continue + + # NOTE: the ideal behavior is to use `file_path = ...` but that + # means we need to have stateful execution o f code which we currently + # do not have. + if isinstance(m.content, Attachment): + ret.append(attachment_message(m.content.url)) + elif isinstance(m.content, list): + for c in m.content: + if isinstance(c, Attachment): + ret.append(attachment_message(c.url)) + + ret.append(m) + + return ret diff --git a/llama_toolchain/agentic_system/api/__init__.py b/llama_toolchain/agentic_system/api/__init__.py new file mode 100644 index 000000000..4cefa053f --- /dev/null +++ b/llama_toolchain/agentic_system/api/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .datatypes import * # noqa +from .endpoints import * # noqa diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py new file mode 100644 index 000000000..45700a75a --- /dev/null +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field +from strong_typing.schema import json_schema_type +from typing_extensions import Annotated + +from llama_toolchain.common.deployment_types import * # noqa: F403 +from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.safety.api.datatypes import * # noqa: F403 +from llama_toolchain.memory.api.datatypes import * # noqa: F403 + + +@json_schema_type +class AgenticSystemToolDefinition(ToolDefinition): + execution_config: Optional[RestAPIExecutionConfig] = None + input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) + output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) + + +class StepCommon(BaseModel): + turn_id: str + step_id: str + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + +class StepType(Enum): + inference = "inference" + tool_execution = "tool_execution" + shield_call = "shield_call" + memory_retrieval = "memory_retrieval" + + +@json_schema_type +class InferenceStep(StepCommon): + step_type: Literal[StepType.inference.value] = StepType.inference.value + model_response: CompletionMessage + + +@json_schema_type +class ToolExecutionStep(StepCommon): + step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value + tool_calls: List[ToolCall] + tool_responses: List[ToolResponse] + + +@json_schema_type +class ShieldCallStep(StepCommon): + step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value + response: ShieldResponse + + +@json_schema_type +class MemoryRetrievalStep(StepCommon): + step_type: Literal[StepType.memory_retrieval.value] = ( + StepType.memory_retrieval.value + ) + memory_bank_ids: List[str] + documents: List[MemoryBankDocument] + scores: List[float] + + +Step = Annotated[ + Union[ + InferenceStep, + ToolExecutionStep, + ShieldCallStep, + MemoryRetrievalStep, + ], + Field(discriminator="step_type"), +] + + +@json_schema_type +class Turn(BaseModel): + """A single turn in an interaction with an Agentic System.""" + + turn_id: str + session_id: str + input_messages: List[ + Union[ + UserMessage, + ToolResponseMessage, + ] + ] + steps: List[Step] + output_message: CompletionMessage + started_at: datetime + completed_at: Optional[datetime] = None + + +@json_schema_type +class Session(BaseModel): + """A single session of an interaction with an Agentic System.""" + + session_id: str + session_name: str + turns: List[Turn] + started_at: datetime + + +@json_schema_type +class AgenticSystemInstanceConfig(BaseModel): + instructions: str + sampling_params: Optional[SamplingParams] = SamplingParams() + # zero-shot or built-in tool configurations as input to the model + available_tools: Optional[List[AgenticSystemToolDefinition]] = Field( + default_factory=list + ) + + input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) + output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) + + quantization_config: Optional[QuantizationConfig] = None + + # if you completely want to replace the messages prefixed by the system, + # this is debug only + debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list) + + +class AgenticSystemTurnResponseEventType(Enum): + step_start = "step_start" + step_complete = "step_complete" + step_progress = "step_progress" + + turn_start = "turn_start" + turn_complete = "turn_complete" + + +@json_schema_type +class AgenticSystemTurnResponseStepStartPayload(BaseModel): + event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = ( + AgenticSystemTurnResponseEventType.step_start.value + ) + step_type: StepType + step_id: str + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +@json_schema_type +class AgenticSystemTurnResponseStepCompletePayload(BaseModel): + event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = ( + AgenticSystemTurnResponseEventType.step_complete.value + ) + step_type: StepType + step_details: Step + + +@json_schema_type +class AgenticSystemTurnResponseStepProgressPayload(BaseModel): + event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = ( + AgenticSystemTurnResponseEventType.step_progress.value + ) + step_type: StepType + step_id: str + + model_response_text_delta: Optional[str] = None + tool_call_delta: Optional[ToolCallDelta] = None + tool_response_text_delta: Optional[str] = None + + +@json_schema_type +class AgenticSystemTurnResponseTurnStartPayload(BaseModel): + event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = ( + AgenticSystemTurnResponseEventType.turn_start.value + ) + turn_id: str + + +@json_schema_type +class AgenticSystemTurnResponseTurnCompletePayload(BaseModel): + event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = ( + AgenticSystemTurnResponseEventType.turn_complete.value + ) + turn: Turn + + +@json_schema_type +class AgenticSystemTurnResponseEvent(BaseModel): + """Streamed agent execution response.""" + + payload: Annotated[ + Union[ + AgenticSystemTurnResponseStepStartPayload, + AgenticSystemTurnResponseStepProgressPayload, + AgenticSystemTurnResponseStepCompletePayload, + AgenticSystemTurnResponseTurnStartPayload, + AgenticSystemTurnResponseTurnCompletePayload, + ], + Field(discriminator="event_type"), + ] diff --git a/llama_toolchain/agentic_system/api/endpoints.py b/llama_toolchain/agentic_system/api/endpoints.py new file mode 100644 index 000000000..89ccc2995 --- /dev/null +++ b/llama_toolchain/agentic_system/api/endpoints.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .datatypes import * # noqa: F403 +from typing import Protocol + +# this dependency is annoying and we need a forked up version anyway +from pyopenapi import webmethod + +from strong_typing.schema import json_schema_type + + +@json_schema_type +class AgenticSystemCreateRequest(BaseModel): + model: str + instance_config: AgenticSystemInstanceConfig + + +@json_schema_type +class AgenticSystemCreateResponse(BaseModel): + system_id: str + + +@json_schema_type +class AgenticSystemSessionCreateRequest(BaseModel): + system_id: str + session_name: str + + +@json_schema_type +class AgenticSystemSessionCreateResponse(BaseModel): + session_id: str + + +@json_schema_type +# what's the URI? +class AgenticSystemTurnCreateRequest(BaseModel): + system_id: str + session_id: str + + messages: List[ + Union[ + UserMessage, + ToolResponseMessage, + ] + ] + + stream: Optional[bool] = False + override_config: Optional[AgenticSystemInstanceConfig] = None + + +@json_schema_type( + schema={"description": "Server side event (SSE) stream of these events"} +) +class AgenticSystemTurnResponseStreamChunk(BaseModel): + event: AgenticSystemTurnResponseEvent + + +@json_schema_type +class AgenticSystemStepResponse(BaseModel): + step: Step + + +class AgenticSystem(Protocol): + + @webmethod(route="/agentic_system/create") + async def create_agentic_system( + self, + request: AgenticSystemCreateRequest, + ) -> AgenticSystemCreateResponse: ... + + @webmethod(route="/agentic_system/turn/create") + async def create_agentic_system_turn( + self, + request: AgenticSystemTurnCreateRequest, + ) -> AgenticSystemTurnResponseStreamChunk: ... + + @webmethod(route="/agentic_system/turn/get") + async def get_agentic_system_turn( + self, + agent_id: str, + turn_id: str, + ) -> Turn: ... + + @webmethod(route="/agentic_system/step/get") + async def get_agentic_system_step( + self, agent_id: str, turn_id: str, step_id: str + ) -> AgenticSystemStepResponse: ... + + @webmethod(route="/agentic_system/session/create") + async def create_agentic_system_session( + self, + request: AgenticSystemSessionCreateRequest, + ) -> AgenticSystemSessionCreateResponse: ... + + @webmethod(route="/agentic_system/memory_bank/attach") + async def attach_memory_bank_to_agentic_system( + self, + agent_id: str, + session_id: str, + memory_bank_ids: List[str], + ) -> None: ... + + @webmethod(route="/agentic_system/memory_bank/detach") + async def detach_memory_bank_from_agentic_system( + self, + agent_id: str, + session_id: str, + memory_bank_ids: List[str], + ) -> None: ... + + @webmethod(route="/agentic_system/session/get") + async def get_agentic_system_session( + self, + agent_id: str, + session_id: str, + turn_ids: Optional[List[str]] = None, + ) -> Session: ... + + @webmethod(route="/agentic_system/session/delete") + async def delete_agentic_system_session( + self, agent_id: str, session_id: str + ) -> None: ... + + @webmethod(route="/agentic_system/delete") + async def delete_agentic_system( + self, + agent_id: str, + ) -> None: ... diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py new file mode 100644 index 000000000..71c578e2f --- /dev/null +++ b/llama_toolchain/agentic_system/client.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json + +from typing import AsyncGenerator + +import fire + +import httpx + +from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams + +from .api import ( + AgenticSystem, + AgenticSystemCreateRequest, + AgenticSystemCreateResponse, + AgenticSystemInstanceConfig, + AgenticSystemSessionCreateRequest, + AgenticSystemSessionCreateResponse, + AgenticSystemToolDefinition, + AgenticSystemTurnCreateRequest, + AgenticSystemTurnResponseStreamChunk, +) + + +async def get_client_impl(base_url: str): + return AgenticSystemClient(base_url) + + +class AgenticSystemClient(AgenticSystem): + def __init__(self, base_url: str): + self.base_url = base_url + + async def create_agentic_system( + self, request: AgenticSystemCreateRequest + ) -> AgenticSystemCreateResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/agentic_system/create", + data=request.json(), + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return AgenticSystemCreateResponse(**response.json()) + + async def create_agentic_system_session( + self, + request: AgenticSystemSessionCreateRequest, + ) -> AgenticSystemSessionCreateResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/agentic_system/session/create", + data=request.json(), + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return AgenticSystemSessionCreateResponse(**response.json()) + + async def create_agentic_system_turn( + self, + request: AgenticSystemTurnCreateRequest, + ) -> AsyncGenerator: + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", + f"{self.base_url}/agentic_system/turn/create", + data=request.json(), + headers={"Content-Type": "application/json"}, + timeout=20, + ) as response: + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line[len("data: ") :] + try: + yield AgenticSystemTurnResponseStreamChunk( + **json.loads(data) + ) + except Exception as e: + print(data) + print(f"Error with parsing or validation: {e}") + + +async def run_main(host: str, port: int): + # client to test remote impl of agentic system + api = await AgenticSystemClient(f"http://{host}:{port}") + + tool_definitions = [ + AgenticSystemToolDefinition( + tool_name=BuiltinTool.brave_search, + ), + AgenticSystemToolDefinition( + tool_name=BuiltinTool.wolfram_alpha, + ), + AgenticSystemToolDefinition( + tool_name=BuiltinTool.photogen, + ), + AgenticSystemToolDefinition( + tool_name=BuiltinTool.code_interpreter, + ), + ] + + create_request = AgenticSystemCreateRequest( + model="Meta-Llama3.1-8B-Instruct", + instance_config=AgenticSystemInstanceConfig( + instructions="You are a helpful assistant", + sampling_params=SamplingParams(), + available_tools=tool_definitions, + input_shields=[], + output_shields=[], + quantization_config=None, + debug_prefix_messages=[], + ), + ) + + create_response = await api.create_agentic_system(create_request) + print(create_response) + # TODO: Add chat session / turn apis to test e2e + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_toolchain/agentic_system/config.py b/llama_toolchain/agentic_system/config.py new file mode 100644 index 000000000..349784021 --- /dev/null +++ b/llama_toolchain/agentic_system/config.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class AgenticSystemConfig(BaseModel): + # placeholder, no separate configuration is needed for now + pass diff --git a/llama_toolchain/agentic_system/safety.py b/llama_toolchain/agentic_system/safety.py new file mode 100644 index 000000000..f066baf59 --- /dev/null +++ b/llama_toolchain/agentic_system/safety.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_models.llama3_1.api.datatypes import Message, Role +from llama_toolchain.safety.api.datatypes import ( + OnViolationAction, + ShieldDefinition, + ShieldResponse, +) +from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety +from termcolor import cprint + + +class SafetyException(Exception): # noqa: N818 + def __init__(self, response: ShieldResponse): + self.response = response + super().__init__(response.violation_return_message) + + +class ShieldRunnerMixin: + + def __init__( + self, + safety_api: Safety, + input_shields: List[ShieldDefinition] = None, + output_shields: List[ShieldDefinition] = None, + ): + self.safety_api = safety_api + self.input_shields = input_shields + self.output_shields = output_shields + + async def run_shields( + self, messages: List[Message], shields: List[ShieldDefinition] + ) -> List[ShieldResponse]: + # some shields like llama-guard require the first message to be a user message + # since this might be a tool call, first role might not be user + if len(messages) > 0 and messages[0].role != Role.user.value: + # TODO(ashwin): we need to change the type of the message, this kind of modification + # is no longer appropriate + messages[0].role = Role.user.value + + res = await self.safety_api.run_shields( + RunShieldRequest( + messages=messages, + shields=shields, + ) + ) + + results = res.responses + for shield, r in zip(shields, results): + if r.is_violation: + if shield.on_violation_action == OnViolationAction.RAISE: + raise SafetyException(r) + elif shield.on_violation_action == OnViolationAction.WARN: + cprint( + f"[Warn]{shield.__class__.__name__} raised a warning", + color="red", + ) + + return results diff --git a/llama_toolchain/agentic_system/system_prompt.py b/llama_toolchain/agentic_system/system_prompt.py new file mode 100644 index 000000000..c8c616285 --- /dev/null +++ b/llama_toolchain/agentic_system/system_prompt.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import List + +from llama_toolchain.inference.api import ( + BuiltinTool, + Message, + SystemMessage, + ToolDefinition, +) + +from .tools.builtin import SingleMessageBuiltinTool + + +def get_agentic_prefix_messages( + builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition] +) -> List[Message]: + messages = [] + content = "" + if builtin_tools: + content += "Environment: ipython\n" + + tool_str = ", ".join( + [ + t.get_name() + for t in builtin_tools + if t.get_name() != BuiltinTool.code_interpreter.value + ] + ) + if tool_str: + content += f"Tools: {tool_str}\n" + + current_date = datetime.now() + formatted_date = current_date.strftime("%d %B %Y") + date_str = f""" +Cutting Knowledge Date: December 2023 +Today Date: {formatted_date}\n\n""" + content += date_str + + if custom_tools: + custom_message = get_system_prompt_for_custom_tools(custom_tools) + content += custom_message + + # TODO: Replace this hard coded message with instructions coming in the request + if False: + content += "You are a helpful Assistant." + + messages.append(SystemMessage(content=content)) + return messages + + +def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str: + custom_tool_params = "" + for t in custom_tools: + custom_tool_params += get_instruction_string(t) + "\n" + custom_tool_params += get_parameters_string(t) + "\n\n" + + content = f""" +You have access to the following functions: + +{custom_tool_params} +Think very carefully before calling functions. +If you choose to call a function ONLY reply in the following format with no prefix or suffix: + +{{"example_name": "example_value"}} + +Reminder: +- If looking for real time information use relevant functions before falling back to brave_search +- Function calls MUST follow the specified format, start with +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line + +""" + return content + + +def get_instruction_string(custom_tool_definition) -> str: + return f"Use the function '{custom_tool_definition.tool_name}' to '{custom_tool_definition.description}'" + + +def get_parameters_string(custom_tool_definition) -> str: + return json.dumps( + { + "name": custom_tool_definition.tool_name, + "description": custom_tool_definition.description, + "parameters": { + name: definition.__dict__ + for name, definition in custom_tool_definition.parameters.items() + }, + } + ) + + +# NOTE: Unused right now +def translate_custom_tool_definition_to_json(tool_def): + """Translates ToolDefinition to json as expected by model + eg. output for a function + { + "type": "function", + "function": { + "name": "conv_int", + "description": "Convert serialized fract24 integer into int value.", + "parameters": { + "type": "object", + "properties": [ + { + "data": { + "type": "object", + "description": "" + } + } + ], + "required": ["data"] + } + } + } + """ + assert isinstance(tool_def.tool_name, str) + func_def = {"type": "function", "function": {}} + func_def["function"]["name"] = tool_def.tool_name + func_def["function"]["description"] = tool_def.description or "" + if tool_def.parameters: + required = [] + properties = [] + for p_name, p_def in tool_def.parameters.items(): + properties.append( + { + p_name: { + # TODO: see if this should not always be object + "type": "object", + "description": p_def.description or "", + } + } + ) + if p_def.required: + required.append(p_name) + func_def["function"]["parameters"] = { + "type": "object", + "properties": properties, + "required": required, + } + else: + func_def["function"]["parameters"] = {} + + return json.dumps(func_def) diff --git a/llama_toolchain/agentic_system/tools/__init__.py b/llama_toolchain/agentic_system/tools/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/agentic_system/tools/base.py b/llama_toolchain/agentic_system/tools/base.py new file mode 100644 index 000000000..3c2722305 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/base.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import List + +from llama_toolchain.inference.api import Message + + +class BaseTool(ABC): + + @abstractmethod + def get_name(self) -> str: + raise NotImplementedError + + @abstractmethod + async def run(self, messages: List[Message]) -> List[Message]: + raise NotImplementedError diff --git a/llama_toolchain/agentic_system/tools/builtin.py b/llama_toolchain/agentic_system/tools/builtin.py new file mode 100644 index 000000000..4487a2692 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/builtin.py @@ -0,0 +1,326 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +import re + +from abc import abstractmethod +from typing import List, Optional + +import requests +from termcolor import cprint + +from .ipython_tool.code_execution import ( + CodeExecutionContext, + CodeExecutionRequest, + CodeExecutor, + TOOLS_ATTACHMENT_KEY_REGEX, +) + +from llama_toolchain.inference.api import * # noqa: F403 + +from .base import BaseTool + + +def interpret_content_as_attachment(content: str) -> Optional[Attachment]: + match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) + if match: + snippet = match.group(1) + data = json.loads(snippet) + return Attachment( + url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"] + ) + + return None + + +class SingleMessageBuiltinTool(BaseTool): + async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: + assert len(messages) == 1, f"Expected single message, got {len(messages)}" + + message = messages[0] + assert len(message.tool_calls) == 1, "Expected a single tool call" + + tool_call = messages[0].tool_calls[0] + + query = tool_call.arguments["query"] + response: str = await self.run_impl(query) + + message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=response, + ) + if attachment := interpret_content_as_attachment(response): + message.content = attachment + + return [message] + + @abstractmethod + async def run_impl(self, query: str) -> str: + raise NotImplementedError() + + +class PhotogenTool(SingleMessageBuiltinTool): + + def __init__(self, dump_dir: str) -> None: + self.dump_dir = dump_dir + + def get_name(self) -> str: + return BuiltinTool.photogen.value + + async def run_impl(self, query: str) -> str: + """ + Implement this to give the model an ability to generate images. + + Return: + info = { + "filepath": str(image_filepath), + "mimetype": "image/png", + } + """ + raise NotImplementedError() + + +class BraveSearchTool(SingleMessageBuiltinTool): + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + + def get_name(self) -> str: + return BuiltinTool.brave_search.value + + async def run_impl(self, query: str) -> str: + url = "https://api.search.brave.com/res/v1/web/search" + headers = { + "X-Subscription-Token": self.api_key, + "Accept-Encoding": "gzip", + "Accept": "application/json", + } + payload = {"q": query} + response = requests.get(url=url, params=payload, headers=headers) + return json.dumps(self._clean_brave_response(response.json())) + + def _clean_brave_response(self, search_response, top_k=3): + query = None + clean_response = [] + if "query" in search_response: + if "original" in search_response["query"]: + query = search_response["query"]["original"] + if "mixed" in search_response: + mixed_results = search_response["mixed"] + for m in mixed_results["main"][:top_k]: + r_type = m["type"] + results = search_response[r_type]["results"] + if r_type == "web": + # For web data - add a single output from the search + idx = m["index"] + selected_keys = [ + "type", + "title", + "url", + "description", + "date", + "extra_snippets", + ] + cleaned = { + k: v for k, v in results[idx].items() if k in selected_keys + } + elif r_type == "faq": + # For faw data - take a list of all the questions & answers + selected_keys = ["type", "question", "answer", "title", "url"] + cleaned = [] + for q in results: + cleaned.append( + {k: v for k, v in q.items() if k in selected_keys} + ) + elif r_type == "infobox": + idx = m["index"] + selected_keys = [ + "type", + "title", + "url", + "description", + "long_desc", + ] + cleaned = { + k: v for k, v in results[idx].items() if k in selected_keys + } + elif r_type == "videos": + selected_keys = [ + "type", + "url", + "title", + "description", + "date", + ] + cleaned = [] + for q in results: + cleaned.append( + {k: v for k, v in q.items() if k in selected_keys} + ) + elif r_type == "locations": + # For faw data - take a list of all the questions & answers + selected_keys = [ + "type", + "title", + "url", + "description", + "coordinates", + "postal_address", + "contact", + "rating", + "distance", + "zoom_level", + ] + cleaned = [] + for q in results: + cleaned.append( + {k: v for k, v in q.items() if k in selected_keys} + ) + elif r_type == "news": + # For faw data - take a list of all the questions & answers + selected_keys = [ + "type", + "title", + "url", + "description", + ] + cleaned = [] + for q in results: + cleaned.append( + {k: v for k, v in q.items() if k in selected_keys} + ) + else: + cleaned = [] + + clean_response.append(cleaned) + + return {"query": query, "top_k": clean_response} + + +class WolframAlphaTool(SingleMessageBuiltinTool): + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + self.url = "https://api.wolframalpha.com/v2/query" + + def get_name(self) -> str: + return BuiltinTool.wolfram_alpha.value + + async def run_impl(self, query: str) -> str: + params = { + "input": query, + "appid": self.api_key, + "format": "plaintext", + "output": "json", + } + response = requests.get( + self.url, + params=params, + ) + + return json.dumps(self._clean_wolfram_alpha_response(response.json())) + + def _clean_wolfram_alpha_response(self, wa_response): + remove = { + "queryresult": [ + "datatypes", + "error", + "timedout", + "timedoutpods", + "numpods", + "timing", + "parsetiming", + "parsetimedout", + "recalculate", + "id", + "host", + "server", + "related", + "version", + { + "pods": [ + "scanner", + "id", + "error", + "expressiontypes", + "states", + "infos", + "position", + "numsubpods", + ] + }, + "assumptions", + ], + } + for main_key in remove: + for key_to_remove in remove[main_key]: + try: + if key_to_remove == "assumptions": + if "assumptions" in wa_response[main_key]: + del wa_response[main_key][key_to_remove] + if isinstance(key_to_remove, dict): + for sub_key in key_to_remove: + if sub_key == "pods": + for i in range(len(wa_response[main_key][sub_key])): + if ( + wa_response[main_key][sub_key][i]["title"] + == "Result" + ): + del wa_response[main_key][sub_key][i + 1 :] + break + sub_items = wa_response[main_key][sub_key] + for i in range(len(sub_items)): + for sub_key_to_remove in key_to_remove[sub_key]: + if sub_key_to_remove in sub_items[i]: + del sub_items[i][sub_key_to_remove] + elif key_to_remove in wa_response[main_key]: + del wa_response[main_key][key_to_remove] + except KeyError: + pass + return wa_response + + +class CodeInterpreterTool(BaseTool): + + def __init__(self) -> None: + ctx = CodeExecutionContext( + matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump", + ) + self.code_executor = CodeExecutor(ctx) + + def get_name(self) -> str: + return BuiltinTool.code_interpreter.value + + async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: + message = messages[0] + assert len(message.tool_calls) == 1, "Expected a single tool call" + + tool_call = messages[0].tool_calls[0] + script = tool_call.arguments["code"] + + req = CodeExecutionRequest(scripts=[script]) + res = self.code_executor.execute(req) + + pieces = [res["process_status"]] + for out_type in ["stdout", "stderr"]: + res_out = res[out_type] + if res_out != "": + pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"]) + if out_type == "stderr": + cprint(f"ipython tool error: ↓\n{res_out}", color="red") + + message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content="\n".join(pieces), + ) + if attachment := interpret_content_as_attachment(res["stdout"]): + message.content = attachment + + return [message] diff --git a/llama_toolchain/agentic_system/tools/custom.py b/llama_toolchain/agentic_system/tools/custom.py new file mode 100644 index 000000000..35e3dd57d --- /dev/null +++ b/llama_toolchain/agentic_system/tools/custom.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +from abc import abstractmethod +from typing import Dict, List + +from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_toolchain.agentic_system.api import * # noqa: F403 + +from .builtin import interpret_content_as_attachment + + +class CustomTool: + """ + Developers can define their custom tools that models can use + by extending this class. + + Developers need to provide + - name + - description + - params_definition + - implement tool's behavior in `run_impl` method + + NOTE: The return of the `run` method needs to be json serializable + """ + + @abstractmethod + def get_name(self) -> str: + raise NotImplementedError + + @abstractmethod + def get_description(self) -> str: + raise NotImplementedError + + @abstractmethod + def get_params_definition(self) -> Dict[str, ToolParamDefinition]: + raise NotImplementedError + + def get_instruction_string(self) -> str: + return f"Use the function '{self.get_name()}' to: {self.get_description()}" + + def parameters_for_system_prompt(self) -> str: + return json.dumps( + { + "name": self.get_name(), + "description": self.get_description(), + "parameters": { + name: definition.__dict__ + for name, definition in self.get_params_definition().items() + }, + } + ) + + def get_tool_definition(self) -> AgenticSystemToolDefinition: + return AgenticSystemToolDefinition( + tool_name=self.get_name(), + description=self.get_description(), + parameters=self.get_params_definition(), + ) + + @abstractmethod + async def run(self, messages: List[Message]) -> List[Message]: + raise NotImplementedError + + +class SingleMessageCustomTool(CustomTool): + """ + Helper class to handle custom tools that take a single message + Extending this class and implementing the `run_impl` method will + allow for the tool be called by the model and the necessary plumbing. + """ + + async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: + assert len(messages) == 1, "Expected single message" + + message = messages[0] + + tool_call = message.tool_calls[0] + + try: + response = await self.run_impl(**tool_call.arguments) + response_str = json.dumps(response, ensure_ascii=False) + except Exception as e: + response_str = f"Error when running tool: {e}" + + message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=response_str, + ) + if attachment := interpret_content_as_attachment(response_str): + message.content = attachment + + return [message] + + @abstractmethod + async def run_impl(self, *args, **kwargs): + raise NotImplementedError() diff --git a/llama_toolchain/agentic_system/tools/execute.py b/llama_toolchain/agentic_system/tools/execute.py new file mode 100644 index 000000000..2a7625f65 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/execute.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, AsyncGenerator, List + +from llama_models.llama3_1.api.datatypes import StopReason, ToolResponseMessage + +from llama_toolchain.agentic_system.api import ( + AgenticSystem, + AgenticSystemTurnCreateRequest, + AgenticSystemTurnResponseEventType as EventType, +) + +from llama_toolchain.inference.api import Message + + +async def execute_with_custom_tools( + system: AgenticSystem, + system_id: str, + session_id: str, + messages: List[Message], + custom_tools: List[Any], + max_iters: int = 5, + stream: bool = True, +) -> AsyncGenerator: + # first create a session, or do you keep a persistent session? + tools_dict = {t.get_name(): t for t in custom_tools} + + current_messages = messages.copy() + n_iter = 0 + while n_iter < max_iters: + n_iter += 1 + + request = AgenticSystemTurnCreateRequest( + system_id=system_id, + session_id=session_id, + messages=current_messages, + stream=stream, + ) + + turn = None + async for chunk in system.create_agentic_system_turn(request): + if chunk.event.payload.event_type != EventType.turn_complete.value: + yield chunk + else: + turn = chunk.event.payload.turn + break + + message = turn.output_message + if len(message.tool_calls) == 0: + yield chunk + return + + if message.stop_reason == StopReason.out_of_tokens: + yield chunk + return + + tool_call = message.tool_calls[0] + if tool_call.tool_name not in tools_dict: + m = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", + ) + next_message = m + else: + tool = tools_dict[tool_call.tool_name] + result_messages = await execute_custom_tool(tool, message) + next_message = result_messages[0] + + yield next_message + current_messages = [next_message] + + +async def execute_custom_tool(tool: Any, message: Message) -> List[Message]: + result_messages = await tool.run([message]) + assert ( + len(result_messages) == 1 + ), f"Expected single message, got {len(result_messages)}" + + return result_messages diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/__init__.py b/llama_toolchain/agentic_system/tools/ipython_tool/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/ipython_tool/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/code_env_prefix.py b/llama_toolchain/agentic_system/tools/ipython_tool/code_env_prefix.py new file mode 100644 index 000000000..10f64ec94 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/ipython_tool/code_env_prefix.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import errno + +# Disabling potentially dangerous functions +import os as _os +from functools import partial + +os_funcs_to_disable = [ + "kill", + "system", + "putenv", + "remove", + "removedirs", + "rmdir", + "fchdir", + "setuid", + "fork", + "forkpty", + "killpg", + "rename", + "renames", + "truncate", + "replace", + # "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly + "fchmod", + "fchown", + "chmod", + "chown", + "chroot", + "fchdir", + "lchflags", + "lchmod", + "lchown", + "chdir", +] + + +def call_not_allowed(*args, **kwargs): + raise OSError(errno.EPERM, "Call are not permitted in this environment") + + +for func_name in os_funcs_to_disable: + if hasattr(_os, func_name): + setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}")) + +import shutil as _shutil + +for func_name in ["rmtree", "move", "chown"]: + if hasattr(_shutil, func_name): + setattr( + _shutil, + func_name, + partial(call_not_allowed, _func_name=f"shutil.{func_name}"), + ) + +import subprocess as _subprocess + + +def popen_not_allowed(*args, **kwargs): + raise _subprocess.CalledProcessError( + -1, + args[0] if args else "unknown", + stderr="subprocess.Popen is not allowed in this environment", + ) + + +_subprocess.Popen = popen_not_allowed + + +import atexit as _atexit +import builtins as _builtins +import io as _io +import json as _json +import sys as _sys + +# NB! The following "unused" imports crucial, make sure not not to remove +# them with linters - they're used in code_execution.py +from contextlib import ( # noqa + contextmanager as _contextmanager, + redirect_stderr as _redirect_stderr, + redirect_stdout as _redirect_stdout, +) +from multiprocessing.connection import Connection as _Connection + +# Mangle imports to avoid polluting model execution namespace. + +_IO_SINK = _io.StringIO() +_NETWORK_TIMEOUT = 5 +_NETWORK_CONNECTIONS = None + + +def _open_connections(): + global _NETWORK_CONNECTIONS + if _NETWORK_CONNECTIONS is not None: + # Ensure connections only opened once. + return _NETWORK_CONNECTIONS + req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2] + req_con = _Connection(int(req_w_fd), readable=False) + resp_con = _Connection(int(resp_r_fd), writable=False) + _NETWORK_CONNECTIONS = (req_con, resp_con) + return _NETWORK_CONNECTIONS + + +_builtins._open_connections = _open_connections + + +@_atexit.register +def _close_connections(): + global _NETWORK_CONNECTIONS + if _NETWORK_CONNECTIONS is None: + return + for con in _NETWORK_CONNECTIONS: + con.close() + del _NETWORK_CONNECTIONS + + +def _network_call(request): + # NOTE: We communicate with the parent process in json, encoded + # in raw bytes. We do this because native send/recv methods use + # pickle which involves execution of arbitrary code. + _open_connections() + req_con, resp_con = _NETWORK_CONNECTIONS + + req_con.send_bytes(_json.dumps(request).encode("utf-8")) + if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None: + raise Exception(f"Network request timed out: {_json.dumps(request)}") + else: + return _json.loads(resp_con.recv_bytes().decode("utf-8")) diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/code_execution.py b/llama_toolchain/agentic_system/tools/ipython_tool/code_execution.py new file mode 100644 index 000000000..fa2e367e5 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/ipython_tool/code_execution.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import json +import multiprocessing +import os +import re +import subprocess +import sys +import tempfile +import textwrap +import time +from dataclasses import dataclass +from datetime import datetime +from io import BytesIO +from pathlib import Path +from typing import List + +from PIL import Image + +from .utils import get_code_env_prefix + +TOOLS_ATTACHMENT_KEY = "__tools_attachment__" +TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") + +DIRNAME = Path(__file__).parent + +CODE_EXEC_TIMEOUT = 20 +CODE_ENV_PREFIX = get_code_env_prefix() + +STDOUTERR_SINK_WRAPPER_TEMPLATE = """\ +with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK): +{code}\ +""" + +TRYEXCEPT_WRAPPER_TEMPLATE = """\ +try: +{code} +except: + pass\ +""" + + +def generate_bwrap_command(bind_dirs: List[str]) -> str: + """ + Generate the bwrap command string for binding all + directories in the current directory read-only. + """ + bwrap_args = "" + bwrap_args += "--ro-bind / / " + # Add the --dev flag to mount device files + bwrap_args += "--dev /dev " + for d in bind_dirs: + bwrap_args += f"--bind {d} {d} " + + # Add the --unshare-all flag to isolate the sandbox from the rest of the system + bwrap_args += "--unshare-all " + # Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies + bwrap_args += "--die-with-parent " + return bwrap_args + + +@dataclass +class CodeExecutionContext: + matplotlib_dump_dir: str + use_proxy: bool = False + + +@dataclass +class CodeExecutionRequest: + scripts: List[str] + only_last_cell_stdouterr: bool = True + only_last_cell_fail: bool = True + seed: int = 0 + strip_fpaths_in_stderr: bool = True + + +class CodeExecutor: + def __init__(self, context: CodeExecutionContext): + self.context = context + + def execute(self, req: CodeExecutionRequest) -> dict: + scripts = req.scripts + for i in range(len(scripts) - 1): + if req.only_last_cell_stdouterr: + scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format( + code=textwrap.indent(scripts[i], " " * 4) + ) + if req.only_last_cell_fail: + scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format( + code=textwrap.indent(scripts[i], " " * 4) + ) + + # Seeds prefix: + seed = req.seed + seeds_prefix = f"""\ +def _set_seeds(): + import random + random.seed({seed}) + import numpy as np + np.random.seed({seed}) +_set_seeds()\ +""" + + script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts) + with tempfile.TemporaryDirectory() as dpath: + bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath]) + cmd = [*bwrap_prefix.split(), sys.executable, "-c", script] + code_fpath = os.path.join(dpath, "code.py") + with open(code_fpath, "w") as f: + f.write(script) + + try: + python_path = os.environ.get("PYTHONPATH", "") + env = dict( + os.environ, + PYTHONHASHSEED=str(seed), + MPLCONFIGDIR=dpath, + MPLBACKEND="module://matplotlib_custom_backend", + PYTHONPATH=f"{DIRNAME}:{python_path}", + ) + stdout, stderr, returncode = do_subprocess( + cmd=cmd, + env=env, + ctx=self.context, + ) + + stderr = stderr.strip() + if req.strip_fpaths_in_stderr: + pattern = r'File "([^"]+)", line (\d+)' + stderr = re.sub(pattern, r"line \2", stderr) + + return { + "process_status": "completed", + "returncode": returncode, + "stdout": stdout.strip(), + "stderr": stderr, + } + + except subprocess.TimeoutExpired: + return { + "process_status": "timeout", + "stdout": "Timed out", + "stderr": "Timed out", + } + + except Exception as e: + return { + "process_status": "error", + "error_type": type(e).__name__, + "stderr": str(e), + "stdout": str(e), + } + + +def process_matplotlib_response(response, matplotlib_dump_dir: str): + image_data = response["image_data"] + # Convert the base64 string to a bytes object + images = [base64.b64decode(d["image_base64"]) for d in image_data] + # Create a list of PIL images from the bytes objects + images = [Image.open(BytesIO(img)) for img in images] + # Create a list of image paths + image_paths = [] + for i, img in enumerate(images): + # create new directory for each day to better organize data: + dump_dname = datetime.today().strftime("%Y-%m-%d") + dump_dpath = Path(matplotlib_dump_dir, dump_dname) + dump_dpath.mkdir(parents=True, exist_ok=True) + # save image into a file + dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png" + dump_fpath = dump_dpath / dump_fname + img.save(dump_fpath, "PNG") + image_paths.append(str(dump_fpath)) + + # this is kind of convoluted, we send back this response to the subprocess which + # prints it out + info = { + "filepath": str(image_paths[-1]), + "mimetype": "image/png", + } + return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}" + + +def execute_subprocess_request(request, ctx: CodeExecutionContext): + "Route requests from the subprocess (via network Pipes) to the internet/tools." + if request["type"] == "matplotlib": + return process_matplotlib_response(request, ctx.matplotlib_dump_dir) + else: + raise Exception(f'Unrecognised network request type: {request["type"]}') + + +def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext): + # Create Pipes to be used for any external tool/network requests. + req_r, req_w = multiprocessing.Pipe(duplex=False) + resp_r, resp_w = multiprocessing.Pipe(duplex=False) + + cmd += [str(req_w.fileno()), str(resp_r.fileno())] + proc = subprocess.Popen( + cmd, + pass_fds=(req_w.fileno(), resp_r.fileno()), + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + env=env, + ) + + # Close unnecessary fds. + req_w.close() + resp_r.close() + + pipe_close = False + done_read = False + start = time.monotonic() + while proc.poll() is None and not pipe_close: + if req_r.poll(0.1): + # NB: Python pipe semantics for poll and recv mean that + # poll() returns True is a pipe is closed. + # CF old school PEP from '09 + # https://bugs.python.org/issue5573 + try: + request = json.loads(req_r.recv_bytes().decode("utf-8")) + response = execute_subprocess_request(request, ctx) + + resp_w.send_bytes(json.dumps(response).encode("utf-8")) + except EOFError: + # The request pipe is closed - set a marker to exit + # after the next attempt at reading stdout/stderr. + pipe_close = True + + try: + # If lots has been printed, pipe might be full but + # proc cannot exit until all the stdout/stderr + # been written/read. + stdout, stderr = proc.communicate(timeout=0.3) + done_read = True + except subprocess.TimeoutExpired: + # The program has not terminated. Ignore it, there + # may be more network/tool requests. + continue + if time.monotonic() - start > CODE_EXEC_TIMEOUT: + proc.terminate() + raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT) + + if not done_read: + # Solve race condition where process terminates before + # we hit the while loop. + stdout, stderr = proc.communicate(timeout=0.3) + + resp_w.close() + req_r.close() + return stdout, stderr, proc.returncode diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/matplotlib_custom_backend.py b/llama_toolchain/agentic_system/tools/ipython_tool/matplotlib_custom_backend.py new file mode 100644 index 000000000..3aba2ef21 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/ipython_tool/matplotlib_custom_backend.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +A custom Matplotlib backend that overrides the show method to return image bytes. +""" + +import base64 +import io +import json as _json + +import matplotlib +from matplotlib.backend_bases import FigureManagerBase + +# Import necessary components from Matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + + +class CustomFigureCanvas(FigureCanvasAgg): + def show(self): + # Save the figure to a BytesIO object + buf = io.BytesIO() + self.print_png(buf) + image_bytes = buf.getvalue() + buf.close() + return image_bytes + + +class CustomFigureManager(FigureManagerBase): + def __init__(self, canvas, num): + super().__init__(canvas, num) + + +# Mimic module initialization that integrates with the Matplotlib backend system +def _create_figure_manager(num, *args, **kwargs): + """ + Create a custom figure manager instance. + """ + FigureClass = kwargs.pop("FigureClass", None) # noqa: N806 + if FigureClass is None: + from matplotlib.figure import Figure + + FigureClass = Figure # noqa: N806 + fig = FigureClass(*args, **kwargs) + canvas = CustomFigureCanvas(fig) + manager = CustomFigureManager(canvas, num) + return manager + + +def show(): + """ + Handle all figures and potentially return their images as bytes. + + This function iterates over all figures registered with the custom backend, + renders them as images in bytes format, and could return a list of bytes objects, + one for each figure, or handle them as needed. + """ + image_data = [] + for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers(): + # Get the figure from the manager + fig = manager.canvas.figure + buf = io.BytesIO() # Create a buffer for the figure + fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format + buf.seek(0) # Go to the beginning of the buffer + image_bytes = buf.getvalue() # Retrieve bytes value + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + image_data.append({"image_base64": image_base64}) + buf.close() + + req_con, resp_con = _open_connections() + + _json_dump = _json.dumps( + { + "type": "matplotlib", + "image_data": image_data, + } + ) + req_con.send_bytes(_json_dump.encode("utf-8")) + resp = _json.loads(resp_con.recv_bytes().decode("utf-8")) + print(resp) + + +FigureCanvas = CustomFigureCanvas +FigureManager = CustomFigureManager diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/utils.py b/llama_toolchain/agentic_system/tools/ipython_tool/utils.py new file mode 100644 index 000000000..d6f539a39 --- /dev/null +++ b/llama_toolchain/agentic_system/tools/ipython_tool/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +DIR = os.path.dirname(os.path.realpath(__file__)) +CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py") +CODE_ENV_PREFIX = None + + +def get_code_env_prefix() -> str: + global CODE_ENV_PREFIX + + if CODE_ENV_PREFIX is None: + with open(CODE_ENV_PREFIX_FILE, "r") as f: + CODE_ENV_PREFIX = f.read() + + return CODE_ENV_PREFIX diff --git a/llama_toolchain/agentic_system/tools/safety.py b/llama_toolchain/agentic_system/tools/safety.py new file mode 100644 index 000000000..da0abe10a --- /dev/null +++ b/llama_toolchain/agentic_system/tools/safety.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_toolchain.agentic_system.safety import ShieldRunnerMixin + +from llama_toolchain.inference.api import Message +from llama_toolchain.safety.api.datatypes import ShieldDefinition +from llama_toolchain.safety.api.endpoints import Safety + +from .builtin import BaseTool + + +class SafeTool(BaseTool, ShieldRunnerMixin): + """A tool that makes other tools safety enabled""" + + def __init__( + self, + tool: BaseTool, + safety_api: Safety, + input_shields: List[ShieldDefinition] = None, + output_shields: List[ShieldDefinition] = None, + ): + self._tool = tool + ShieldRunnerMixin.__init__( + self, safety_api, input_shields=input_shields, output_shields=output_shields + ) + + def get_name(self) -> str: + # return the name of the wrapped tool + return self._tool.get_name() + + async def run(self, messages: List[Message]) -> List[Message]: + if self.input_shields: + await self.run_shields(messages, self.input_shields) + # run the underlying tool + res = await self._tool.run(messages) + if self.output_shields: + await self.run_shields(messages, self.output_shields) + + return res + + +def with_safety( + tool: BaseTool, + safety_api: Safety, + input_shields: List[ShieldDefinition] = None, + output_shields: List[ShieldDefinition] = None, +) -> SafeTool: + return SafeTool( + tool, + safety_api, + input_shields=input_shields, + output_shields=output_shields, + ) diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_toolchain/common/prompt_for_config.py index c708b96d7..071d40cb6 100644 --- a/llama_toolchain/common/prompt_for_config.py +++ b/llama_toolchain/common/prompt_for_config.py @@ -144,7 +144,11 @@ def prompt_for_config( nested_type = get_non_none_type(field_type) print(f"Entering sub-configuration for {field_name}:") config_data[field_name] = prompt_for_config(nested_type, existing_value) - elif inspect.isclass(field_type) and issubclass(field_type, BaseModel): + elif ( + inspect.isclass(field_type) + and issubclass(field_type, BaseModel) + and len(field_type.__fields__) > 0 + ): print(f"\nEntering sub-configuration for {field_name}:") config_data[field_name] = prompt_for_config( field_type, diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 82196a00b..762b3d487 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -15,6 +15,7 @@ from strong_typing.schema import json_schema_type class ApiSurface(Enum): inference = "inference" safety = "safety" + agentic_system = "agentic_system" @json_schema_type @@ -39,14 +40,19 @@ class SourceAdapter(Adapter): module: str = Field( ..., description=""" -Fully-qualified name of the module to import. The module is expected to have -a `get_adapter_instance()` method which will be passed a validated config object -of type `config_class`.""", +Fully-qualified name of the module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the local implementation +""", ) config_class: str = Field( ..., description="Fully-qualified classname of the config for this adapter", ) + adapter_dependencies: List[ApiSurface] = Field( + default_factory=list, + description="Higher-level API surfaces may depend on other adapters to provide their functionality", + ) @json_schema_type @@ -56,6 +62,13 @@ class PassthroughApiAdapter(Adapter): default_factory=dict, description="Headers (e.g., authorization) to send with the request", ) + module: str = Field( + ..., + description=""" +Fully-qualified name of the module to import. The module is expected to have: + - `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation +""", + ) class Distribution(BaseModel): diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py index 773d05f26..03bd5d3a5 100644 --- a/llama_toolchain/distribution/distribution.py +++ b/llama_toolchain/distribution/distribution.py @@ -7,6 +7,7 @@ import inspect from typing import Dict, List +from llama_toolchain.agentic_system.api.endpoints import AgenticSystem from llama_toolchain.inference.api.endpoints import Inference from llama_toolchain.safety.api.endpoints import Safety @@ -29,6 +30,7 @@ def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]: protocols = { ApiSurface.inference: Inference, ApiSurface.safety: Safety, + ApiSurface.agentic_system: AgenticSystem, } for surface, protocol in protocols.items(): diff --git a/llama_toolchain/distribution/dynamic.py b/llama_toolchain/distribution/dynamic.py index 483c08d79..62e954d46 100644 --- a/llama_toolchain/distribution/dynamic.py +++ b/llama_toolchain/distribution/dynamic.py @@ -8,7 +8,7 @@ import asyncio import importlib from typing import Any, Dict -from .datatypes import SourceAdapter +from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter def instantiate_class_type(fully_qualified_name): @@ -18,9 +18,17 @@ def instantiate_class_type(fully_qualified_name): # returns a class implementing the protocol corresponding to the ApiSurface -def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]): +def instantiate_adapter( + adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter] +): module = importlib.import_module(adapter.module) config_type = instantiate_class_type(adapter.config_class) config = config_type(**adapter_config) - return asyncio.run(module.get_adapter_impl(config)) + return asyncio.run(module.get_adapter_impl(config, deps)) + + +def instantiate_client(adapter: PassthroughApiAdapter, base_url: str): + module = importlib.import_module(adapter.module) + + return asyncio.run(module.get_client_impl(base_url)) diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index eda7a2c1a..897b8f9d0 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -7,6 +7,8 @@ from functools import lru_cache from typing import List, Optional +from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters + from llama_toolchain.inference.adapters import available_inference_adapters from llama_toolchain.safety.adapters import available_safety_adapters @@ -43,10 +45,26 @@ COMMON_DEPENDENCIES = [ ] +def client_module(api_surface: ApiSurface) -> str: + return f"llama_toolchain.{api_surface.value}.client" + + +def passthrough(api_surface: ApiSurface, port: int) -> PassthroughApiAdapter: + return PassthroughApiAdapter( + api_surface=api_surface, + adapter_id=f"{api_surface.value}-passthrough", + base_url=f"http://localhost:{port}", + module=client_module(api_surface), + ) + + @lru_cache() def available_distributions() -> List[Distribution]: inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()} safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()} + agentic_system_adapters_by_id = { + a.adapter_id: a for a in available_agentic_system_adapters() + } return [ Distribution( @@ -56,6 +74,9 @@ def available_distributions() -> List[Distribution]: adapters={ ApiSurface.inference: inference_adapters_by_id["meta-reference"], ApiSurface.safety: safety_adapters_by_id["meta-reference"], + ApiSurface.agentic_system: agentic_system_adapters_by_id[ + "meta-reference" + ], }, ), Distribution( @@ -76,16 +97,9 @@ def available_distributions() -> List[Distribution]: "uvicorn", ], adapters={ - ApiSurface.inference: PassthroughApiAdapter( - api_surface=ApiSurface.inference, - adapter_id="inference-passthrough", - base_url="http://localhost:5001", - ), - ApiSurface.safety: PassthroughApiAdapter( - api_surface=ApiSurface.safety, - adapter_id="safety-passthrough", - base_url="http://localhost:5001", - ), + ApiSurface.inference: passthrough(ApiSurface.inference, 5001), + ApiSurface.safety: passthrough(ApiSurface.safety, 5001), + ApiSurface.agentic_system: passthrough(ApiSurface.agentic_system, 5001), }, ), Distribution( @@ -95,6 +109,9 @@ def available_distributions() -> List[Distribution]: adapters={ ApiSurface.inference: inference_adapters_by_id["meta-ollama"], ApiSurface.safety: safety_adapters_by_id["meta-reference"], + ApiSurface.agentic_system: agentic_system_adapters_by_id[ + "meta-reference" + ], }, ), ] diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index fa57322bd..5ac133d4b 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -12,7 +12,16 @@ from collections.abc import ( AsyncIterator as AsyncIteratorABC, ) from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Dict, + get_type_hints, + List, + Optional, + Set, +) import fire import httpx @@ -27,9 +36,9 @@ from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError from termcolor import cprint -from .datatypes import PassthroughApiAdapter +from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter from .distribution import api_surface_endpoints -from .dynamic import instantiate_adapter +from .dynamic import instantiate_adapter, instantiate_client from .registry import resolve_distribution @@ -213,6 +222,29 @@ def create_dynamic_typed_route(func: Any): return endpoint +def topological_sort(adapters: List[Adapter]) -> List[Adapter]: + + by_id = {x.api_surface: x for x in adapters} + + def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]): + visited.add(a.api_surface) + + for surface in a.adapter_dependencies: + if surface not in visited: + dfs(by_id[surface], visited, stack) + + stack.append(a.api_surface) + + visited = set() + stack = [] + + for a in adapters: + if a.api_surface not in visited: + dfs(a, visited, stack) + + return [by_id[x] for x in stack] + + def main( dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False ): @@ -228,7 +260,13 @@ def main( all_endpoints = api_surface_endpoints() adapter_configs = config["adapters"] - for surface, adapter in dist.adapters.items(): + adapters = topological_sort(dist.adapters.values()) + + # TODO: split this into two parts, first you resolve all impls + # and then you create the routes. + impls = {} + for adapter in adapters: + surface = adapter.api_surface if surface.value not in adapter_configs: raise ValueError( f"Could not find adapter config for {surface}. Please add it to the config" @@ -242,8 +280,11 @@ def main( getattr(app, endpoint.method)(endpoint.route)( create_dynamic_passthrough(url) ) + impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/")) else: - impl = instantiate_adapter(adapter, adapter_config) + deps = {surface: impls[surface] for surface in adapter.adapter_dependencies} + impl = instantiate_adapter(adapter, adapter_config, deps) + impls[surface] = impl for endpoint in endpoints: if not hasattr(impl, endpoint.name): # ideally this should be a typing violation already diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 178452fde..aa84f906d 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -23,6 +23,10 @@ from .api import ( from .event_logger import EventLogger +async def get_client_impl(base_url: str): + return InferenceClient(base_url) + + class InferenceClient(Inference): def __init__(self, base_url: str): print(f"Initializing client for {base_url}") diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index beeb6dd65..7b54313c4 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -6,11 +6,13 @@ import asyncio -from typing import AsyncIterator, Union +from typing import AsyncIterator, Dict, Union from llama_models.llama3_1.api.datatypes import StopReason from llama_models.sku_list import resolve_model +from llama_toolchain.distribution.datatypes import Adapter, ApiSurface + from .api.config import MetaReferenceImplConfig from .api.datatypes import ( ChatCompletionResponseEvent, @@ -27,7 +29,9 @@ from .api.endpoints import ( from .model_parallel import LlamaModelParallelGenerator -async def get_adapter_impl(config: MetaReferenceImplConfig): +async def get_adapter_impl( + config: MetaReferenceImplConfig, _deps: Dict[ApiSurface, Adapter] +): assert isinstance( config, MetaReferenceImplConfig ), f"Unexpected config type: {type(config)}" diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py index fb37bde1a..2bceebc68 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_toolchain/safety/client.py @@ -21,6 +21,10 @@ from .api import ( ) +async def get_client_impl(base_url: str): + return SafetyClient(base_url) + + class SafetyClient(Safety): def __init__(self, base_url: str): print(f"Initializing client for {base_url}") diff --git a/llama_toolchain/safety/safety.py b/llama_toolchain/safety/safety.py index 12405c161..21b7e6f1f 100644 --- a/llama_toolchain/safety/safety.py +++ b/llama_toolchain/safety/safety.py @@ -6,6 +6,10 @@ import asyncio +from typing import Dict + +from llama_toolchain.distribution.datatypes import Adapter, ApiSurface + from .config import SafetyConfig from .api.endpoints import * # noqa from .shields import ( @@ -19,7 +23,7 @@ from .shields import ( ) -async def get_adapter_impl(config: SafetyConfig): +async def get_adapter_impl(config: SafetyConfig, _deps: Dict[ApiSurface, Adapter]): assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" impl = MetaReferenceSafetyImpl(config)