diff --git a/.gitignore b/.gitignore
index 321e946a9..067fe2c6f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
+.env
__pycache__
dist
*.egg-info
diff --git a/llama_toolchain/agentic_system/adapters.py b/llama_toolchain/agentic_system/adapters.py
new file mode 100644
index 000000000..df8e8c9d6
--- /dev/null
+++ b/llama_toolchain/agentic_system/adapters.py
@@ -0,0 +1,29 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import List
+
+from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter
+
+
+def available_agentic_system_adapters() -> List[Adapter]:
+ return [
+ SourceAdapter(
+ api_surface=ApiSurface.agentic_system,
+ adapter_id="meta-reference",
+ pip_packages=[
+ "codeshield",
+ "torch",
+ "transformers",
+ ],
+ module="llama_toolchain.agentic_system.agentic_system",
+ config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig",
+ adapter_dependencies=[
+ ApiSurface.inference,
+ ApiSurface.safety,
+ ],
+ ),
+ ]
diff --git a/llama_toolchain/agentic_system/agentic_system.py b/llama_toolchain/agentic_system/agentic_system.py
new file mode 100644
index 000000000..011893b44
--- /dev/null
+++ b/llama_toolchain/agentic_system/agentic_system.py
@@ -0,0 +1,786 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+
+from llama_toolchain.agentic_system.api import AgenticSystem
+
+from llama_toolchain.distribution.datatypes import Adapter, ApiSurface
+from llama_toolchain.inference.api import Inference
+from llama_toolchain.safety.api import Safety
+
+from .config import AgenticSystemConfig
+from .api.endpoints import * # noqa
+
+import logging
+import os
+import uuid
+from datetime import datetime
+from typing import AsyncGenerator, Dict, List, Optional
+
+from llama_toolchain.inference.api import ChatCompletionRequest
+
+from llama_toolchain.inference.api.datatypes import (
+ Attachment,
+ BuiltinTool,
+ ChatCompletionResponseEventType,
+ CompletionMessage,
+ Message,
+ Role,
+ SamplingParams,
+ StopReason,
+ ToolCallDelta,
+ ToolCallParseStatus,
+ ToolDefinition,
+ ToolResponse,
+ ToolResponseMessage,
+ URL,
+)
+from llama_toolchain.safety.api.datatypes import (
+ BuiltinShield,
+ ShieldDefinition,
+ ShieldResponse,
+)
+
+from termcolor import cprint
+
+from .api.datatypes import (
+ AgenticSystemInstanceConfig,
+ AgenticSystemTurnResponseEvent,
+ AgenticSystemTurnResponseEventType,
+ AgenticSystemTurnResponseStepCompletePayload,
+ AgenticSystemTurnResponseStepProgressPayload,
+ AgenticSystemTurnResponseStepStartPayload,
+ AgenticSystemTurnResponseTurnCompletePayload,
+ AgenticSystemTurnResponseTurnStartPayload,
+ InferenceStep,
+ Session,
+ ShieldCallStep,
+ StepType,
+ ToolExecutionStep,
+ Turn,
+)
+from .api.endpoints import (
+ AgenticSystemCreateRequest,
+ AgenticSystemCreateResponse,
+ AgenticSystemSessionCreateRequest,
+ AgenticSystemSessionCreateResponse,
+ AgenticSystemTurnCreateRequest,
+ AgenticSystemTurnResponseStreamChunk,
+)
+from .safety import SafetyException, ShieldRunnerMixin
+
+from .system_prompt import get_agentic_prefix_messages
+from .tools.base import BaseTool
+from .tools.builtin import (
+ BraveSearchTool,
+ CodeInterpreterTool,
+ PhotogenTool,
+ SingleMessageBuiltinTool,
+ WolframAlphaTool,
+)
+from .tools.safety import with_safety
+
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+
+
+async def get_adapter_impl(
+ config: AgenticSystemConfig, deps: Dict[ApiSurface, Adapter]
+):
+ assert isinstance(
+ config, AgenticSystemConfig
+ ), f"Unexpected config type: {type(config)}"
+
+ impl = MetaReferenceAgenticSystemImpl(
+ deps[ApiSurface.inference],
+ deps[ApiSurface.safety],
+ )
+ await impl.initialize()
+ return impl
+
+
+async def execute_tool_call_maybe(
+ tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
+) -> List[ToolResponseMessage]:
+ # While Tools.run interface takes a list of messages,
+ # All tools currently only run on a single message
+ # When this changes, we can drop this assert
+ # Whether to call tools on each message and aggregate
+ # or aggregate and call tool once, reamins to be seen.
+ assert len(messages) == 1, "Expected single message"
+ message = messages[0]
+
+ tool_call = message.tool_calls[0]
+ name = tool_call.tool_name
+ assert isinstance(name, BuiltinTool)
+
+ name = name.value
+
+ assert name in tools_dict, f"Tool {name} not found"
+ tool = tools_dict[name]
+ result_messages = await tool.run(messages)
+ return result_messages
+
+
+def print_dialog(messages: List[Message]):
+ for i, m in enumerate(messages):
+ if m.role == Role.user.value:
+ color = "red"
+ elif m.role == Role.assistant.value:
+ color = "white"
+ elif m.role == Role.ipython.value:
+ color = "yellow"
+ elif m.role == Role.system.value:
+ color = "green"
+ else:
+ color = "white"
+
+ s = str(m)
+ cprint(f"{i} ::: {s[:100]}...", color=color)
+
+
+AGENT_INSTANCES_BY_ID = {}
+
+
+class AgentInstance(ShieldRunnerMixin):
+ def __init__(
+ self,
+ system_id: int,
+ instance_config: AgenticSystemInstanceConfig,
+ model: str,
+ inference_api: Inference,
+ safety_api: Safety,
+ builtin_tools: List[SingleMessageBuiltinTool],
+ custom_tool_definitions: List[ToolDefinition],
+ input_shields: List[ShieldDefinition],
+ output_shields: List[ShieldDefinition],
+ max_infer_iters: int = 10,
+ prefix_messages: Optional[List[Message]] = None,
+ ):
+ self.system_id = system_id
+ self.instance_config = instance_config
+
+ self.model = model
+ self.inference_api = inference_api
+ self.safety_api = safety_api
+
+ if prefix_messages is not None and len(prefix_messages) > 0:
+ self.prefix_messages = prefix_messages
+ else:
+ self.prefix_messages = get_agentic_prefix_messages(
+ builtin_tools, custom_tool_definitions
+ )
+
+ for m in self.prefix_messages:
+ print(m.content)
+
+ self.max_infer_iters = max_infer_iters
+ self.tools_dict = {t.get_name(): t for t in builtin_tools}
+
+ self.sessions = {}
+
+ ShieldRunnerMixin.__init__(
+ self,
+ safety_api,
+ input_shields=input_shields,
+ output_shields=output_shields,
+ )
+
+ def create_session(self, name: str) -> Session:
+ session_id = str(uuid.uuid4())
+ session = Session(
+ session_id=session_id,
+ session_name=name,
+ turns=[],
+ started_at=datetime.now(),
+ )
+ self.sessions[session_id] = session
+ return session
+
+ async def create_and_execute_turn(
+ self, request: AgenticSystemTurnCreateRequest
+ ) -> AsyncGenerator:
+ assert (
+ request.session_id in self.sessions
+ ), f"Session {request.session_id} not found"
+
+ session = self.sessions[request.session_id]
+
+ messages = []
+ for i, turn in enumerate(session.turns):
+ # print(f"turn {i}")
+ # print_dialog(turn.input_messages)
+ messages.extend(turn.input_messages)
+ for step in turn.steps:
+ if step.step_type == StepType.inference.value:
+ messages.append(step.model_response)
+ elif step.step_type == StepType.tool_execution.value:
+ for response in step.tool_responses:
+ messages.append(
+ ToolResponseMessage(
+ call_id=response.call_id,
+ tool_name=response.tool_name,
+ content=response.content,
+ )
+ )
+ elif step.step_type == StepType.shield_call.value:
+ response = step.response
+ if response.is_violation:
+ # TODO: Properly persist the
+ # CompletionMessage itself in the ShieldResponse
+ messages.append(
+ CompletionMessage(
+ content=response.violation_return_message,
+ stop_reason=StopReason.end_of_turn,
+ )
+ )
+
+ messages.extend(request.messages)
+
+ # print("processed dialog ======== ")
+ # print_dialog(messages)
+
+ turn_id = str(uuid.uuid4())
+ params = self.instance_config.sampling_params
+ start_time = datetime.now()
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseTurnStartPayload(
+ turn_id=turn_id,
+ )
+ )
+ )
+
+ steps = []
+ output_message = None
+ async for chunk in self.run(
+ turn_id=turn_id,
+ input_messages=messages,
+ temperature=params.temperature,
+ top_p=params.top_p,
+ stream=request.stream,
+ max_gen_len=params.max_tokens,
+ ):
+ if isinstance(chunk, CompletionMessage):
+ cprint(
+ f"{chunk.role.capitalize()}: {chunk.content}",
+ "white",
+ attrs=["bold"],
+ )
+ output_message = chunk
+ continue
+
+ assert isinstance(
+ chunk, AgenticSystemTurnResponseStreamChunk
+ ), f"Unexpected type {type(chunk)}"
+ event = chunk.event
+ if (
+ event.payload.event_type
+ == AgenticSystemTurnResponseEventType.step_complete.value
+ ):
+ steps.append(event.payload.step_details)
+
+ yield chunk
+
+ assert output_message is not None
+
+ turn = Turn(
+ turn_id=turn_id,
+ session_id=request.session_id,
+ input_messages=request.messages,
+ output_message=output_message,
+ started_at=start_time,
+ completed_at=datetime.now(),
+ steps=steps,
+ )
+ session.turns.append(turn)
+
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseTurnCompletePayload(
+ turn=turn,
+ )
+ )
+ )
+
+ async def run_shields_wrapper(
+ self,
+ turn_id: str,
+ messages: List[Message],
+ shields: List[ShieldDefinition],
+ touchpoint: str,
+ ) -> AsyncGenerator:
+ if len(shields) == 0:
+ return
+
+ step_id = str(uuid.uuid4())
+ try:
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepStartPayload(
+ step_type=StepType.shield_call.value,
+ step_id=step_id,
+ metadata=dict(touchpoint=touchpoint),
+ )
+ )
+ )
+ await self.run_shields(messages, shields)
+
+ except SafetyException as e:
+
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepCompletePayload(
+ step_type=StepType.shield_call.value,
+ step_details=ShieldCallStep(
+ step_id=step_id,
+ turn_id=turn_id,
+ response=e.response,
+ ),
+ )
+ )
+ )
+
+ yield CompletionMessage(
+ content=str(e),
+ stop_reason=StopReason.end_of_turn,
+ )
+ yield False
+
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepCompletePayload(
+ step_type=StepType.shield_call.value,
+ step_details=ShieldCallStep(
+ step_id=step_id,
+ turn_id=turn_id,
+ response=ShieldResponse(
+ # TODO: fix this, give each shield a shield type method and
+ # fire one event for each shield run
+ shield_type=BuiltinShield.llama_guard,
+ is_violation=False,
+ ),
+ ),
+ )
+ )
+ )
+
+ async def run(
+ self,
+ turn_id: str,
+ input_messages: List[Message],
+ temperature: float,
+ top_p: float,
+ stream: bool = False,
+ max_gen_len: Optional[int] = None,
+ ) -> AsyncGenerator:
+ # Doing async generators makes downstream code much simpler and everything amenable to
+ # stremaing. However, it also makes things complicated here because AsyncGenerators cannot
+ # return a "final value" for the `yield from` statement. we simulate that by yielding a
+ # final boolean (to see whether an exception happened) and then explicitly testing for it.
+
+ async for res in self.run_shields_wrapper(
+ turn_id, input_messages, self.input_shields, "user-input"
+ ):
+ if isinstance(res, bool):
+ return
+ else:
+ yield res
+
+ async for res in self._run(
+ turn_id, input_messages, temperature, top_p, stream, max_gen_len
+ ):
+ if isinstance(res, bool):
+ return
+ elif isinstance(res, CompletionMessage):
+ final_response = res
+ break
+ else:
+ yield res
+
+ assert final_response is not None
+ # for output shields run on the full input and output combination
+ messages = input_messages + [final_response]
+
+ async for res in self.run_shields_wrapper(
+ turn_id, messages, self.output_shields, "assistant-output"
+ ):
+ if isinstance(res, bool):
+ return
+ else:
+ yield res
+
+ yield final_response
+
+ async def _run(
+ self,
+ turn_id: str,
+ input_messages: List[Message],
+ temperature: float,
+ top_p: float,
+ stream: bool = False,
+ max_gen_len: Optional[int] = None,
+ ) -> AsyncGenerator:
+ input_messages = preprocess_dialog(input_messages, self.prefix_messages)
+
+ attachments = []
+
+ n_iter = 0
+ while True:
+ msg = input_messages[-1]
+ if msg.role == Role.user.value:
+ color = "blue"
+ elif msg.role == Role.ipython.value:
+ color = "yellow"
+ else:
+ color = None
+ cprint(f"{str(msg)}", color=color)
+
+ step_id = str(uuid.uuid4())
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepStartPayload(
+ step_type=StepType.inference.value,
+ step_id=step_id,
+ )
+ )
+ )
+
+ # where are the available tools?
+ req = ChatCompletionRequest(
+ model=self.model,
+ messages=input_messages,
+ available_tools=self.instance_config.available_tools,
+ stream=True,
+ sampling_params=SamplingParams(
+ temperature=temperature,
+ top_p=top_p,
+ max_tokens=max_gen_len,
+ ),
+ )
+
+ tool_calls = []
+ content = ""
+ stop_reason = None
+ async for chunk in self.inference_api.chat_completion(req):
+ event = chunk.event
+ if event.event_type != ChatCompletionResponseEventType.progress:
+ continue
+
+ delta = event.delta
+ if isinstance(delta, ToolCallDelta):
+ if delta.parse_status == ToolCallParseStatus.success:
+ tool_calls.append(delta.content)
+
+ if stream:
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepProgressPayload(
+ step_type=StepType.inference.value,
+ step_id=step_id,
+ model_response_text_delta="",
+ tool_call_delta=delta,
+ )
+ )
+ )
+
+ elif isinstance(delta, str):
+ content += delta
+ if stream and event.stop_reason is None:
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepProgressPayload(
+ step_type=StepType.inference.value,
+ step_id=step_id,
+ model_response_text_delta=event.delta,
+ )
+ )
+ )
+ else:
+ raise ValueError(f"Unexpected delta type {type(delta)}")
+
+ if event.stop_reason is not None:
+ stop_reason = event.stop_reason
+
+ stop_reason = stop_reason or StopReason.out_of_tokens
+ message = CompletionMessage(
+ content=content,
+ stop_reason=stop_reason,
+ tool_calls=tool_calls,
+ )
+
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepCompletePayload(
+ step_type=StepType.inference.value,
+ step_id=step_id,
+ step_details=InferenceStep(
+ step_id=step_id, turn_id=turn_id, model_response=message
+ ),
+ )
+ )
+ )
+
+ if n_iter >= self.max_infer_iters:
+ cprint("Done with MAX iterations, exiting.")
+ yield message
+ break
+
+ if stop_reason == StopReason.out_of_tokens:
+ cprint("Out of token budget, exiting.")
+ yield message
+ break
+
+ if len(message.tool_calls) == 0:
+ if stop_reason == StopReason.end_of_turn:
+ if len(attachments) > 0:
+ if isinstance(message.content, list):
+ message.content += attachments
+ else:
+ message.content = [message.content] + attachments
+ yield message
+ else:
+ cprint(f"Partial message: {str(message)}", color="green")
+ input_messages = input_messages + [message]
+ else:
+ cprint(f"{str(message)}", color="green")
+ try:
+ tool_call = message.tool_calls[0]
+
+ name = tool_call.tool_name
+ if not isinstance(name, BuiltinTool):
+ yield message
+ return
+
+ step_id = str(uuid.uuid4())
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepStartPayload(
+ step_type=StepType.tool_execution.value,
+ step_id=step_id,
+ )
+ )
+ )
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepProgressPayload(
+ step_type=StepType.tool_execution.value,
+ step_id=step_id,
+ tool_call=tool_call,
+ )
+ )
+ )
+
+ result_messages = await execute_tool_call_maybe(
+ self.tools_dict,
+ [message],
+ )
+ assert (
+ len(result_messages) == 1
+ ), "Currently not supporting multiple messages"
+ result_message = result_messages[0]
+
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepCompletePayload(
+ step_type=StepType.tool_execution.value,
+ step_details=ToolExecutionStep(
+ step_id=step_id,
+ turn_id=turn_id,
+ tool_calls=[tool_call],
+ tool_responses=[
+ ToolResponse(
+ call_id=result_message.call_id,
+ tool_name=result_message.tool_name,
+ content=result_message.content,
+ )
+ ],
+ ),
+ )
+ )
+ )
+
+ # TODO: add tool-input touchpoint and a "start" event for this step also
+ # but that needs a lot more refactoring of Tool code potentially
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepCompletePayload(
+ step_type=StepType.shield_call.value,
+ step_details=ShieldCallStep(
+ step_id=str(uuid.uuid4()),
+ turn_id=turn_id,
+ response=ShieldResponse(
+ # TODO: fix this, give each shield a shield type method and
+ # fire one event for each shield run
+ shield_type=BuiltinShield.llama_guard,
+ is_violation=False,
+ ),
+ ),
+ )
+ )
+ )
+
+ except SafetyException as e:
+ yield AgenticSystemTurnResponseStreamChunk(
+ event=AgenticSystemTurnResponseEvent(
+ payload=AgenticSystemTurnResponseStepCompletePayload(
+ step_type=StepType.shield_call.value,
+ step_details=ShieldCallStep(
+ step_id=str(uuid.uuid4()),
+ turn_id=turn_id,
+ response=e.response,
+ ),
+ )
+ )
+ )
+
+ yield CompletionMessage(
+ content=str(e),
+ stop_reason=StopReason.end_of_turn,
+ )
+ yield False
+ return
+
+ if isinstance(result_message.content, Attachment):
+ # NOTE: when we push this message back to the model, the model may ignore the
+ # attached file path etc. since the model is trained to only provide a user message
+ # with the summary. We keep all generated attachments and then attach them to final message
+ attachments.append(result_message.content)
+ elif isinstance(result_message.content, list) or isinstance(
+ result_message.content, tuple
+ ):
+ for c in result_message.content:
+ if isinstance(c, Attachment):
+ attachments.append(c)
+
+ input_messages = input_messages + [message, result_message]
+
+ n_iter += 1
+
+
+class MetaReferenceAgenticSystemImpl(AgenticSystem):
+ def __init__(self, inference_api: Inference, safety_api: Safety):
+ self.inference_api = inference_api
+ self.safety_api = safety_api
+
+ async def initialize(self) -> None:
+ pass
+
+ async def create_agentic_system(
+ self,
+ request: AgenticSystemCreateRequest,
+ ) -> AgenticSystemCreateResponse:
+ system_id = str(uuid.uuid4())
+
+ builtin_tools = []
+ custom_tool_definitions = []
+ cfg = request.instance_config
+ for dfn in cfg.available_tools:
+ if isinstance(dfn.tool_name, BuiltinTool):
+ if dfn.tool_name == BuiltinTool.wolfram_alpha:
+ tool = WolframAlphaTool(os.environ.get("WOLFRAM_ALPHA_API_KEY"))
+ elif dfn.tool_name == BuiltinTool.brave_search:
+ tool = BraveSearchTool(os.environ.get("BRAVE_SEARCH_API_KEY"))
+ elif dfn.tool_name == BuiltinTool.code_interpreter:
+ tool = CodeInterpreterTool()
+ elif dfn.tool_name == BuiltinTool.photogen:
+ tool = PhotogenTool(
+ dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
+ )
+ else:
+ raise ValueError(f"Unknown builtin tool: {dfn.tool_name}")
+
+ builtin_tools.append(
+ with_safety(
+ tool, self.safety_api, dfn.input_shields, dfn.output_shields
+ )
+ )
+ else:
+ custom_tool_definitions.append(dfn)
+
+ AGENT_INSTANCES_BY_ID[system_id] = AgentInstance(
+ system_id=system_id,
+ instance_config=request.instance_config,
+ model=request.model,
+ inference_api=self.inference_api,
+ builtin_tools=builtin_tools,
+ custom_tool_definitions=custom_tool_definitions,
+ safety_api=self.safety_api,
+ input_shields=cfg.input_shields,
+ output_shields=cfg.output_shields,
+ prefix_messages=cfg.debug_prefix_messages,
+ )
+
+ return AgenticSystemCreateResponse(
+ system_id=system_id,
+ )
+
+ async def create_agentic_system_session(
+ self,
+ request: AgenticSystemSessionCreateRequest,
+ ) -> AgenticSystemSessionCreateResponse:
+ system_id = request.system_id
+ assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
+ agent = AGENT_INSTANCES_BY_ID[system_id]
+
+ session = agent.create_session(request.session_name)
+ return AgenticSystemSessionCreateResponse(
+ session_id=session.session_id,
+ )
+
+ async def create_agentic_system_turn(
+ self,
+ request: AgenticSystemTurnCreateRequest,
+ ) -> AsyncGenerator:
+ system_id = request.system_id
+ assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
+ agent = AGENT_INSTANCES_BY_ID[system_id]
+
+ assert (
+ request.session_id in agent.sessions
+ ), f"Session {request.session_id} not found"
+ async for event in agent.create_and_execute_turn(request):
+ yield event
+
+
+def attachment_message(url: URL) -> ToolResponseMessage:
+ uri = url.uri
+ assert uri.startswith("file://")
+ filepath = uri[len("file://") :]
+
+ return ToolResponseMessage(
+ call_id="",
+ tool_name=BuiltinTool.code_interpreter,
+ content=f'# There is a file accessible to you at "{filepath}"',
+ )
+
+
+def preprocess_dialog(
+ messages: List[Message], prefix_messages: List[Message]
+) -> List[Message]:
+ """
+ Preprocesses the dialog by removing the system message and
+ adding the system message to the beginning of the dialog.
+ """
+ ret = prefix_messages.copy()
+
+ for m in messages:
+ if m.role == Role.system.value:
+ continue
+
+ # NOTE: the ideal behavior is to use `file_path = ...` but that
+ # means we need to have stateful execution o f code which we currently
+ # do not have.
+ if isinstance(m.content, Attachment):
+ ret.append(attachment_message(m.content.url))
+ elif isinstance(m.content, list):
+ for c in m.content:
+ if isinstance(c, Attachment):
+ ret.append(attachment_message(c.url))
+
+ ret.append(m)
+
+ return ret
diff --git a/llama_toolchain/agentic_system/api/__init__.py b/llama_toolchain/agentic_system/api/__init__.py
new file mode 100644
index 000000000..4cefa053f
--- /dev/null
+++ b/llama_toolchain/agentic_system/api/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .datatypes import * # noqa
+from .endpoints import * # noqa
diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py
new file mode 100644
index 000000000..45700a75a
--- /dev/null
+++ b/llama_toolchain/agentic_system/api/datatypes.py
@@ -0,0 +1,199 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from datetime import datetime
+from enum import Enum
+from typing import Any, Dict, List, Literal, Optional, Union
+
+from pydantic import BaseModel, Field
+from strong_typing.schema import json_schema_type
+from typing_extensions import Annotated
+
+from llama_toolchain.common.deployment_types import * # noqa: F403
+from llama_toolchain.inference.api import * # noqa: F403
+from llama_toolchain.safety.api.datatypes import * # noqa: F403
+from llama_toolchain.memory.api.datatypes import * # noqa: F403
+
+
+@json_schema_type
+class AgenticSystemToolDefinition(ToolDefinition):
+ execution_config: Optional[RestAPIExecutionConfig] = None
+ input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
+ output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
+
+
+class StepCommon(BaseModel):
+ turn_id: str
+ step_id: str
+ started_at: Optional[datetime] = None
+ completed_at: Optional[datetime] = None
+
+
+class StepType(Enum):
+ inference = "inference"
+ tool_execution = "tool_execution"
+ shield_call = "shield_call"
+ memory_retrieval = "memory_retrieval"
+
+
+@json_schema_type
+class InferenceStep(StepCommon):
+ step_type: Literal[StepType.inference.value] = StepType.inference.value
+ model_response: CompletionMessage
+
+
+@json_schema_type
+class ToolExecutionStep(StepCommon):
+ step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
+ tool_calls: List[ToolCall]
+ tool_responses: List[ToolResponse]
+
+
+@json_schema_type
+class ShieldCallStep(StepCommon):
+ step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
+ response: ShieldResponse
+
+
+@json_schema_type
+class MemoryRetrievalStep(StepCommon):
+ step_type: Literal[StepType.memory_retrieval.value] = (
+ StepType.memory_retrieval.value
+ )
+ memory_bank_ids: List[str]
+ documents: List[MemoryBankDocument]
+ scores: List[float]
+
+
+Step = Annotated[
+ Union[
+ InferenceStep,
+ ToolExecutionStep,
+ ShieldCallStep,
+ MemoryRetrievalStep,
+ ],
+ Field(discriminator="step_type"),
+]
+
+
+@json_schema_type
+class Turn(BaseModel):
+ """A single turn in an interaction with an Agentic System."""
+
+ turn_id: str
+ session_id: str
+ input_messages: List[
+ Union[
+ UserMessage,
+ ToolResponseMessage,
+ ]
+ ]
+ steps: List[Step]
+ output_message: CompletionMessage
+ started_at: datetime
+ completed_at: Optional[datetime] = None
+
+
+@json_schema_type
+class Session(BaseModel):
+ """A single session of an interaction with an Agentic System."""
+
+ session_id: str
+ session_name: str
+ turns: List[Turn]
+ started_at: datetime
+
+
+@json_schema_type
+class AgenticSystemInstanceConfig(BaseModel):
+ instructions: str
+ sampling_params: Optional[SamplingParams] = SamplingParams()
+ # zero-shot or built-in tool configurations as input to the model
+ available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
+ default_factory=list
+ )
+
+ input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
+ output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
+
+ quantization_config: Optional[QuantizationConfig] = None
+
+ # if you completely want to replace the messages prefixed by the system,
+ # this is debug only
+ debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
+
+
+class AgenticSystemTurnResponseEventType(Enum):
+ step_start = "step_start"
+ step_complete = "step_complete"
+ step_progress = "step_progress"
+
+ turn_start = "turn_start"
+ turn_complete = "turn_complete"
+
+
+@json_schema_type
+class AgenticSystemTurnResponseStepStartPayload(BaseModel):
+ event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
+ AgenticSystemTurnResponseEventType.step_start.value
+ )
+ step_type: StepType
+ step_id: str
+ metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
+
+
+@json_schema_type
+class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
+ event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
+ AgenticSystemTurnResponseEventType.step_complete.value
+ )
+ step_type: StepType
+ step_details: Step
+
+
+@json_schema_type
+class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
+ event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
+ AgenticSystemTurnResponseEventType.step_progress.value
+ )
+ step_type: StepType
+ step_id: str
+
+ model_response_text_delta: Optional[str] = None
+ tool_call_delta: Optional[ToolCallDelta] = None
+ tool_response_text_delta: Optional[str] = None
+
+
+@json_schema_type
+class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
+ event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
+ AgenticSystemTurnResponseEventType.turn_start.value
+ )
+ turn_id: str
+
+
+@json_schema_type
+class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
+ event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
+ AgenticSystemTurnResponseEventType.turn_complete.value
+ )
+ turn: Turn
+
+
+@json_schema_type
+class AgenticSystemTurnResponseEvent(BaseModel):
+ """Streamed agent execution response."""
+
+ payload: Annotated[
+ Union[
+ AgenticSystemTurnResponseStepStartPayload,
+ AgenticSystemTurnResponseStepProgressPayload,
+ AgenticSystemTurnResponseStepCompletePayload,
+ AgenticSystemTurnResponseTurnStartPayload,
+ AgenticSystemTurnResponseTurnCompletePayload,
+ ],
+ Field(discriminator="event_type"),
+ ]
diff --git a/llama_toolchain/agentic_system/api/endpoints.py b/llama_toolchain/agentic_system/api/endpoints.py
new file mode 100644
index 000000000..89ccc2995
--- /dev/null
+++ b/llama_toolchain/agentic_system/api/endpoints.py
@@ -0,0 +1,132 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .datatypes import * # noqa: F403
+from typing import Protocol
+
+# this dependency is annoying and we need a forked up version anyway
+from pyopenapi import webmethod
+
+from strong_typing.schema import json_schema_type
+
+
+@json_schema_type
+class AgenticSystemCreateRequest(BaseModel):
+ model: str
+ instance_config: AgenticSystemInstanceConfig
+
+
+@json_schema_type
+class AgenticSystemCreateResponse(BaseModel):
+ system_id: str
+
+
+@json_schema_type
+class AgenticSystemSessionCreateRequest(BaseModel):
+ system_id: str
+ session_name: str
+
+
+@json_schema_type
+class AgenticSystemSessionCreateResponse(BaseModel):
+ session_id: str
+
+
+@json_schema_type
+# what's the URI?
+class AgenticSystemTurnCreateRequest(BaseModel):
+ system_id: str
+ session_id: str
+
+ messages: List[
+ Union[
+ UserMessage,
+ ToolResponseMessage,
+ ]
+ ]
+
+ stream: Optional[bool] = False
+ override_config: Optional[AgenticSystemInstanceConfig] = None
+
+
+@json_schema_type(
+ schema={"description": "Server side event (SSE) stream of these events"}
+)
+class AgenticSystemTurnResponseStreamChunk(BaseModel):
+ event: AgenticSystemTurnResponseEvent
+
+
+@json_schema_type
+class AgenticSystemStepResponse(BaseModel):
+ step: Step
+
+
+class AgenticSystem(Protocol):
+
+ @webmethod(route="/agentic_system/create")
+ async def create_agentic_system(
+ self,
+ request: AgenticSystemCreateRequest,
+ ) -> AgenticSystemCreateResponse: ...
+
+ @webmethod(route="/agentic_system/turn/create")
+ async def create_agentic_system_turn(
+ self,
+ request: AgenticSystemTurnCreateRequest,
+ ) -> AgenticSystemTurnResponseStreamChunk: ...
+
+ @webmethod(route="/agentic_system/turn/get")
+ async def get_agentic_system_turn(
+ self,
+ agent_id: str,
+ turn_id: str,
+ ) -> Turn: ...
+
+ @webmethod(route="/agentic_system/step/get")
+ async def get_agentic_system_step(
+ self, agent_id: str, turn_id: str, step_id: str
+ ) -> AgenticSystemStepResponse: ...
+
+ @webmethod(route="/agentic_system/session/create")
+ async def create_agentic_system_session(
+ self,
+ request: AgenticSystemSessionCreateRequest,
+ ) -> AgenticSystemSessionCreateResponse: ...
+
+ @webmethod(route="/agentic_system/memory_bank/attach")
+ async def attach_memory_bank_to_agentic_system(
+ self,
+ agent_id: str,
+ session_id: str,
+ memory_bank_ids: List[str],
+ ) -> None: ...
+
+ @webmethod(route="/agentic_system/memory_bank/detach")
+ async def detach_memory_bank_from_agentic_system(
+ self,
+ agent_id: str,
+ session_id: str,
+ memory_bank_ids: List[str],
+ ) -> None: ...
+
+ @webmethod(route="/agentic_system/session/get")
+ async def get_agentic_system_session(
+ self,
+ agent_id: str,
+ session_id: str,
+ turn_ids: Optional[List[str]] = None,
+ ) -> Session: ...
+
+ @webmethod(route="/agentic_system/session/delete")
+ async def delete_agentic_system_session(
+ self, agent_id: str, session_id: str
+ ) -> None: ...
+
+ @webmethod(route="/agentic_system/delete")
+ async def delete_agentic_system(
+ self,
+ agent_id: str,
+ ) -> None: ...
diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py
new file mode 100644
index 000000000..71c578e2f
--- /dev/null
+++ b/llama_toolchain/agentic_system/client.py
@@ -0,0 +1,130 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import asyncio
+import json
+
+from typing import AsyncGenerator
+
+import fire
+
+import httpx
+
+from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams
+
+from .api import (
+ AgenticSystem,
+ AgenticSystemCreateRequest,
+ AgenticSystemCreateResponse,
+ AgenticSystemInstanceConfig,
+ AgenticSystemSessionCreateRequest,
+ AgenticSystemSessionCreateResponse,
+ AgenticSystemToolDefinition,
+ AgenticSystemTurnCreateRequest,
+ AgenticSystemTurnResponseStreamChunk,
+)
+
+
+async def get_client_impl(base_url: str):
+ return AgenticSystemClient(base_url)
+
+
+class AgenticSystemClient(AgenticSystem):
+ def __init__(self, base_url: str):
+ self.base_url = base_url
+
+ async def create_agentic_system(
+ self, request: AgenticSystemCreateRequest
+ ) -> AgenticSystemCreateResponse:
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ f"{self.base_url}/agentic_system/create",
+ data=request.json(),
+ headers={"Content-Type": "application/json"},
+ )
+ response.raise_for_status()
+ return AgenticSystemCreateResponse(**response.json())
+
+ async def create_agentic_system_session(
+ self,
+ request: AgenticSystemSessionCreateRequest,
+ ) -> AgenticSystemSessionCreateResponse:
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ f"{self.base_url}/agentic_system/session/create",
+ data=request.json(),
+ headers={"Content-Type": "application/json"},
+ )
+ response.raise_for_status()
+ return AgenticSystemSessionCreateResponse(**response.json())
+
+ async def create_agentic_system_turn(
+ self,
+ request: AgenticSystemTurnCreateRequest,
+ ) -> AsyncGenerator:
+ async with httpx.AsyncClient() as client:
+ async with client.stream(
+ "POST",
+ f"{self.base_url}/agentic_system/turn/create",
+ data=request.json(),
+ headers={"Content-Type": "application/json"},
+ timeout=20,
+ ) as response:
+ async for line in response.aiter_lines():
+ if line.startswith("data:"):
+ data = line[len("data: ") :]
+ try:
+ yield AgenticSystemTurnResponseStreamChunk(
+ **json.loads(data)
+ )
+ except Exception as e:
+ print(data)
+ print(f"Error with parsing or validation: {e}")
+
+
+async def run_main(host: str, port: int):
+ # client to test remote impl of agentic system
+ api = await AgenticSystemClient(f"http://{host}:{port}")
+
+ tool_definitions = [
+ AgenticSystemToolDefinition(
+ tool_name=BuiltinTool.brave_search,
+ ),
+ AgenticSystemToolDefinition(
+ tool_name=BuiltinTool.wolfram_alpha,
+ ),
+ AgenticSystemToolDefinition(
+ tool_name=BuiltinTool.photogen,
+ ),
+ AgenticSystemToolDefinition(
+ tool_name=BuiltinTool.code_interpreter,
+ ),
+ ]
+
+ create_request = AgenticSystemCreateRequest(
+ model="Meta-Llama3.1-8B-Instruct",
+ instance_config=AgenticSystemInstanceConfig(
+ instructions="You are a helpful assistant",
+ sampling_params=SamplingParams(),
+ available_tools=tool_definitions,
+ input_shields=[],
+ output_shields=[],
+ quantization_config=None,
+ debug_prefix_messages=[],
+ ),
+ )
+
+ create_response = await api.create_agentic_system(create_request)
+ print(create_response)
+ # TODO: Add chat session / turn apis to test e2e
+
+
+def main(host: str, port: int):
+ asyncio.run(run_main(host, port))
+
+
+if __name__ == "__main__":
+ fire.Fire(main)
diff --git a/llama_toolchain/agentic_system/config.py b/llama_toolchain/agentic_system/config.py
new file mode 100644
index 000000000..349784021
--- /dev/null
+++ b/llama_toolchain/agentic_system/config.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from pydantic import BaseModel
+
+
+class AgenticSystemConfig(BaseModel):
+ # placeholder, no separate configuration is needed for now
+ pass
diff --git a/llama_toolchain/agentic_system/safety.py b/llama_toolchain/agentic_system/safety.py
new file mode 100644
index 000000000..f066baf59
--- /dev/null
+++ b/llama_toolchain/agentic_system/safety.py
@@ -0,0 +1,65 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import List
+
+from llama_models.llama3_1.api.datatypes import Message, Role
+from llama_toolchain.safety.api.datatypes import (
+ OnViolationAction,
+ ShieldDefinition,
+ ShieldResponse,
+)
+from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
+from termcolor import cprint
+
+
+class SafetyException(Exception): # noqa: N818
+ def __init__(self, response: ShieldResponse):
+ self.response = response
+ super().__init__(response.violation_return_message)
+
+
+class ShieldRunnerMixin:
+
+ def __init__(
+ self,
+ safety_api: Safety,
+ input_shields: List[ShieldDefinition] = None,
+ output_shields: List[ShieldDefinition] = None,
+ ):
+ self.safety_api = safety_api
+ self.input_shields = input_shields
+ self.output_shields = output_shields
+
+ async def run_shields(
+ self, messages: List[Message], shields: List[ShieldDefinition]
+ ) -> List[ShieldResponse]:
+ # some shields like llama-guard require the first message to be a user message
+ # since this might be a tool call, first role might not be user
+ if len(messages) > 0 and messages[0].role != Role.user.value:
+ # TODO(ashwin): we need to change the type of the message, this kind of modification
+ # is no longer appropriate
+ messages[0].role = Role.user.value
+
+ res = await self.safety_api.run_shields(
+ RunShieldRequest(
+ messages=messages,
+ shields=shields,
+ )
+ )
+
+ results = res.responses
+ for shield, r in zip(shields, results):
+ if r.is_violation:
+ if shield.on_violation_action == OnViolationAction.RAISE:
+ raise SafetyException(r)
+ elif shield.on_violation_action == OnViolationAction.WARN:
+ cprint(
+ f"[Warn]{shield.__class__.__name__} raised a warning",
+ color="red",
+ )
+
+ return results
diff --git a/llama_toolchain/agentic_system/system_prompt.py b/llama_toolchain/agentic_system/system_prompt.py
new file mode 100644
index 000000000..c8c616285
--- /dev/null
+++ b/llama_toolchain/agentic_system/system_prompt.py
@@ -0,0 +1,152 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import json
+from datetime import datetime
+from typing import List
+
+from llama_toolchain.inference.api import (
+ BuiltinTool,
+ Message,
+ SystemMessage,
+ ToolDefinition,
+)
+
+from .tools.builtin import SingleMessageBuiltinTool
+
+
+def get_agentic_prefix_messages(
+ builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition]
+) -> List[Message]:
+ messages = []
+ content = ""
+ if builtin_tools:
+ content += "Environment: ipython\n"
+
+ tool_str = ", ".join(
+ [
+ t.get_name()
+ for t in builtin_tools
+ if t.get_name() != BuiltinTool.code_interpreter.value
+ ]
+ )
+ if tool_str:
+ content += f"Tools: {tool_str}\n"
+
+ current_date = datetime.now()
+ formatted_date = current_date.strftime("%d %B %Y")
+ date_str = f"""
+Cutting Knowledge Date: December 2023
+Today Date: {formatted_date}\n\n"""
+ content += date_str
+
+ if custom_tools:
+ custom_message = get_system_prompt_for_custom_tools(custom_tools)
+ content += custom_message
+
+ # TODO: Replace this hard coded message with instructions coming in the request
+ if False:
+ content += "You are a helpful Assistant."
+
+ messages.append(SystemMessage(content=content))
+ return messages
+
+
+def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str:
+ custom_tool_params = ""
+ for t in custom_tools:
+ custom_tool_params += get_instruction_string(t) + "\n"
+ custom_tool_params += get_parameters_string(t) + "\n\n"
+
+ content = f"""
+You have access to the following functions:
+
+{custom_tool_params}
+Think very carefully before calling functions.
+If you choose to call a function ONLY reply in the following format with no prefix or suffix:
+
+{{"example_name": "example_value"}}
+
+Reminder:
+- If looking for real time information use relevant functions before falling back to brave_search
+- Function calls MUST follow the specified format, start with
+- Required parameters MUST be specified
+- Only call one function at a time
+- Put the entire function call reply on one line
+
+"""
+ return content
+
+
+def get_instruction_string(custom_tool_definition) -> str:
+ return f"Use the function '{custom_tool_definition.tool_name}' to '{custom_tool_definition.description}'"
+
+
+def get_parameters_string(custom_tool_definition) -> str:
+ return json.dumps(
+ {
+ "name": custom_tool_definition.tool_name,
+ "description": custom_tool_definition.description,
+ "parameters": {
+ name: definition.__dict__
+ for name, definition in custom_tool_definition.parameters.items()
+ },
+ }
+ )
+
+
+# NOTE: Unused right now
+def translate_custom_tool_definition_to_json(tool_def):
+ """Translates ToolDefinition to json as expected by model
+ eg. output for a function
+ {
+ "type": "function",
+ "function": {
+ "name": "conv_int",
+ "description": "Convert serialized fract24 integer into int value.",
+ "parameters": {
+ "type": "object",
+ "properties": [
+ {
+ "data": {
+ "type": "object",
+ "description": ""
+ }
+ }
+ ],
+ "required": ["data"]
+ }
+ }
+ }
+ """
+ assert isinstance(tool_def.tool_name, str)
+ func_def = {"type": "function", "function": {}}
+ func_def["function"]["name"] = tool_def.tool_name
+ func_def["function"]["description"] = tool_def.description or ""
+ if tool_def.parameters:
+ required = []
+ properties = []
+ for p_name, p_def in tool_def.parameters.items():
+ properties.append(
+ {
+ p_name: {
+ # TODO: see if this should not always be object
+ "type": "object",
+ "description": p_def.description or "",
+ }
+ }
+ )
+ if p_def.required:
+ required.append(p_name)
+ func_def["function"]["parameters"] = {
+ "type": "object",
+ "properties": properties,
+ "required": required,
+ }
+ else:
+ func_def["function"]["parameters"] = {}
+
+ return json.dumps(func_def)
diff --git a/llama_toolchain/agentic_system/tools/__init__.py b/llama_toolchain/agentic_system/tools/__init__.py
new file mode 100644
index 000000000..756f351d8
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
diff --git a/llama_toolchain/agentic_system/tools/base.py b/llama_toolchain/agentic_system/tools/base.py
new file mode 100644
index 000000000..3c2722305
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/base.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+from typing import List
+
+from llama_toolchain.inference.api import Message
+
+
+class BaseTool(ABC):
+
+ @abstractmethod
+ def get_name(self) -> str:
+ raise NotImplementedError
+
+ @abstractmethod
+ async def run(self, messages: List[Message]) -> List[Message]:
+ raise NotImplementedError
diff --git a/llama_toolchain/agentic_system/tools/builtin.py b/llama_toolchain/agentic_system/tools/builtin.py
new file mode 100644
index 000000000..4487a2692
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/builtin.py
@@ -0,0 +1,326 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import json
+import os
+import re
+
+from abc import abstractmethod
+from typing import List, Optional
+
+import requests
+from termcolor import cprint
+
+from .ipython_tool.code_execution import (
+ CodeExecutionContext,
+ CodeExecutionRequest,
+ CodeExecutor,
+ TOOLS_ATTACHMENT_KEY_REGEX,
+)
+
+from llama_toolchain.inference.api import * # noqa: F403
+
+from .base import BaseTool
+
+
+def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
+ match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
+ if match:
+ snippet = match.group(1)
+ data = json.loads(snippet)
+ return Attachment(
+ url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
+ )
+
+ return None
+
+
+class SingleMessageBuiltinTool(BaseTool):
+ async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
+ assert len(messages) == 1, f"Expected single message, got {len(messages)}"
+
+ message = messages[0]
+ assert len(message.tool_calls) == 1, "Expected a single tool call"
+
+ tool_call = messages[0].tool_calls[0]
+
+ query = tool_call.arguments["query"]
+ response: str = await self.run_impl(query)
+
+ message = ToolResponseMessage(
+ call_id=tool_call.call_id,
+ tool_name=tool_call.tool_name,
+ content=response,
+ )
+ if attachment := interpret_content_as_attachment(response):
+ message.content = attachment
+
+ return [message]
+
+ @abstractmethod
+ async def run_impl(self, query: str) -> str:
+ raise NotImplementedError()
+
+
+class PhotogenTool(SingleMessageBuiltinTool):
+
+ def __init__(self, dump_dir: str) -> None:
+ self.dump_dir = dump_dir
+
+ def get_name(self) -> str:
+ return BuiltinTool.photogen.value
+
+ async def run_impl(self, query: str) -> str:
+ """
+ Implement this to give the model an ability to generate images.
+
+ Return:
+ info = {
+ "filepath": str(image_filepath),
+ "mimetype": "image/png",
+ }
+ """
+ raise NotImplementedError()
+
+
+class BraveSearchTool(SingleMessageBuiltinTool):
+
+ def __init__(self, api_key: str) -> None:
+ self.api_key = api_key
+
+ def get_name(self) -> str:
+ return BuiltinTool.brave_search.value
+
+ async def run_impl(self, query: str) -> str:
+ url = "https://api.search.brave.com/res/v1/web/search"
+ headers = {
+ "X-Subscription-Token": self.api_key,
+ "Accept-Encoding": "gzip",
+ "Accept": "application/json",
+ }
+ payload = {"q": query}
+ response = requests.get(url=url, params=payload, headers=headers)
+ return json.dumps(self._clean_brave_response(response.json()))
+
+ def _clean_brave_response(self, search_response, top_k=3):
+ query = None
+ clean_response = []
+ if "query" in search_response:
+ if "original" in search_response["query"]:
+ query = search_response["query"]["original"]
+ if "mixed" in search_response:
+ mixed_results = search_response["mixed"]
+ for m in mixed_results["main"][:top_k]:
+ r_type = m["type"]
+ results = search_response[r_type]["results"]
+ if r_type == "web":
+ # For web data - add a single output from the search
+ idx = m["index"]
+ selected_keys = [
+ "type",
+ "title",
+ "url",
+ "description",
+ "date",
+ "extra_snippets",
+ ]
+ cleaned = {
+ k: v for k, v in results[idx].items() if k in selected_keys
+ }
+ elif r_type == "faq":
+ # For faw data - take a list of all the questions & answers
+ selected_keys = ["type", "question", "answer", "title", "url"]
+ cleaned = []
+ for q in results:
+ cleaned.append(
+ {k: v for k, v in q.items() if k in selected_keys}
+ )
+ elif r_type == "infobox":
+ idx = m["index"]
+ selected_keys = [
+ "type",
+ "title",
+ "url",
+ "description",
+ "long_desc",
+ ]
+ cleaned = {
+ k: v for k, v in results[idx].items() if k in selected_keys
+ }
+ elif r_type == "videos":
+ selected_keys = [
+ "type",
+ "url",
+ "title",
+ "description",
+ "date",
+ ]
+ cleaned = []
+ for q in results:
+ cleaned.append(
+ {k: v for k, v in q.items() if k in selected_keys}
+ )
+ elif r_type == "locations":
+ # For faw data - take a list of all the questions & answers
+ selected_keys = [
+ "type",
+ "title",
+ "url",
+ "description",
+ "coordinates",
+ "postal_address",
+ "contact",
+ "rating",
+ "distance",
+ "zoom_level",
+ ]
+ cleaned = []
+ for q in results:
+ cleaned.append(
+ {k: v for k, v in q.items() if k in selected_keys}
+ )
+ elif r_type == "news":
+ # For faw data - take a list of all the questions & answers
+ selected_keys = [
+ "type",
+ "title",
+ "url",
+ "description",
+ ]
+ cleaned = []
+ for q in results:
+ cleaned.append(
+ {k: v for k, v in q.items() if k in selected_keys}
+ )
+ else:
+ cleaned = []
+
+ clean_response.append(cleaned)
+
+ return {"query": query, "top_k": clean_response}
+
+
+class WolframAlphaTool(SingleMessageBuiltinTool):
+
+ def __init__(self, api_key: str) -> None:
+ self.api_key = api_key
+ self.url = "https://api.wolframalpha.com/v2/query"
+
+ def get_name(self) -> str:
+ return BuiltinTool.wolfram_alpha.value
+
+ async def run_impl(self, query: str) -> str:
+ params = {
+ "input": query,
+ "appid": self.api_key,
+ "format": "plaintext",
+ "output": "json",
+ }
+ response = requests.get(
+ self.url,
+ params=params,
+ )
+
+ return json.dumps(self._clean_wolfram_alpha_response(response.json()))
+
+ def _clean_wolfram_alpha_response(self, wa_response):
+ remove = {
+ "queryresult": [
+ "datatypes",
+ "error",
+ "timedout",
+ "timedoutpods",
+ "numpods",
+ "timing",
+ "parsetiming",
+ "parsetimedout",
+ "recalculate",
+ "id",
+ "host",
+ "server",
+ "related",
+ "version",
+ {
+ "pods": [
+ "scanner",
+ "id",
+ "error",
+ "expressiontypes",
+ "states",
+ "infos",
+ "position",
+ "numsubpods",
+ ]
+ },
+ "assumptions",
+ ],
+ }
+ for main_key in remove:
+ for key_to_remove in remove[main_key]:
+ try:
+ if key_to_remove == "assumptions":
+ if "assumptions" in wa_response[main_key]:
+ del wa_response[main_key][key_to_remove]
+ if isinstance(key_to_remove, dict):
+ for sub_key in key_to_remove:
+ if sub_key == "pods":
+ for i in range(len(wa_response[main_key][sub_key])):
+ if (
+ wa_response[main_key][sub_key][i]["title"]
+ == "Result"
+ ):
+ del wa_response[main_key][sub_key][i + 1 :]
+ break
+ sub_items = wa_response[main_key][sub_key]
+ for i in range(len(sub_items)):
+ for sub_key_to_remove in key_to_remove[sub_key]:
+ if sub_key_to_remove in sub_items[i]:
+ del sub_items[i][sub_key_to_remove]
+ elif key_to_remove in wa_response[main_key]:
+ del wa_response[main_key][key_to_remove]
+ except KeyError:
+ pass
+ return wa_response
+
+
+class CodeInterpreterTool(BaseTool):
+
+ def __init__(self) -> None:
+ ctx = CodeExecutionContext(
+ matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump",
+ )
+ self.code_executor = CodeExecutor(ctx)
+
+ def get_name(self) -> str:
+ return BuiltinTool.code_interpreter.value
+
+ async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
+ message = messages[0]
+ assert len(message.tool_calls) == 1, "Expected a single tool call"
+
+ tool_call = messages[0].tool_calls[0]
+ script = tool_call.arguments["code"]
+
+ req = CodeExecutionRequest(scripts=[script])
+ res = self.code_executor.execute(req)
+
+ pieces = [res["process_status"]]
+ for out_type in ["stdout", "stderr"]:
+ res_out = res[out_type]
+ if res_out != "":
+ pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
+ if out_type == "stderr":
+ cprint(f"ipython tool error: ↓\n{res_out}", color="red")
+
+ message = ToolResponseMessage(
+ call_id=tool_call.call_id,
+ tool_name=tool_call.tool_name,
+ content="\n".join(pieces),
+ )
+ if attachment := interpret_content_as_attachment(res["stdout"]):
+ message.content = attachment
+
+ return [message]
diff --git a/llama_toolchain/agentic_system/tools/custom.py b/llama_toolchain/agentic_system/tools/custom.py
new file mode 100644
index 000000000..35e3dd57d
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/custom.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import json
+
+from abc import abstractmethod
+from typing import Dict, List
+
+from llama_models.llama3_1.api.datatypes import * # noqa: F403
+from llama_toolchain.agentic_system.api import * # noqa: F403
+
+from .builtin import interpret_content_as_attachment
+
+
+class CustomTool:
+ """
+ Developers can define their custom tools that models can use
+ by extending this class.
+
+ Developers need to provide
+ - name
+ - description
+ - params_definition
+ - implement tool's behavior in `run_impl` method
+
+ NOTE: The return of the `run` method needs to be json serializable
+ """
+
+ @abstractmethod
+ def get_name(self) -> str:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_description(self) -> str:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_params_definition(self) -> Dict[str, ToolParamDefinition]:
+ raise NotImplementedError
+
+ def get_instruction_string(self) -> str:
+ return f"Use the function '{self.get_name()}' to: {self.get_description()}"
+
+ def parameters_for_system_prompt(self) -> str:
+ return json.dumps(
+ {
+ "name": self.get_name(),
+ "description": self.get_description(),
+ "parameters": {
+ name: definition.__dict__
+ for name, definition in self.get_params_definition().items()
+ },
+ }
+ )
+
+ def get_tool_definition(self) -> AgenticSystemToolDefinition:
+ return AgenticSystemToolDefinition(
+ tool_name=self.get_name(),
+ description=self.get_description(),
+ parameters=self.get_params_definition(),
+ )
+
+ @abstractmethod
+ async def run(self, messages: List[Message]) -> List[Message]:
+ raise NotImplementedError
+
+
+class SingleMessageCustomTool(CustomTool):
+ """
+ Helper class to handle custom tools that take a single message
+ Extending this class and implementing the `run_impl` method will
+ allow for the tool be called by the model and the necessary plumbing.
+ """
+
+ async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
+ assert len(messages) == 1, "Expected single message"
+
+ message = messages[0]
+
+ tool_call = message.tool_calls[0]
+
+ try:
+ response = await self.run_impl(**tool_call.arguments)
+ response_str = json.dumps(response, ensure_ascii=False)
+ except Exception as e:
+ response_str = f"Error when running tool: {e}"
+
+ message = ToolResponseMessage(
+ call_id=tool_call.call_id,
+ tool_name=tool_call.tool_name,
+ content=response_str,
+ )
+ if attachment := interpret_content_as_attachment(response_str):
+ message.content = attachment
+
+ return [message]
+
+ @abstractmethod
+ async def run_impl(self, *args, **kwargs):
+ raise NotImplementedError()
diff --git a/llama_toolchain/agentic_system/tools/execute.py b/llama_toolchain/agentic_system/tools/execute.py
new file mode 100644
index 000000000..2a7625f65
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/execute.py
@@ -0,0 +1,84 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import Any, AsyncGenerator, List
+
+from llama_models.llama3_1.api.datatypes import StopReason, ToolResponseMessage
+
+from llama_toolchain.agentic_system.api import (
+ AgenticSystem,
+ AgenticSystemTurnCreateRequest,
+ AgenticSystemTurnResponseEventType as EventType,
+)
+
+from llama_toolchain.inference.api import Message
+
+
+async def execute_with_custom_tools(
+ system: AgenticSystem,
+ system_id: str,
+ session_id: str,
+ messages: List[Message],
+ custom_tools: List[Any],
+ max_iters: int = 5,
+ stream: bool = True,
+) -> AsyncGenerator:
+ # first create a session, or do you keep a persistent session?
+ tools_dict = {t.get_name(): t for t in custom_tools}
+
+ current_messages = messages.copy()
+ n_iter = 0
+ while n_iter < max_iters:
+ n_iter += 1
+
+ request = AgenticSystemTurnCreateRequest(
+ system_id=system_id,
+ session_id=session_id,
+ messages=current_messages,
+ stream=stream,
+ )
+
+ turn = None
+ async for chunk in system.create_agentic_system_turn(request):
+ if chunk.event.payload.event_type != EventType.turn_complete.value:
+ yield chunk
+ else:
+ turn = chunk.event.payload.turn
+ break
+
+ message = turn.output_message
+ if len(message.tool_calls) == 0:
+ yield chunk
+ return
+
+ if message.stop_reason == StopReason.out_of_tokens:
+ yield chunk
+ return
+
+ tool_call = message.tool_calls[0]
+ if tool_call.tool_name not in tools_dict:
+ m = ToolResponseMessage(
+ call_id=tool_call.call_id,
+ tool_name=tool_call.tool_name,
+ content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
+ )
+ next_message = m
+ else:
+ tool = tools_dict[tool_call.tool_name]
+ result_messages = await execute_custom_tool(tool, message)
+ next_message = result_messages[0]
+
+ yield next_message
+ current_messages = [next_message]
+
+
+async def execute_custom_tool(tool: Any, message: Message) -> List[Message]:
+ result_messages = await tool.run([message])
+ assert (
+ len(result_messages) == 1
+ ), f"Expected single message, got {len(result_messages)}"
+
+ return result_messages
diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/__init__.py b/llama_toolchain/agentic_system/tools/ipython_tool/__init__.py
new file mode 100644
index 000000000..756f351d8
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/ipython_tool/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/code_env_prefix.py b/llama_toolchain/agentic_system/tools/ipython_tool/code_env_prefix.py
new file mode 100644
index 000000000..10f64ec94
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/ipython_tool/code_env_prefix.py
@@ -0,0 +1,133 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import errno
+
+# Disabling potentially dangerous functions
+import os as _os
+from functools import partial
+
+os_funcs_to_disable = [
+ "kill",
+ "system",
+ "putenv",
+ "remove",
+ "removedirs",
+ "rmdir",
+ "fchdir",
+ "setuid",
+ "fork",
+ "forkpty",
+ "killpg",
+ "rename",
+ "renames",
+ "truncate",
+ "replace",
+ # "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
+ "fchmod",
+ "fchown",
+ "chmod",
+ "chown",
+ "chroot",
+ "fchdir",
+ "lchflags",
+ "lchmod",
+ "lchown",
+ "chdir",
+]
+
+
+def call_not_allowed(*args, **kwargs):
+ raise OSError(errno.EPERM, "Call are not permitted in this environment")
+
+
+for func_name in os_funcs_to_disable:
+ if hasattr(_os, func_name):
+ setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
+
+import shutil as _shutil
+
+for func_name in ["rmtree", "move", "chown"]:
+ if hasattr(_shutil, func_name):
+ setattr(
+ _shutil,
+ func_name,
+ partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
+ )
+
+import subprocess as _subprocess
+
+
+def popen_not_allowed(*args, **kwargs):
+ raise _subprocess.CalledProcessError(
+ -1,
+ args[0] if args else "unknown",
+ stderr="subprocess.Popen is not allowed in this environment",
+ )
+
+
+_subprocess.Popen = popen_not_allowed
+
+
+import atexit as _atexit
+import builtins as _builtins
+import io as _io
+import json as _json
+import sys as _sys
+
+# NB! The following "unused" imports crucial, make sure not not to remove
+# them with linters - they're used in code_execution.py
+from contextlib import ( # noqa
+ contextmanager as _contextmanager,
+ redirect_stderr as _redirect_stderr,
+ redirect_stdout as _redirect_stdout,
+)
+from multiprocessing.connection import Connection as _Connection
+
+# Mangle imports to avoid polluting model execution namespace.
+
+_IO_SINK = _io.StringIO()
+_NETWORK_TIMEOUT = 5
+_NETWORK_CONNECTIONS = None
+
+
+def _open_connections():
+ global _NETWORK_CONNECTIONS
+ if _NETWORK_CONNECTIONS is not None:
+ # Ensure connections only opened once.
+ return _NETWORK_CONNECTIONS
+ req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
+ req_con = _Connection(int(req_w_fd), readable=False)
+ resp_con = _Connection(int(resp_r_fd), writable=False)
+ _NETWORK_CONNECTIONS = (req_con, resp_con)
+ return _NETWORK_CONNECTIONS
+
+
+_builtins._open_connections = _open_connections
+
+
+@_atexit.register
+def _close_connections():
+ global _NETWORK_CONNECTIONS
+ if _NETWORK_CONNECTIONS is None:
+ return
+ for con in _NETWORK_CONNECTIONS:
+ con.close()
+ del _NETWORK_CONNECTIONS
+
+
+def _network_call(request):
+ # NOTE: We communicate with the parent process in json, encoded
+ # in raw bytes. We do this because native send/recv methods use
+ # pickle which involves execution of arbitrary code.
+ _open_connections()
+ req_con, resp_con = _NETWORK_CONNECTIONS
+
+ req_con.send_bytes(_json.dumps(request).encode("utf-8"))
+ if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
+ raise Exception(f"Network request timed out: {_json.dumps(request)}")
+ else:
+ return _json.loads(resp_con.recv_bytes().decode("utf-8"))
diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/code_execution.py b/llama_toolchain/agentic_system/tools/ipython_tool/code_execution.py
new file mode 100644
index 000000000..fa2e367e5
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/ipython_tool/code_execution.py
@@ -0,0 +1,256 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import base64
+import json
+import multiprocessing
+import os
+import re
+import subprocess
+import sys
+import tempfile
+import textwrap
+import time
+from dataclasses import dataclass
+from datetime import datetime
+from io import BytesIO
+from pathlib import Path
+from typing import List
+
+from PIL import Image
+
+from .utils import get_code_env_prefix
+
+TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
+TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
+
+DIRNAME = Path(__file__).parent
+
+CODE_EXEC_TIMEOUT = 20
+CODE_ENV_PREFIX = get_code_env_prefix()
+
+STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
+with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
+{code}\
+"""
+
+TRYEXCEPT_WRAPPER_TEMPLATE = """\
+try:
+{code}
+except:
+ pass\
+"""
+
+
+def generate_bwrap_command(bind_dirs: List[str]) -> str:
+ """
+ Generate the bwrap command string for binding all
+ directories in the current directory read-only.
+ """
+ bwrap_args = ""
+ bwrap_args += "--ro-bind / / "
+ # Add the --dev flag to mount device files
+ bwrap_args += "--dev /dev "
+ for d in bind_dirs:
+ bwrap_args += f"--bind {d} {d} "
+
+ # Add the --unshare-all flag to isolate the sandbox from the rest of the system
+ bwrap_args += "--unshare-all "
+ # Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
+ bwrap_args += "--die-with-parent "
+ return bwrap_args
+
+
+@dataclass
+class CodeExecutionContext:
+ matplotlib_dump_dir: str
+ use_proxy: bool = False
+
+
+@dataclass
+class CodeExecutionRequest:
+ scripts: List[str]
+ only_last_cell_stdouterr: bool = True
+ only_last_cell_fail: bool = True
+ seed: int = 0
+ strip_fpaths_in_stderr: bool = True
+
+
+class CodeExecutor:
+ def __init__(self, context: CodeExecutionContext):
+ self.context = context
+
+ def execute(self, req: CodeExecutionRequest) -> dict:
+ scripts = req.scripts
+ for i in range(len(scripts) - 1):
+ if req.only_last_cell_stdouterr:
+ scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
+ code=textwrap.indent(scripts[i], " " * 4)
+ )
+ if req.only_last_cell_fail:
+ scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
+ code=textwrap.indent(scripts[i], " " * 4)
+ )
+
+ # Seeds prefix:
+ seed = req.seed
+ seeds_prefix = f"""\
+def _set_seeds():
+ import random
+ random.seed({seed})
+ import numpy as np
+ np.random.seed({seed})
+_set_seeds()\
+"""
+
+ script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
+ with tempfile.TemporaryDirectory() as dpath:
+ bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
+ cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
+ code_fpath = os.path.join(dpath, "code.py")
+ with open(code_fpath, "w") as f:
+ f.write(script)
+
+ try:
+ python_path = os.environ.get("PYTHONPATH", "")
+ env = dict(
+ os.environ,
+ PYTHONHASHSEED=str(seed),
+ MPLCONFIGDIR=dpath,
+ MPLBACKEND="module://matplotlib_custom_backend",
+ PYTHONPATH=f"{DIRNAME}:{python_path}",
+ )
+ stdout, stderr, returncode = do_subprocess(
+ cmd=cmd,
+ env=env,
+ ctx=self.context,
+ )
+
+ stderr = stderr.strip()
+ if req.strip_fpaths_in_stderr:
+ pattern = r'File "([^"]+)", line (\d+)'
+ stderr = re.sub(pattern, r"line \2", stderr)
+
+ return {
+ "process_status": "completed",
+ "returncode": returncode,
+ "stdout": stdout.strip(),
+ "stderr": stderr,
+ }
+
+ except subprocess.TimeoutExpired:
+ return {
+ "process_status": "timeout",
+ "stdout": "Timed out",
+ "stderr": "Timed out",
+ }
+
+ except Exception as e:
+ return {
+ "process_status": "error",
+ "error_type": type(e).__name__,
+ "stderr": str(e),
+ "stdout": str(e),
+ }
+
+
+def process_matplotlib_response(response, matplotlib_dump_dir: str):
+ image_data = response["image_data"]
+ # Convert the base64 string to a bytes object
+ images = [base64.b64decode(d["image_base64"]) for d in image_data]
+ # Create a list of PIL images from the bytes objects
+ images = [Image.open(BytesIO(img)) for img in images]
+ # Create a list of image paths
+ image_paths = []
+ for i, img in enumerate(images):
+ # create new directory for each day to better organize data:
+ dump_dname = datetime.today().strftime("%Y-%m-%d")
+ dump_dpath = Path(matplotlib_dump_dir, dump_dname)
+ dump_dpath.mkdir(parents=True, exist_ok=True)
+ # save image into a file
+ dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
+ dump_fpath = dump_dpath / dump_fname
+ img.save(dump_fpath, "PNG")
+ image_paths.append(str(dump_fpath))
+
+ # this is kind of convoluted, we send back this response to the subprocess which
+ # prints it out
+ info = {
+ "filepath": str(image_paths[-1]),
+ "mimetype": "image/png",
+ }
+ return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
+
+
+def execute_subprocess_request(request, ctx: CodeExecutionContext):
+ "Route requests from the subprocess (via network Pipes) to the internet/tools."
+ if request["type"] == "matplotlib":
+ return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
+ else:
+ raise Exception(f'Unrecognised network request type: {request["type"]}')
+
+
+def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
+ # Create Pipes to be used for any external tool/network requests.
+ req_r, req_w = multiprocessing.Pipe(duplex=False)
+ resp_r, resp_w = multiprocessing.Pipe(duplex=False)
+
+ cmd += [str(req_w.fileno()), str(resp_r.fileno())]
+ proc = subprocess.Popen(
+ cmd,
+ pass_fds=(req_w.fileno(), resp_r.fileno()),
+ text=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ close_fds=True,
+ env=env,
+ )
+
+ # Close unnecessary fds.
+ req_w.close()
+ resp_r.close()
+
+ pipe_close = False
+ done_read = False
+ start = time.monotonic()
+ while proc.poll() is None and not pipe_close:
+ if req_r.poll(0.1):
+ # NB: Python pipe semantics for poll and recv mean that
+ # poll() returns True is a pipe is closed.
+ # CF old school PEP from '09
+ # https://bugs.python.org/issue5573
+ try:
+ request = json.loads(req_r.recv_bytes().decode("utf-8"))
+ response = execute_subprocess_request(request, ctx)
+
+ resp_w.send_bytes(json.dumps(response).encode("utf-8"))
+ except EOFError:
+ # The request pipe is closed - set a marker to exit
+ # after the next attempt at reading stdout/stderr.
+ pipe_close = True
+
+ try:
+ # If lots has been printed, pipe might be full but
+ # proc cannot exit until all the stdout/stderr
+ # been written/read.
+ stdout, stderr = proc.communicate(timeout=0.3)
+ done_read = True
+ except subprocess.TimeoutExpired:
+ # The program has not terminated. Ignore it, there
+ # may be more network/tool requests.
+ continue
+ if time.monotonic() - start > CODE_EXEC_TIMEOUT:
+ proc.terminate()
+ raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
+
+ if not done_read:
+ # Solve race condition where process terminates before
+ # we hit the while loop.
+ stdout, stderr = proc.communicate(timeout=0.3)
+
+ resp_w.close()
+ req_r.close()
+ return stdout, stderr, proc.returncode
diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/matplotlib_custom_backend.py b/llama_toolchain/agentic_system/tools/ipython_tool/matplotlib_custom_backend.py
new file mode 100644
index 000000000..3aba2ef21
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/ipython_tool/matplotlib_custom_backend.py
@@ -0,0 +1,87 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+"""
+A custom Matplotlib backend that overrides the show method to return image bytes.
+"""
+
+import base64
+import io
+import json as _json
+
+import matplotlib
+from matplotlib.backend_bases import FigureManagerBase
+
+# Import necessary components from Matplotlib
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+
+class CustomFigureCanvas(FigureCanvasAgg):
+ def show(self):
+ # Save the figure to a BytesIO object
+ buf = io.BytesIO()
+ self.print_png(buf)
+ image_bytes = buf.getvalue()
+ buf.close()
+ return image_bytes
+
+
+class CustomFigureManager(FigureManagerBase):
+ def __init__(self, canvas, num):
+ super().__init__(canvas, num)
+
+
+# Mimic module initialization that integrates with the Matplotlib backend system
+def _create_figure_manager(num, *args, **kwargs):
+ """
+ Create a custom figure manager instance.
+ """
+ FigureClass = kwargs.pop("FigureClass", None) # noqa: N806
+ if FigureClass is None:
+ from matplotlib.figure import Figure
+
+ FigureClass = Figure # noqa: N806
+ fig = FigureClass(*args, **kwargs)
+ canvas = CustomFigureCanvas(fig)
+ manager = CustomFigureManager(canvas, num)
+ return manager
+
+
+def show():
+ """
+ Handle all figures and potentially return their images as bytes.
+
+ This function iterates over all figures registered with the custom backend,
+ renders them as images in bytes format, and could return a list of bytes objects,
+ one for each figure, or handle them as needed.
+ """
+ image_data = []
+ for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers():
+ # Get the figure from the manager
+ fig = manager.canvas.figure
+ buf = io.BytesIO() # Create a buffer for the figure
+ fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format
+ buf.seek(0) # Go to the beginning of the buffer
+ image_bytes = buf.getvalue() # Retrieve bytes value
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
+ image_data.append({"image_base64": image_base64})
+ buf.close()
+
+ req_con, resp_con = _open_connections()
+
+ _json_dump = _json.dumps(
+ {
+ "type": "matplotlib",
+ "image_data": image_data,
+ }
+ )
+ req_con.send_bytes(_json_dump.encode("utf-8"))
+ resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
+ print(resp)
+
+
+FigureCanvas = CustomFigureCanvas
+FigureManager = CustomFigureManager
diff --git a/llama_toolchain/agentic_system/tools/ipython_tool/utils.py b/llama_toolchain/agentic_system/tools/ipython_tool/utils.py
new file mode 100644
index 000000000..d6f539a39
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/ipython_tool/utils.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import os
+
+DIR = os.path.dirname(os.path.realpath(__file__))
+CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
+CODE_ENV_PREFIX = None
+
+
+def get_code_env_prefix() -> str:
+ global CODE_ENV_PREFIX
+
+ if CODE_ENV_PREFIX is None:
+ with open(CODE_ENV_PREFIX_FILE, "r") as f:
+ CODE_ENV_PREFIX = f.read()
+
+ return CODE_ENV_PREFIX
diff --git a/llama_toolchain/agentic_system/tools/safety.py b/llama_toolchain/agentic_system/tools/safety.py
new file mode 100644
index 000000000..da0abe10a
--- /dev/null
+++ b/llama_toolchain/agentic_system/tools/safety.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import List
+
+from llama_toolchain.agentic_system.safety import ShieldRunnerMixin
+
+from llama_toolchain.inference.api import Message
+from llama_toolchain.safety.api.datatypes import ShieldDefinition
+from llama_toolchain.safety.api.endpoints import Safety
+
+from .builtin import BaseTool
+
+
+class SafeTool(BaseTool, ShieldRunnerMixin):
+ """A tool that makes other tools safety enabled"""
+
+ def __init__(
+ self,
+ tool: BaseTool,
+ safety_api: Safety,
+ input_shields: List[ShieldDefinition] = None,
+ output_shields: List[ShieldDefinition] = None,
+ ):
+ self._tool = tool
+ ShieldRunnerMixin.__init__(
+ self, safety_api, input_shields=input_shields, output_shields=output_shields
+ )
+
+ def get_name(self) -> str:
+ # return the name of the wrapped tool
+ return self._tool.get_name()
+
+ async def run(self, messages: List[Message]) -> List[Message]:
+ if self.input_shields:
+ await self.run_shields(messages, self.input_shields)
+ # run the underlying tool
+ res = await self._tool.run(messages)
+ if self.output_shields:
+ await self.run_shields(messages, self.output_shields)
+
+ return res
+
+
+def with_safety(
+ tool: BaseTool,
+ safety_api: Safety,
+ input_shields: List[ShieldDefinition] = None,
+ output_shields: List[ShieldDefinition] = None,
+) -> SafeTool:
+ return SafeTool(
+ tool,
+ safety_api,
+ input_shields=input_shields,
+ output_shields=output_shields,
+ )
diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_toolchain/common/prompt_for_config.py
index c708b96d7..071d40cb6 100644
--- a/llama_toolchain/common/prompt_for_config.py
+++ b/llama_toolchain/common/prompt_for_config.py
@@ -144,7 +144,11 @@ def prompt_for_config(
nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
- elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
+ elif (
+ inspect.isclass(field_type)
+ and issubclass(field_type, BaseModel)
+ and len(field_type.__fields__) > 0
+ ):
print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,
diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py
index 82196a00b..762b3d487 100644
--- a/llama_toolchain/distribution/datatypes.py
+++ b/llama_toolchain/distribution/datatypes.py
@@ -15,6 +15,7 @@ from strong_typing.schema import json_schema_type
class ApiSurface(Enum):
inference = "inference"
safety = "safety"
+ agentic_system = "agentic_system"
@json_schema_type
@@ -39,14 +40,19 @@ class SourceAdapter(Adapter):
module: str = Field(
...,
description="""
-Fully-qualified name of the module to import. The module is expected to have
-a `get_adapter_instance()` method which will be passed a validated config object
-of type `config_class`.""",
+Fully-qualified name of the module to import. The module is expected to have:
+
+ - `get_adapter_impl(config, deps)`: returns the local implementation
+""",
)
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this adapter",
)
+ adapter_dependencies: List[ApiSurface] = Field(
+ default_factory=list,
+ description="Higher-level API surfaces may depend on other adapters to provide their functionality",
+ )
@json_schema_type
@@ -56,6 +62,13 @@ class PassthroughApiAdapter(Adapter):
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
)
+ module: str = Field(
+ ...,
+ description="""
+Fully-qualified name of the module to import. The module is expected to have:
+ - `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
+""",
+ )
class Distribution(BaseModel):
diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py
index 773d05f26..03bd5d3a5 100644
--- a/llama_toolchain/distribution/distribution.py
+++ b/llama_toolchain/distribution/distribution.py
@@ -7,6 +7,7 @@
import inspect
from typing import Dict, List
+from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety
@@ -29,6 +30,7 @@ def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
protocols = {
ApiSurface.inference: Inference,
ApiSurface.safety: Safety,
+ ApiSurface.agentic_system: AgenticSystem,
}
for surface, protocol in protocols.items():
diff --git a/llama_toolchain/distribution/dynamic.py b/llama_toolchain/distribution/dynamic.py
index 483c08d79..62e954d46 100644
--- a/llama_toolchain/distribution/dynamic.py
+++ b/llama_toolchain/distribution/dynamic.py
@@ -8,7 +8,7 @@ import asyncio
import importlib
from typing import Any, Dict
-from .datatypes import SourceAdapter
+from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter
def instantiate_class_type(fully_qualified_name):
@@ -18,9 +18,17 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the ApiSurface
-def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]):
+def instantiate_adapter(
+ adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
+):
module = importlib.import_module(adapter.module)
config_type = instantiate_class_type(adapter.config_class)
config = config_type(**adapter_config)
- return asyncio.run(module.get_adapter_impl(config))
+ return asyncio.run(module.get_adapter_impl(config, deps))
+
+
+def instantiate_client(adapter: PassthroughApiAdapter, base_url: str):
+ module = importlib.import_module(adapter.module)
+
+ return asyncio.run(module.get_client_impl(base_url))
diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py
index eda7a2c1a..897b8f9d0 100644
--- a/llama_toolchain/distribution/registry.py
+++ b/llama_toolchain/distribution/registry.py
@@ -7,6 +7,8 @@
from functools import lru_cache
from typing import List, Optional
+from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters
+
from llama_toolchain.inference.adapters import available_inference_adapters
from llama_toolchain.safety.adapters import available_safety_adapters
@@ -43,10 +45,26 @@ COMMON_DEPENDENCIES = [
]
+def client_module(api_surface: ApiSurface) -> str:
+ return f"llama_toolchain.{api_surface.value}.client"
+
+
+def passthrough(api_surface: ApiSurface, port: int) -> PassthroughApiAdapter:
+ return PassthroughApiAdapter(
+ api_surface=api_surface,
+ adapter_id=f"{api_surface.value}-passthrough",
+ base_url=f"http://localhost:{port}",
+ module=client_module(api_surface),
+ )
+
+
@lru_cache()
def available_distributions() -> List[Distribution]:
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()}
+ agentic_system_adapters_by_id = {
+ a.adapter_id: a for a in available_agentic_system_adapters()
+ }
return [
Distribution(
@@ -56,6 +74,9 @@ def available_distributions() -> List[Distribution]:
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-reference"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
+ ApiSurface.agentic_system: agentic_system_adapters_by_id[
+ "meta-reference"
+ ],
},
),
Distribution(
@@ -76,16 +97,9 @@ def available_distributions() -> List[Distribution]:
"uvicorn",
],
adapters={
- ApiSurface.inference: PassthroughApiAdapter(
- api_surface=ApiSurface.inference,
- adapter_id="inference-passthrough",
- base_url="http://localhost:5001",
- ),
- ApiSurface.safety: PassthroughApiAdapter(
- api_surface=ApiSurface.safety,
- adapter_id="safety-passthrough",
- base_url="http://localhost:5001",
- ),
+ ApiSurface.inference: passthrough(ApiSurface.inference, 5001),
+ ApiSurface.safety: passthrough(ApiSurface.safety, 5001),
+ ApiSurface.agentic_system: passthrough(ApiSurface.agentic_system, 5001),
},
),
Distribution(
@@ -95,6 +109,9 @@ def available_distributions() -> List[Distribution]:
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-ollama"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
+ ApiSurface.agentic_system: agentic_system_adapters_by_id[
+ "meta-reference"
+ ],
},
),
]
diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py
index fa57322bd..5ac133d4b 100644
--- a/llama_toolchain/distribution/server.py
+++ b/llama_toolchain/distribution/server.py
@@ -12,7 +12,16 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
-from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
+from typing import (
+ Any,
+ AsyncGenerator,
+ AsyncIterator,
+ Dict,
+ get_type_hints,
+ List,
+ Optional,
+ Set,
+)
import fire
import httpx
@@ -27,9 +36,9 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
-from .datatypes import PassthroughApiAdapter
+from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter
from .distribution import api_surface_endpoints
-from .dynamic import instantiate_adapter
+from .dynamic import instantiate_adapter, instantiate_client
from .registry import resolve_distribution
@@ -213,6 +222,29 @@ def create_dynamic_typed_route(func: Any):
return endpoint
+def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
+
+ by_id = {x.api_surface: x for x in adapters}
+
+ def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
+ visited.add(a.api_surface)
+
+ for surface in a.adapter_dependencies:
+ if surface not in visited:
+ dfs(by_id[surface], visited, stack)
+
+ stack.append(a.api_surface)
+
+ visited = set()
+ stack = []
+
+ for a in adapters:
+ if a.api_surface not in visited:
+ dfs(a, visited, stack)
+
+ return [by_id[x] for x in stack]
+
+
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
@@ -228,7 +260,13 @@ def main(
all_endpoints = api_surface_endpoints()
adapter_configs = config["adapters"]
- for surface, adapter in dist.adapters.items():
+ adapters = topological_sort(dist.adapters.values())
+
+ # TODO: split this into two parts, first you resolve all impls
+ # and then you create the routes.
+ impls = {}
+ for adapter in adapters:
+ surface = adapter.api_surface
if surface.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {surface}. Please add it to the config"
@@ -242,8 +280,11 @@ def main(
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
+ impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else:
- impl = instantiate_adapter(adapter, adapter_config)
+ deps = {surface: impls[surface] for surface in adapter.adapter_dependencies}
+ impl = instantiate_adapter(adapter, adapter_config, deps)
+ impls[surface] = impl
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py
index 178452fde..aa84f906d 100644
--- a/llama_toolchain/inference/client.py
+++ b/llama_toolchain/inference/client.py
@@ -23,6 +23,10 @@ from .api import (
from .event_logger import EventLogger
+async def get_client_impl(base_url: str):
+ return InferenceClient(base_url)
+
+
class InferenceClient(Inference):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py
index beeb6dd65..7b54313c4 100644
--- a/llama_toolchain/inference/inference.py
+++ b/llama_toolchain/inference/inference.py
@@ -6,11 +6,13 @@
import asyncio
-from typing import AsyncIterator, Union
+from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
+from llama_toolchain.distribution.datatypes import Adapter, ApiSurface
+
from .api.config import MetaReferenceImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
@@ -27,7 +29,9 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator
-async def get_adapter_impl(config: MetaReferenceImplConfig):
+async def get_adapter_impl(
+ config: MetaReferenceImplConfig, _deps: Dict[ApiSurface, Adapter]
+):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py
index fb37bde1a..2bceebc68 100644
--- a/llama_toolchain/safety/client.py
+++ b/llama_toolchain/safety/client.py
@@ -21,6 +21,10 @@ from .api import (
)
+async def get_client_impl(base_url: str):
+ return SafetyClient(base_url)
+
+
class SafetyClient(Safety):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
diff --git a/llama_toolchain/safety/safety.py b/llama_toolchain/safety/safety.py
index 12405c161..21b7e6f1f 100644
--- a/llama_toolchain/safety/safety.py
+++ b/llama_toolchain/safety/safety.py
@@ -6,6 +6,10 @@
import asyncio
+from typing import Dict
+
+from llama_toolchain.distribution.datatypes import Adapter, ApiSurface
+
from .config import SafetyConfig
from .api.endpoints import * # noqa
from .shields import (
@@ -19,7 +23,7 @@ from .shields import (
)
-async def get_adapter_impl(config: SafetyConfig):
+async def get_adapter_impl(config: SafetyConfig, _deps: Dict[ApiSurface, Adapter]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)