Bring agentic system api to toolchain

Add adapter dependencies and resolve adapters using a topological sort
This commit is contained in:
Ashwin Bharambe 2024-08-04 10:53:38 -07:00
parent b0e5340645
commit be19b22391
31 changed files with 2740 additions and 25 deletions

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
.env
__pycache__
dist
*.egg-info

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter
def available_agentic_system_adapters() -> List[Adapter]:
return [
SourceAdapter(
api_surface=ApiSurface.agentic_system,
adapter_id="meta-reference",
pip_packages=[
"codeshield",
"torch",
"transformers",
],
module="llama_toolchain.agentic_system.agentic_system",
config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig",
adapter_dependencies=[
ApiSurface.inference,
ApiSurface.safety,
],
),
]

View file

@ -0,0 +1,786 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_toolchain.agentic_system.api import AgenticSystem
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface
from llama_toolchain.inference.api import Inference
from llama_toolchain.safety.api import Safety
from .config import AgenticSystemConfig
from .api.endpoints import * # noqa
import logging
import os
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
from llama_toolchain.inference.api import ChatCompletionRequest
from llama_toolchain.inference.api.datatypes import (
Attachment,
BuiltinTool,
ChatCompletionResponseEventType,
CompletionMessage,
Message,
Role,
SamplingParams,
StopReason,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
URL,
)
from llama_toolchain.safety.api.datatypes import (
BuiltinShield,
ShieldDefinition,
ShieldResponse,
)
from termcolor import cprint
from .api.datatypes import (
AgenticSystemInstanceConfig,
AgenticSystemTurnResponseEvent,
AgenticSystemTurnResponseEventType,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
InferenceStep,
Session,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
)
from .api.endpoints import (
AgenticSystemCreateRequest,
AgenticSystemCreateResponse,
AgenticSystemSessionCreateRequest,
AgenticSystemSessionCreateResponse,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk,
)
from .safety import SafetyException, ShieldRunnerMixin
from .system_prompt import get_agentic_prefix_messages
from .tools.base import BaseTool
from .tools.builtin import (
BraveSearchTool,
CodeInterpreterTool,
PhotogenTool,
SingleMessageBuiltinTool,
WolframAlphaTool,
)
from .tools.safety import with_safety
logger = logging.getLogger()
logger.setLevel(logging.INFO)
async def get_adapter_impl(
config: AgenticSystemConfig, deps: Dict[ApiSurface, Adapter]
):
assert isinstance(
config, AgenticSystemConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
deps[ApiSurface.inference],
deps[ApiSurface.safety],
)
await impl.initialize()
return impl
async def execute_tool_call_maybe(
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
# When this changes, we can drop this assert
# Whether to call tools on each message and aggregate
# or aggregate and call tool once, reamins to be seen.
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
name = tool_call.tool_name
assert isinstance(name, BuiltinTool)
name = name.value
assert name in tools_dict, f"Tool {name} not found"
tool = tools_dict[name]
result_messages = await tool.run(messages)
return result_messages
def print_dialog(messages: List[Message]):
for i, m in enumerate(messages):
if m.role == Role.user.value:
color = "red"
elif m.role == Role.assistant.value:
color = "white"
elif m.role == Role.ipython.value:
color = "yellow"
elif m.role == Role.system.value:
color = "green"
else:
color = "white"
s = str(m)
cprint(f"{i} ::: {s[:100]}...", color=color)
AGENT_INSTANCES_BY_ID = {}
class AgentInstance(ShieldRunnerMixin):
def __init__(
self,
system_id: int,
instance_config: AgenticSystemInstanceConfig,
model: str,
inference_api: Inference,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
custom_tool_definitions: List[ToolDefinition],
input_shields: List[ShieldDefinition],
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
):
self.system_id = system_id
self.instance_config = instance_config
self.model = model
self.inference_api = inference_api
self.safety_api = safety_api
if prefix_messages is not None and len(prefix_messages) > 0:
self.prefix_messages = prefix_messages
else:
self.prefix_messages = get_agentic_prefix_messages(
builtin_tools, custom_tool_definitions
)
for m in self.prefix_messages:
print(m.content)
self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.sessions = {}
ShieldRunnerMixin.__init__(
self,
safety_api,
input_shields=input_shields,
output_shields=output_shields,
)
def create_session(self, name: str) -> Session:
session_id = str(uuid.uuid4())
session = Session(
session_id=session_id,
session_name=name,
turns=[],
started_at=datetime.now(),
)
self.sessions[session_id] = session
return session
async def create_and_execute_turn(
self, request: AgenticSystemTurnCreateRequest
) -> AsyncGenerator:
assert (
request.session_id in self.sessions
), f"Session {request.session_id} not found"
session = self.sessions[request.session_id]
messages = []
for i, turn in enumerate(session.turns):
# print(f"turn {i}")
# print_dialog(turn.input_messages)
messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
call_id=response.call_id,
tool_name=response.tool_name,
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
# TODO: Properly persist the
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
stop_reason=StopReason.end_of_turn,
)
)
messages.extend(request.messages)
# print("processed dialog ======== ")
# print_dialog(messages)
turn_id = str(uuid.uuid4())
params = self.instance_config.sampling_params
start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
steps = []
output_message = None
async for chunk in self.run(
turn_id=turn_id,
input_messages=messages,
temperature=params.temperature,
top_p=params.top_p,
stream=request.stream,
max_gen_len=params.max_tokens,
):
if isinstance(chunk, CompletionMessage):
cprint(
f"{chunk.role.capitalize()}: {chunk.content}",
"white",
attrs=["bold"],
)
output_message = chunk
continue
assert isinstance(
chunk, AgenticSystemTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgenticSystemTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
session.turns.append(turn)
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
async def run_shields_wrapper(
self,
turn_id: str,
messages: List[Message],
shields: List[ShieldDefinition],
touchpoint: str,
) -> AsyncGenerator:
if len(shields) == 0:
return
step_id = str(uuid.uuid4())
try:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
)
)
await self.run_shields(messages, shields)
except SafetyException as e:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=e.response,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
),
)
)
)
async def run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def _run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
input_messages = preprocess_dialog(input_messages, self.prefix_messages)
attachments = []
n_iter = 0
while True:
msg = input_messages[-1]
if msg.role == Role.user.value:
color = "blue"
elif msg.role == Role.ipython.value:
color = "yellow"
else:
color = None
cprint(f"{str(msg)}", color=color)
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
step_type=StepType.inference.value,
step_id=step_id,
)
)
)
# where are the available tools?
req = ChatCompletionRequest(
model=self.model,
messages=input_messages,
available_tools=self.instance_config.available_tools,
stream=True,
sampling_params=SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_gen_len,
),
)
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion(req):
event = chunk.event
if event.event_type != ChatCompletionResponseEventType.progress:
continue
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
)
)
)
elif isinstance(delta, str):
content += delta
if stream and event.stop_reason is None:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
)
)
)
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
if event.stop_reason is not None:
stop_reason = event.stop_reason
stop_reason = stop_reason or StopReason.out_of_tokens
message = CompletionMessage(
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.inference.value,
step_id=step_id,
step_details=InferenceStep(
step_id=step_id, turn_id=turn_id, model_response=message
),
)
)
)
if n_iter >= self.max_infer_iters:
cprint("Done with MAX iterations, exiting.")
yield message
break
if stop_reason == StopReason.out_of_tokens:
cprint("Out of token budget, exiting.")
yield message
break
if len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn:
if len(attachments) > 0:
if isinstance(message.content, list):
message.content += attachments
else:
message.content = [message.content] + attachments
yield message
else:
cprint(f"Partial message: {str(message)}", color="green")
input_messages = input_messages + [message]
else:
cprint(f"{str(message)}", color="green")
try:
tool_call = message.tool_calls[0]
name = tool_call.tool_name
if not isinstance(name, BuiltinTool):
yield message
return
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
)
)
)
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=result_message.call_id,
tool_name=result_message.tool_name,
content=result_message.content,
)
],
),
)
)
)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
),
)
)
)
except SafetyException as e:
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=e.response,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
return
if isinstance(result_message.content, Attachment):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
attachments.append(result_message.content)
elif isinstance(result_message.content, list) or isinstance(
result_message.content, tuple
):
for c in result_message.content:
if isinstance(c, Attachment):
attachments.append(c)
input_messages = input_messages + [message, result_message]
n_iter += 1
class MetaReferenceAgenticSystemImpl(AgenticSystem):
def __init__(self, inference_api: Inference, safety_api: Safety):
self.inference_api = inference_api
self.safety_api = safety_api
async def initialize(self) -> None:
pass
async def create_agentic_system(
self,
request: AgenticSystemCreateRequest,
) -> AgenticSystemCreateResponse:
system_id = str(uuid.uuid4())
builtin_tools = []
custom_tool_definitions = []
cfg = request.instance_config
for dfn in cfg.available_tools:
if isinstance(dfn.tool_name, BuiltinTool):
if dfn.tool_name == BuiltinTool.wolfram_alpha:
tool = WolframAlphaTool(os.environ.get("WOLFRAM_ALPHA_API_KEY"))
elif dfn.tool_name == BuiltinTool.brave_search:
tool = BraveSearchTool(os.environ.get("BRAVE_SEARCH_API_KEY"))
elif dfn.tool_name == BuiltinTool.code_interpreter:
tool = CodeInterpreterTool()
elif dfn.tool_name == BuiltinTool.photogen:
tool = PhotogenTool(
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
)
else:
raise ValueError(f"Unknown builtin tool: {dfn.tool_name}")
builtin_tools.append(
with_safety(
tool, self.safety_api, dfn.input_shields, dfn.output_shields
)
)
else:
custom_tool_definitions.append(dfn)
AGENT_INSTANCES_BY_ID[system_id] = AgentInstance(
system_id=system_id,
instance_config=request.instance_config,
model=request.model,
inference_api=self.inference_api,
builtin_tools=builtin_tools,
custom_tool_definitions=custom_tool_definitions,
safety_api=self.safety_api,
input_shields=cfg.input_shields,
output_shields=cfg.output_shields,
prefix_messages=cfg.debug_prefix_messages,
)
return AgenticSystemCreateResponse(
system_id=system_id,
)
async def create_agentic_system_session(
self,
request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse:
system_id = request.system_id
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
agent = AGENT_INSTANCES_BY_ID[system_id]
session = agent.create_session(request.session_name)
return AgenticSystemSessionCreateResponse(
session_id=session.session_id,
)
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
) -> AsyncGenerator:
system_id = request.system_id
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
agent = AGENT_INSTANCES_BY_ID[system_id]
assert (
request.session_id in agent.sessions
), f"Session {request.session_id} not found"
async for event in agent.create_and_execute_turn(request):
yield event
def attachment_message(url: URL) -> ToolResponseMessage:
uri = url.uri
assert uri.startswith("file://")
filepath = uri[len("file://") :]
return ToolResponseMessage(
call_id="",
tool_name=BuiltinTool.code_interpreter,
content=f'# There is a file accessible to you at "{filepath}"',
)
def preprocess_dialog(
messages: List[Message], prefix_messages: List[Message]
) -> List[Message]:
"""
Preprocesses the dialog by removing the system message and
adding the system message to the beginning of the dialog.
"""
ret = prefix_messages.copy()
for m in messages:
if m.role == Role.system.value:
continue
# NOTE: the ideal behavior is to use `file_path = ...` but that
# means we need to have stateful execution o f code which we currently
# do not have.
if isinstance(m.content, Attachment):
ret.append(attachment_message(m.content.url))
elif isinstance(m.content, list):
for c in m.content:
if isinstance(c, Attachment):
ret.append(attachment_message(c.url))
ret.append(m)
return ret

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa
from .endpoints import * # noqa

View file

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
from llama_toolchain.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.safety.api.datatypes import * # noqa: F403
from llama_toolchain.memory.api.datatypes import * # noqa: F403
@json_schema_type
class AgenticSystemToolDefinition(ToolDefinition):
execution_config: Optional[RestAPIExecutionConfig] = None
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
class StepCommon(BaseModel):
turn_id: str
step_id: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class StepType(Enum):
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
memory_retrieval = "memory_retrieval"
@json_schema_type
class InferenceStep(StepCommon):
step_type: Literal[StepType.inference.value] = StepType.inference.value
model_response: CompletionMessage
@json_schema_type
class ToolExecutionStep(StepCommon):
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall]
tool_responses: List[ToolResponse]
@json_schema_type
class ShieldCallStep(StepCommon):
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
response: ShieldResponse
@json_schema_type
class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = (
StepType.memory_retrieval.value
)
memory_bank_ids: List[str]
documents: List[MemoryBankDocument]
scores: List[float]
Step = Annotated[
Union[
InferenceStep,
ToolExecutionStep,
ShieldCallStep,
MemoryRetrievalStep,
],
Field(discriminator="step_type"),
]
@json_schema_type
class Turn(BaseModel):
"""A single turn in an interaction with an Agentic System."""
turn_id: str
session_id: str
input_messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
steps: List[Step]
output_message: CompletionMessage
started_at: datetime
completed_at: Optional[datetime] = None
@json_schema_type
class Session(BaseModel):
"""A single session of an interaction with an Agentic System."""
session_id: str
session_name: str
turns: List[Turn]
started_at: datetime
@json_schema_type
class AgenticSystemInstanceConfig(BaseModel):
instructions: str
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot or built-in tool configurations as input to the model
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
default_factory=list
)
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
quantization_config: Optional[QuantizationConfig] = None
# if you completely want to replace the messages prefixed by the system,
# this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
class AgenticSystemTurnResponseEventType(Enum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
turn_start = "turn_start"
turn_complete = "turn_complete"
@json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value
)
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value
)
step_type: StepType
step_details: Step
@json_schema_type
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value
)
step_type: StepType
step_id: str
model_response_text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None
tool_response_text_delta: Optional[str] = None
@json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value
)
turn_id: str
@json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value
)
turn: Turn
@json_schema_type
class AgenticSystemTurnResponseEvent(BaseModel):
"""Streamed agent execution response."""
payload: Annotated[
Union[
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
],
Field(discriminator="event_type"),
]

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F403
from typing import Protocol
# this dependency is annoying and we need a forked up version anyway
from pyopenapi import webmethod
from strong_typing.schema import json_schema_type
@json_schema_type
class AgenticSystemCreateRequest(BaseModel):
model: str
instance_config: AgenticSystemInstanceConfig
@json_schema_type
class AgenticSystemCreateResponse(BaseModel):
system_id: str
@json_schema_type
class AgenticSystemSessionCreateRequest(BaseModel):
system_id: str
session_name: str
@json_schema_type
class AgenticSystemSessionCreateResponse(BaseModel):
session_id: str
@json_schema_type
# what's the URI?
class AgenticSystemTurnCreateRequest(BaseModel):
system_id: str
session_id: str
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
stream: Optional[bool] = False
override_config: Optional[AgenticSystemInstanceConfig] = None
@json_schema_type(
schema={"description": "Server side event (SSE) stream of these events"}
)
class AgenticSystemTurnResponseStreamChunk(BaseModel):
event: AgenticSystemTurnResponseEvent
@json_schema_type
class AgenticSystemStepResponse(BaseModel):
step: Step
class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/create")
async def create_agentic_system(
self,
request: AgenticSystemCreateRequest,
) -> AgenticSystemCreateResponse: ...
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
) -> AgenticSystemTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn(
self,
agent_id: str,
turn_id: str,
) -> Turn: ...
@webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step(
self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ...
@webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session(
self,
request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse: ...
@webmethod(route="/agentic_system/memory_bank/attach")
async def attach_memory_bank_to_agentic_system(
self,
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
@webmethod(route="/agentic_system/memory_bank/detach")
async def detach_memory_bank_from_agentic_system(
self,
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
@webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
@webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session(
self, agent_id: str, session_id: str
) -> None: ...
@webmethod(route="/agentic_system/delete")
async def delete_agentic_system(
self,
agent_id: str,
) -> None: ...

View file

@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from typing import AsyncGenerator
import fire
import httpx
from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams
from .api import (
AgenticSystem,
AgenticSystemCreateRequest,
AgenticSystemCreateResponse,
AgenticSystemInstanceConfig,
AgenticSystemSessionCreateRequest,
AgenticSystemSessionCreateResponse,
AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk,
)
async def get_client_impl(base_url: str):
return AgenticSystemClient(base_url)
class AgenticSystemClient(AgenticSystem):
def __init__(self, base_url: str):
self.base_url = base_url
async def create_agentic_system(
self, request: AgenticSystemCreateRequest
) -> AgenticSystemCreateResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/agentic_system/create",
data=request.json(),
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return AgenticSystemCreateResponse(**response.json())
async def create_agentic_system_session(
self,
request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/agentic_system/session/create",
data=request.json(),
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return AgenticSystemSessionCreateResponse(**response.json())
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/agentic_system/turn/create",
data=request.json(),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
yield AgenticSystemTurnResponseStreamChunk(
**json.loads(data)
)
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int):
# client to test remote impl of agentic system
api = await AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [
AgenticSystemToolDefinition(
tool_name=BuiltinTool.brave_search,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.wolfram_alpha,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.photogen,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.code_interpreter,
),
]
create_request = AgenticSystemCreateRequest(
model="Meta-Llama3.1-8B-Instruct",
instance_config=AgenticSystemInstanceConfig(
instructions="You are a helpful assistant",
sampling_params=SamplingParams(),
available_tools=tool_definitions,
input_shields=[],
output_shields=[],
quantization_config=None,
debug_prefix_messages=[],
),
)
create_response = await api.create_agentic_system(create_request)
print(create_response)
# TODO: Add chat session / turn apis to test e2e
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class AgenticSystemConfig(BaseModel):
# placeholder, no separate configuration is needed for now
pass

View file

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role
from llama_toolchain.safety.api.datatypes import (
OnViolationAction,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
from termcolor import cprint
class SafetyException(Exception): # noqa: N818
def __init__(self, response: ShieldResponse):
self.response = response
super().__init__(response.violation_return_message)
class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
self.output_shields = output_shields
async def run_shields(
self, messages: List[Message], shields: List[ShieldDefinition]
) -> List[ShieldResponse]:
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
# TODO(ashwin): we need to change the type of the message, this kind of modification
# is no longer appropriate
messages[0].role = Role.user.value
res = await self.safety_api.run_shields(
RunShieldRequest(
messages=messages,
shields=shields,
)
)
results = res.responses
for shield, r in zip(shields, results):
if r.is_violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return results

View file

@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import List
from llama_toolchain.inference.api import (
BuiltinTool,
Message,
SystemMessage,
ToolDefinition,
)
from .tools.builtin import SingleMessageBuiltinTool
def get_agentic_prefix_messages(
builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition]
) -> List[Message]:
messages = []
content = ""
if builtin_tools:
content += "Environment: ipython\n"
tool_str = ", ".join(
[
t.get_name()
for t in builtin_tools
if t.get_name() != BuiltinTool.code_interpreter.value
]
)
if tool_str:
content += f"Tools: {tool_str}\n"
current_date = datetime.now()
formatted_date = current_date.strftime("%d %B %Y")
date_str = f"""
Cutting Knowledge Date: December 2023
Today Date: {formatted_date}\n\n"""
content += date_str
if custom_tools:
custom_message = get_system_prompt_for_custom_tools(custom_tools)
content += custom_message
# TODO: Replace this hard coded message with instructions coming in the request
if False:
content += "You are a helpful Assistant."
messages.append(SystemMessage(content=content))
return messages
def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str:
custom_tool_params = ""
for t in custom_tools:
custom_tool_params += get_instruction_string(t) + "\n"
custom_tool_params += get_parameters_string(t) + "\n\n"
content = f"""
You have access to the following functions:
{custom_tool_params}
Think very carefully before calling functions.
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
"""
return content
def get_instruction_string(custom_tool_definition) -> str:
return f"Use the function '{custom_tool_definition.tool_name}' to '{custom_tool_definition.description}'"
def get_parameters_string(custom_tool_definition) -> str:
return json.dumps(
{
"name": custom_tool_definition.tool_name,
"description": custom_tool_definition.description,
"parameters": {
name: definition.__dict__
for name, definition in custom_tool_definition.parameters.items()
},
}
)
# NOTE: Unused right now
def translate_custom_tool_definition_to_json(tool_def):
"""Translates ToolDefinition to json as expected by model
eg. output for a function
{
"type": "function",
"function": {
"name": "conv_int",
"description": "Convert serialized fract24 integer into int value.",
"parameters": {
"type": "object",
"properties": [
{
"data": {
"type": "object",
"description": ""
}
}
],
"required": ["data"]
}
}
}
"""
assert isinstance(tool_def.tool_name, str)
func_def = {"type": "function", "function": {}}
func_def["function"]["name"] = tool_def.tool_name
func_def["function"]["description"] = tool_def.description or ""
if tool_def.parameters:
required = []
properties = []
for p_name, p_def in tool_def.parameters.items():
properties.append(
{
p_name: {
# TODO: see if this should not always be object
"type": "object",
"description": p_def.description or "",
}
}
)
if p_def.required:
required.append(p_name)
func_def["function"]["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
}
else:
func_def["function"]["parameters"] = {}
return json.dumps(func_def)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List
from llama_toolchain.inference.api import Message
class BaseTool(ABC):
@abstractmethod
def get_name(self) -> str:
raise NotImplementedError
@abstractmethod
async def run(self, messages: List[Message]) -> List[Message]:
raise NotImplementedError

View file

@ -0,0 +1,326 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import re
from abc import abstractmethod
from typing import List, Optional
import requests
from termcolor import cprint
from .ipython_tool.code_execution import (
CodeExecutionContext,
CodeExecutionRequest,
CodeExecutor,
TOOLS_ATTACHMENT_KEY_REGEX,
)
from llama_toolchain.inference.api import * # noqa: F403
from .base import BaseTool
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
)
return None
class SingleMessageBuiltinTool(BaseTool):
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, f"Expected single message, got {len(messages)}"
message = messages[0]
assert len(message.tool_calls) == 1, "Expected a single tool call"
tool_call = messages[0].tool_calls[0]
query = tool_call.arguments["query"]
response: str = await self.run_impl(query)
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response,
)
if attachment := interpret_content_as_attachment(response):
message.content = attachment
return [message]
@abstractmethod
async def run_impl(self, query: str) -> str:
raise NotImplementedError()
class PhotogenTool(SingleMessageBuiltinTool):
def __init__(self, dump_dir: str) -> None:
self.dump_dir = dump_dir
def get_name(self) -> str:
return BuiltinTool.photogen.value
async def run_impl(self, query: str) -> str:
"""
Implement this to give the model an ability to generate images.
Return:
info = {
"filepath": str(image_filepath),
"mimetype": "image/png",
}
"""
raise NotImplementedError()
class BraveSearchTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key
def get_name(self) -> str:
return BuiltinTool.brave_search.value
async def run_impl(self, query: str) -> str:
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": self.api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": query}
response = requests.get(url=url, params=payload, headers=headers)
return json.dumps(self._clean_brave_response(response.json()))
def _clean_brave_response(self, search_response, top_k=3):
query = None
clean_response = []
if "query" in search_response:
if "original" in search_response["query"]:
query = search_response["query"]["original"]
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][:top_k]:
r_type = m["type"]
results = search_response[r_type]["results"]
if r_type == "web":
# For web data - add a single output from the search
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"date",
"extra_snippets",
]
cleaned = {
k: v for k, v in results[idx].items() if k in selected_keys
}
elif r_type == "faq":
# For faw data - take a list of all the questions & answers
selected_keys = ["type", "question", "answer", "title", "url"]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "infobox":
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"long_desc",
]
cleaned = {
k: v for k, v in results[idx].items() if k in selected_keys
}
elif r_type == "videos":
selected_keys = [
"type",
"url",
"title",
"description",
"date",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "locations":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "news":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
else:
cleaned = []
clean_response.append(cleaned)
return {"query": query, "top_k": clean_response}
class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.url = "https://api.wolframalpha.com/v2/query"
def get_name(self) -> str:
return BuiltinTool.wolfram_alpha.value
async def run_impl(self, query: str) -> str:
params = {
"input": query,
"appid": self.api_key,
"format": "plaintext",
"output": "json",
}
response = requests.get(
self.url,
params=params,
)
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
def _clean_wolfram_alpha_response(self, wa_response):
remove = {
"queryresult": [
"datatypes",
"error",
"timedout",
"timedoutpods",
"numpods",
"timing",
"parsetiming",
"parsetimedout",
"recalculate",
"id",
"host",
"server",
"related",
"version",
{
"pods": [
"scanner",
"id",
"error",
"expressiontypes",
"states",
"infos",
"position",
"numsubpods",
]
},
"assumptions",
],
}
for main_key in remove:
for key_to_remove in remove[main_key]:
try:
if key_to_remove == "assumptions":
if "assumptions" in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
if isinstance(key_to_remove, dict):
for sub_key in key_to_remove:
if sub_key == "pods":
for i in range(len(wa_response[main_key][sub_key])):
if (
wa_response[main_key][sub_key][i]["title"]
== "Result"
):
del wa_response[main_key][sub_key][i + 1 :]
break
sub_items = wa_response[main_key][sub_key]
for i in range(len(sub_items)):
for sub_key_to_remove in key_to_remove[sub_key]:
if sub_key_to_remove in sub_items[i]:
del sub_items[i][sub_key_to_remove]
elif key_to_remove in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
except KeyError:
pass
return wa_response
class CodeInterpreterTool(BaseTool):
def __init__(self) -> None:
ctx = CodeExecutionContext(
matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump",
)
self.code_executor = CodeExecutor(ctx)
def get_name(self) -> str:
return BuiltinTool.code_interpreter.value
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
message = messages[0]
assert len(message.tool_calls) == 1, "Expected a single tool call"
tool_call = messages[0].tool_calls[0]
script = tool_call.arguments["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]
for out_type in ["stdout", "stderr"]:
res_out = res[out_type]
if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr":
cprint(f"ipython tool error: ↓\n{res_out}", color="red")
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content="\n".join(pieces),
)
if attachment := interpret_content_as_attachment(res["stdout"]):
message.content = attachment
return [message]

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from abc import abstractmethod
from typing import Dict, List
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_toolchain.agentic_system.api import * # noqa: F403
from .builtin import interpret_content_as_attachment
class CustomTool:
"""
Developers can define their custom tools that models can use
by extending this class.
Developers need to provide
- name
- description
- params_definition
- implement tool's behavior in `run_impl` method
NOTE: The return of the `run` method needs to be json serializable
"""
@abstractmethod
def get_name(self) -> str:
raise NotImplementedError
@abstractmethod
def get_description(self) -> str:
raise NotImplementedError
@abstractmethod
def get_params_definition(self) -> Dict[str, ToolParamDefinition]:
raise NotImplementedError
def get_instruction_string(self) -> str:
return f"Use the function '{self.get_name()}' to: {self.get_description()}"
def parameters_for_system_prompt(self) -> str:
return json.dumps(
{
"name": self.get_name(),
"description": self.get_description(),
"parameters": {
name: definition.__dict__
for name, definition in self.get_params_definition().items()
},
}
)
def get_tool_definition(self) -> AgenticSystemToolDefinition:
return AgenticSystemToolDefinition(
tool_name=self.get_name(),
description=self.get_description(),
parameters=self.get_params_definition(),
)
@abstractmethod
async def run(self, messages: List[Message]) -> List[Message]:
raise NotImplementedError
class SingleMessageCustomTool(CustomTool):
"""
Helper class to handle custom tools that take a single message
Extending this class and implementing the `run_impl` method will
allow for the tool be called by the model and the necessary plumbing.
"""
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
try:
response = await self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response_str,
)
if attachment := interpret_content_as_attachment(response_str):
message.content = attachment
return [message]
@abstractmethod
async def run_impl(self, *args, **kwargs):
raise NotImplementedError()

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, List
from llama_models.llama3_1.api.datatypes import StopReason, ToolResponseMessage
from llama_toolchain.agentic_system.api import (
AgenticSystem,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseEventType as EventType,
)
from llama_toolchain.inference.api import Message
async def execute_with_custom_tools(
system: AgenticSystem,
system_id: str,
session_id: str,
messages: List[Message],
custom_tools: List[Any],
max_iters: int = 5,
stream: bool = True,
) -> AsyncGenerator:
# first create a session, or do you keep a persistent session?
tools_dict = {t.get_name(): t for t in custom_tools}
current_messages = messages.copy()
n_iter = 0
while n_iter < max_iters:
n_iter += 1
request = AgenticSystemTurnCreateRequest(
system_id=system_id,
session_id=session_id,
messages=current_messages,
stream=stream,
)
turn = None
async for chunk in system.create_agentic_system_turn(request):
if chunk.event.payload.event_type != EventType.turn_complete.value:
yield chunk
else:
turn = chunk.event.payload.turn
break
message = turn.output_message
if len(message.tool_calls) == 0:
yield chunk
return
if message.stop_reason == StopReason.out_of_tokens:
yield chunk
return
tool_call = message.tool_calls[0]
if tool_call.tool_name not in tools_dict:
m = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
)
next_message = m
else:
tool = tools_dict[tool_call.tool_name]
result_messages = await execute_custom_tool(tool, message)
next_message = result_messages[0]
yield next_message
current_messages = [next_message]
async def execute_custom_tool(tool: Any, message: Message) -> List[Message]:
result_messages = await tool.run([message])
assert (
len(result_messages) == 1
), f"Expected single message, got {len(result_messages)}"
return result_messages

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import errno
# Disabling potentially dangerous functions
import os as _os
from functools import partial
os_funcs_to_disable = [
"kill",
"system",
"putenv",
"remove",
"removedirs",
"rmdir",
"fchdir",
"setuid",
"fork",
"forkpty",
"killpg",
"rename",
"renames",
"truncate",
"replace",
# "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
"fchmod",
"fchown",
"chmod",
"chown",
"chroot",
"fchdir",
"lchflags",
"lchmod",
"lchown",
"chdir",
]
def call_not_allowed(*args, **kwargs):
raise OSError(errno.EPERM, "Call are not permitted in this environment")
for func_name in os_funcs_to_disable:
if hasattr(_os, func_name):
setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
import shutil as _shutil
for func_name in ["rmtree", "move", "chown"]:
if hasattr(_shutil, func_name):
setattr(
_shutil,
func_name,
partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
)
import subprocess as _subprocess
def popen_not_allowed(*args, **kwargs):
raise _subprocess.CalledProcessError(
-1,
args[0] if args else "unknown",
stderr="subprocess.Popen is not allowed in this environment",
)
_subprocess.Popen = popen_not_allowed
import atexit as _atexit
import builtins as _builtins
import io as _io
import json as _json
import sys as _sys
# NB! The following "unused" imports crucial, make sure not not to remove
# them with linters - they're used in code_execution.py
from contextlib import ( # noqa
contextmanager as _contextmanager,
redirect_stderr as _redirect_stderr,
redirect_stdout as _redirect_stdout,
)
from multiprocessing.connection import Connection as _Connection
# Mangle imports to avoid polluting model execution namespace.
_IO_SINK = _io.StringIO()
_NETWORK_TIMEOUT = 5
_NETWORK_CONNECTIONS = None
def _open_connections():
global _NETWORK_CONNECTIONS
if _NETWORK_CONNECTIONS is not None:
# Ensure connections only opened once.
return _NETWORK_CONNECTIONS
req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
req_con = _Connection(int(req_w_fd), readable=False)
resp_con = _Connection(int(resp_r_fd), writable=False)
_NETWORK_CONNECTIONS = (req_con, resp_con)
return _NETWORK_CONNECTIONS
_builtins._open_connections = _open_connections
@_atexit.register
def _close_connections():
global _NETWORK_CONNECTIONS
if _NETWORK_CONNECTIONS is None:
return
for con in _NETWORK_CONNECTIONS:
con.close()
del _NETWORK_CONNECTIONS
def _network_call(request):
# NOTE: We communicate with the parent process in json, encoded
# in raw bytes. We do this because native send/recv methods use
# pickle which involves execution of arbitrary code.
_open_connections()
req_con, resp_con = _NETWORK_CONNECTIONS
req_con.send_bytes(_json.dumps(request).encode("utf-8"))
if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
raise Exception(f"Network request timed out: {_json.dumps(request)}")
else:
return _json.loads(resp_con.recv_bytes().decode("utf-8"))

View file

@ -0,0 +1,256 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import multiprocessing
import os
import re
import subprocess
import sys
import tempfile
import textwrap
import time
from dataclasses import dataclass
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import List
from PIL import Image
from .utils import get_code_env_prefix
TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
DIRNAME = Path(__file__).parent
CODE_EXEC_TIMEOUT = 20
CODE_ENV_PREFIX = get_code_env_prefix()
STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
{code}\
"""
TRYEXCEPT_WRAPPER_TEMPLATE = """\
try:
{code}
except:
pass\
"""
def generate_bwrap_command(bind_dirs: List[str]) -> str:
"""
Generate the bwrap command string for binding all
directories in the current directory read-only.
"""
bwrap_args = ""
bwrap_args += "--ro-bind / / "
# Add the --dev flag to mount device files
bwrap_args += "--dev /dev "
for d in bind_dirs:
bwrap_args += f"--bind {d} {d} "
# Add the --unshare-all flag to isolate the sandbox from the rest of the system
bwrap_args += "--unshare-all "
# Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
bwrap_args += "--die-with-parent "
return bwrap_args
@dataclass
class CodeExecutionContext:
matplotlib_dump_dir: str
use_proxy: bool = False
@dataclass
class CodeExecutionRequest:
scripts: List[str]
only_last_cell_stdouterr: bool = True
only_last_cell_fail: bool = True
seed: int = 0
strip_fpaths_in_stderr: bool = True
class CodeExecutor:
def __init__(self, context: CodeExecutionContext):
self.context = context
def execute(self, req: CodeExecutionRequest) -> dict:
scripts = req.scripts
for i in range(len(scripts) - 1):
if req.only_last_cell_stdouterr:
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
code=textwrap.indent(scripts[i], " " * 4)
)
if req.only_last_cell_fail:
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
code=textwrap.indent(scripts[i], " " * 4)
)
# Seeds prefix:
seed = req.seed
seeds_prefix = f"""\
def _set_seeds():
import random
random.seed({seed})
import numpy as np
np.random.seed({seed})
_set_seeds()\
"""
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
with tempfile.TemporaryDirectory() as dpath:
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
code_fpath = os.path.join(dpath, "code.py")
with open(code_fpath, "w") as f:
f.write(script)
try:
python_path = os.environ.get("PYTHONPATH", "")
env = dict(
os.environ,
PYTHONHASHSEED=str(seed),
MPLCONFIGDIR=dpath,
MPLBACKEND="module://matplotlib_custom_backend",
PYTHONPATH=f"{DIRNAME}:{python_path}",
)
stdout, stderr, returncode = do_subprocess(
cmd=cmd,
env=env,
ctx=self.context,
)
stderr = stderr.strip()
if req.strip_fpaths_in_stderr:
pattern = r'File "([^"]+)", line (\d+)'
stderr = re.sub(pattern, r"line \2", stderr)
return {
"process_status": "completed",
"returncode": returncode,
"stdout": stdout.strip(),
"stderr": stderr,
}
except subprocess.TimeoutExpired:
return {
"process_status": "timeout",
"stdout": "Timed out",
"stderr": "Timed out",
}
except Exception as e:
return {
"process_status": "error",
"error_type": type(e).__name__,
"stderr": str(e),
"stdout": str(e),
}
def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_data = response["image_data"]
# Convert the base64 string to a bytes object
images = [base64.b64decode(d["image_base64"]) for d in image_data]
# Create a list of PIL images from the bytes objects
images = [Image.open(BytesIO(img)) for img in images]
# Create a list of image paths
image_paths = []
for i, img in enumerate(images):
# create new directory for each day to better organize data:
dump_dname = datetime.today().strftime("%Y-%m-%d")
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
dump_dpath.mkdir(parents=True, exist_ok=True)
# save image into a file
dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
dump_fpath = dump_dpath / dump_fname
img.save(dump_fpath, "PNG")
image_paths.append(str(dump_fpath))
# this is kind of convoluted, we send back this response to the subprocess which
# prints it out
info = {
"filepath": str(image_paths[-1]),
"mimetype": "image/png",
}
return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
def execute_subprocess_request(request, ctx: CodeExecutionContext):
"Route requests from the subprocess (via network Pipes) to the internet/tools."
if request["type"] == "matplotlib":
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
else:
raise Exception(f'Unrecognised network request type: {request["type"]}')
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
# Create Pipes to be used for any external tool/network requests.
req_r, req_w = multiprocessing.Pipe(duplex=False)
resp_r, resp_w = multiprocessing.Pipe(duplex=False)
cmd += [str(req_w.fileno()), str(resp_r.fileno())]
proc = subprocess.Popen(
cmd,
pass_fds=(req_w.fileno(), resp_r.fileno()),
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=True,
env=env,
)
# Close unnecessary fds.
req_w.close()
resp_r.close()
pipe_close = False
done_read = False
start = time.monotonic()
while proc.poll() is None and not pipe_close:
if req_r.poll(0.1):
# NB: Python pipe semantics for poll and recv mean that
# poll() returns True is a pipe is closed.
# CF old school PEP from '09
# https://bugs.python.org/issue5573
try:
request = json.loads(req_r.recv_bytes().decode("utf-8"))
response = execute_subprocess_request(request, ctx)
resp_w.send_bytes(json.dumps(response).encode("utf-8"))
except EOFError:
# The request pipe is closed - set a marker to exit
# after the next attempt at reading stdout/stderr.
pipe_close = True
try:
# If lots has been printed, pipe might be full but
# proc cannot exit until all the stdout/stderr
# been written/read.
stdout, stderr = proc.communicate(timeout=0.3)
done_read = True
except subprocess.TimeoutExpired:
# The program has not terminated. Ignore it, there
# may be more network/tool requests.
continue
if time.monotonic() - start > CODE_EXEC_TIMEOUT:
proc.terminate()
raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
if not done_read:
# Solve race condition where process terminates before
# we hit the while loop.
stdout, stderr = proc.communicate(timeout=0.3)
resp_w.close()
req_r.close()
return stdout, stderr, proc.returncode

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
A custom Matplotlib backend that overrides the show method to return image bytes.
"""
import base64
import io
import json as _json
import matplotlib
from matplotlib.backend_bases import FigureManagerBase
# Import necessary components from Matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg
class CustomFigureCanvas(FigureCanvasAgg):
def show(self):
# Save the figure to a BytesIO object
buf = io.BytesIO()
self.print_png(buf)
image_bytes = buf.getvalue()
buf.close()
return image_bytes
class CustomFigureManager(FigureManagerBase):
def __init__(self, canvas, num):
super().__init__(canvas, num)
# Mimic module initialization that integrates with the Matplotlib backend system
def _create_figure_manager(num, *args, **kwargs):
"""
Create a custom figure manager instance.
"""
FigureClass = kwargs.pop("FigureClass", None) # noqa: N806
if FigureClass is None:
from matplotlib.figure import Figure
FigureClass = Figure # noqa: N806
fig = FigureClass(*args, **kwargs)
canvas = CustomFigureCanvas(fig)
manager = CustomFigureManager(canvas, num)
return manager
def show():
"""
Handle all figures and potentially return their images as bytes.
This function iterates over all figures registered with the custom backend,
renders them as images in bytes format, and could return a list of bytes objects,
one for each figure, or handle them as needed.
"""
image_data = []
for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers():
# Get the figure from the manager
fig = manager.canvas.figure
buf = io.BytesIO() # Create a buffer for the figure
fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format
buf.seek(0) # Go to the beginning of the buffer
image_bytes = buf.getvalue() # Retrieve bytes value
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
image_data.append({"image_base64": image_base64})
buf.close()
req_con, resp_con = _open_connections()
_json_dump = _json.dumps(
{
"type": "matplotlib",
"image_data": image_data,
}
)
req_con.send_bytes(_json_dump.encode("utf-8"))
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
print(resp)
FigureCanvas = CustomFigureCanvas
FigureManager = CustomFigureManager

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
DIR = os.path.dirname(os.path.realpath(__file__))
CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
CODE_ENV_PREFIX = None
def get_code_env_prefix() -> str:
global CODE_ENV_PREFIX
if CODE_ENV_PREFIX is None:
with open(CODE_ENV_PREFIX_FILE, "r") as f:
CODE_ENV_PREFIX = f.read()
return CODE_ENV_PREFIX

View file

@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_toolchain.agentic_system.safety import ShieldRunnerMixin
from llama_toolchain.inference.api import Message
from llama_toolchain.safety.api.datatypes import ShieldDefinition
from llama_toolchain.safety.api.endpoints import Safety
from .builtin import BaseTool
class SafeTool(BaseTool, ShieldRunnerMixin):
"""A tool that makes other tools safety enabled"""
def __init__(
self,
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
):
self._tool = tool
ShieldRunnerMixin.__init__(
self, safety_api, input_shields=input_shields, output_shields=output_shields
)
def get_name(self) -> str:
# return the name of the wrapped tool
return self._tool.get_name()
async def run(self, messages: List[Message]) -> List[Message]:
if self.input_shields:
await self.run_shields(messages, self.input_shields)
# run the underlying tool
res = await self._tool.run(messages)
if self.output_shields:
await self.run_shields(messages, self.output_shields)
return res
def with_safety(
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
) -> SafeTool:
return SafeTool(
tool,
safety_api,
input_shields=input_shields,
output_shields=output_shields,
)

View file

@ -144,7 +144,11 @@ def prompt_for_config(
nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
elif (
inspect.isclass(field_type)
and issubclass(field_type, BaseModel)
and len(field_type.__fields__) > 0
):
print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,

View file

@ -15,6 +15,7 @@ from strong_typing.schema import json_schema_type
class ApiSurface(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
@json_schema_type
@ -39,14 +40,19 @@ class SourceAdapter(Adapter):
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have
a `get_adapter_instance()` method which will be passed a validated config object
of type `config_class`.""",
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the local implementation
""",
)
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this adapter",
)
adapter_dependencies: List[ApiSurface] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other adapters to provide their functionality",
)
@json_schema_type
@ -56,6 +62,13 @@ class PassthroughApiAdapter(Adapter):
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
""",
)
class Distribution(BaseModel):

View file

@ -7,6 +7,7 @@
import inspect
from typing import Dict, List
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety
@ -29,6 +30,7 @@ def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
protocols = {
ApiSurface.inference: Inference,
ApiSurface.safety: Safety,
ApiSurface.agentic_system: AgenticSystem,
}
for surface, protocol in protocols.items():

View file

@ -8,7 +8,7 @@ import asyncio
import importlib
from typing import Any, Dict
from .datatypes import SourceAdapter
from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter
def instantiate_class_type(fully_qualified_name):
@ -18,9 +18,17 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the ApiSurface
def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]):
def instantiate_adapter(
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
):
module = importlib.import_module(adapter.module)
config_type = instantiate_class_type(adapter.config_class)
config = config_type(**adapter_config)
return asyncio.run(module.get_adapter_impl(config))
return asyncio.run(module.get_adapter_impl(config, deps))
def instantiate_client(adapter: PassthroughApiAdapter, base_url: str):
module = importlib.import_module(adapter.module)
return asyncio.run(module.get_client_impl(base_url))

View file

@ -7,6 +7,8 @@
from functools import lru_cache
from typing import List, Optional
from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters
from llama_toolchain.inference.adapters import available_inference_adapters
from llama_toolchain.safety.adapters import available_safety_adapters
@ -43,10 +45,26 @@ COMMON_DEPENDENCIES = [
]
def client_module(api_surface: ApiSurface) -> str:
return f"llama_toolchain.{api_surface.value}.client"
def passthrough(api_surface: ApiSurface, port: int) -> PassthroughApiAdapter:
return PassthroughApiAdapter(
api_surface=api_surface,
adapter_id=f"{api_surface.value}-passthrough",
base_url=f"http://localhost:{port}",
module=client_module(api_surface),
)
@lru_cache()
def available_distributions() -> List[Distribution]:
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()}
agentic_system_adapters_by_id = {
a.adapter_id: a for a in available_agentic_system_adapters()
}
return [
Distribution(
@ -56,6 +74,9 @@ def available_distributions() -> List[Distribution]:
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-reference"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
ApiSurface.agentic_system: agentic_system_adapters_by_id[
"meta-reference"
],
},
),
Distribution(
@ -76,16 +97,9 @@ def available_distributions() -> List[Distribution]:
"uvicorn",
],
adapters={
ApiSurface.inference: PassthroughApiAdapter(
api_surface=ApiSurface.inference,
adapter_id="inference-passthrough",
base_url="http://localhost:5001",
),
ApiSurface.safety: PassthroughApiAdapter(
api_surface=ApiSurface.safety,
adapter_id="safety-passthrough",
base_url="http://localhost:5001",
),
ApiSurface.inference: passthrough(ApiSurface.inference, 5001),
ApiSurface.safety: passthrough(ApiSurface.safety, 5001),
ApiSurface.agentic_system: passthrough(ApiSurface.agentic_system, 5001),
},
),
Distribution(
@ -95,6 +109,9 @@ def available_distributions() -> List[Distribution]:
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-ollama"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
ApiSurface.agentic_system: agentic_system_adapters_by_id[
"meta-reference"
],
},
),
]

View file

@ -12,7 +12,16 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
get_type_hints,
List,
Optional,
Set,
)
import fire
import httpx
@ -27,9 +36,9 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import PassthroughApiAdapter
from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter
from .distribution import api_surface_endpoints
from .dynamic import instantiate_adapter
from .dynamic import instantiate_adapter, instantiate_client
from .registry import resolve_distribution
@ -213,6 +222,29 @@ def create_dynamic_typed_route(func: Any):
return endpoint
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
by_id = {x.api_surface: x for x in adapters}
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
visited.add(a.api_surface)
for surface in a.adapter_dependencies:
if surface not in visited:
dfs(by_id[surface], visited, stack)
stack.append(a.api_surface)
visited = set()
stack = []
for a in adapters:
if a.api_surface not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
@ -228,7 +260,13 @@ def main(
all_endpoints = api_surface_endpoints()
adapter_configs = config["adapters"]
for surface, adapter in dist.adapters.items():
adapters = topological_sort(dist.adapters.values())
# TODO: split this into two parts, first you resolve all impls
# and then you create the routes.
impls = {}
for adapter in adapters:
surface = adapter.api_surface
if surface.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {surface}. Please add it to the config"
@ -242,8 +280,11 @@ def main(
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else:
impl = instantiate_adapter(adapter, adapter_config)
deps = {surface: impls[surface] for surface in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
impls[surface] = impl
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already

View file

@ -23,6 +23,10 @@ from .api import (
from .event_logger import EventLogger
async def get_client_impl(base_url: str):
return InferenceClient(base_url)
class InferenceClient(Inference):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")

View file

@ -6,11 +6,13 @@
import asyncio
from typing import AsyncIterator, Union
from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface
from .api.config import MetaReferenceImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
@ -27,7 +29,9 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator
async def get_adapter_impl(config: MetaReferenceImplConfig):
async def get_adapter_impl(
config: MetaReferenceImplConfig, _deps: Dict[ApiSurface, Adapter]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"

View file

@ -21,6 +21,10 @@ from .api import (
)
async def get_client_impl(base_url: str):
return SafetyClient(base_url)
class SafetyClient(Safety):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")

View file

@ -6,6 +6,10 @@
import asyncio
from typing import Dict
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface
from .config import SafetyConfig
from .api.endpoints import * # noqa
from .shields import (
@ -19,7 +23,7 @@ from .shields import (
)
async def get_adapter_impl(config: SafetyConfig):
async def get_adapter_impl(config: SafetyConfig, _deps: Dict[ApiSurface, Adapter]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)