forked from phoenix-oss/llama-stack-mirror
API Updates: fleshing out RAG APIs, introduce "llama stack" CLI command (#51)
* add tools to chat completion request * use templates for generating system prompts * Moved ToolPromptFormat and jinja templates to llama_models.llama3.api * <WIP> memory changes - inlined AgenticSystemInstanceConfig so API feels more ergonomic - renamed it to AgentConfig, AgentInstance -> Agent - added a MemoryConfig and `memory` parameter - added `attachments` to input and `output_attachments` to the response - some naming changes * InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool * flesh out memory banks API * agentic loop has a RAG implementation * faiss provider implementation * memory client works * re-work tool definitions, fix FastAPI issues, fix tool regressions * fix agentic_system utils * basic RAG seems to work * small bug fixes for inline attachments * Refactor custom tool execution utilities * Bug fix, show memory retrieval steps in EventLogger * No need for api_key for Remote providers * add special unicode character ↵ to showcase newlines in model prompt templates * remove api.endpoints imports * combine datatypes.py and endpoints.py into api.py * Attachment / add TTL api * split batch_inference from inference * minor import fixes * use a single impl for ChatFormat.decode_assistant_mesage * use interleaved_text_media_as_str() utilityt * Fix api.datatypes imports * Add blobfile for tiktoken * Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly * templates take optional --format={json,function_tag} * Rag Updates * Add `api build` subcommand -- WIP * fix * build + run image seems to work * <WIP> adapters * bunch more work to make adapters work * api build works for conda now * ollama remote adapter works * Several smaller fixes to make adapters work Also, reorganized the pattern of __init__ inside providers so configuration can stay lightweight * llama distribution -> llama stack + containers (WIP) * All the new CLI for api + stack work * Make Fireworks and Together into the Adapter format * Some quick fixes to the CLI behavior to make it consistent * Updated README phew * Update cli_reference.md * llama_toolchain/distribution -> llama_toolchain/core * Add termcolor * update paths * Add a log just for consistency * chmod +x scripts * Fix api dependencies not getting added to configuration * missing import lol * Delete utils.py; move to agentic system * Support downloading of URLs for attachments for code interpreter * Simplify and generalize `llama api build` yay * Update `llama stack configure` to be very simple also * Fix stack start * Allow building an "adhoc" distribution * Remote `llama api []` subcommands * Fixes to llama stack commands and update docs * Update documentation again and add error messages to llama stack start * llama stack start -> llama stack run * Change name of build for less confusion * Add pyopenapi fork to the repository, update RFC assets * Remove conflicting annotation * Added a "--raw" option for model template printing --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com> Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com>
This commit is contained in:
parent
35093c0b6f
commit
7bc7785b0d
141 changed files with 8252 additions and 4032 deletions
|
@ -4,5 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa
|
||||
from .endpoints import * # noqa
|
||||
from .api import * # noqa: F401 F403
|
||||
|
|
413
llama_toolchain/agentic_system/api/api.py
Normal file
413
llama_toolchain/agentic_system/api/api.py
Normal file
|
@ -0,0 +1,413 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.safety.api import * # noqa: F403
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Attachment(BaseModel):
|
||||
content: InterleavedTextMedia | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class AgenticSystemTool(Enum):
|
||||
brave_search = "brave_search"
|
||||
wolfram_alpha = "wolfram_alpha"
|
||||
photogen = "photogen"
|
||||
code_interpreter = "code_interpreter"
|
||||
|
||||
function_call = "function_call"
|
||||
memory = "memory"
|
||||
|
||||
|
||||
class ToolDefinitionCommon(BaseModel):
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BraveSearchToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.brave_search.value] = (
|
||||
AgenticSystemTool.brave_search.value
|
||||
)
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
|
||||
AgenticSystemTool.wolfram_alpha.value
|
||||
)
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PhotogenToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.code_interpreter.value] = (
|
||||
AgenticSystemTool.code_interpreter.value
|
||||
)
|
||||
enable_inline_code_execution: bool = True
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.function_call.value] = (
|
||||
AgenticSystemTool.function_call.value
|
||||
)
|
||||
function_name: str
|
||||
description: str
|
||||
parameters: Dict[str, ToolParamDefinition]
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
|
||||
|
||||
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
|
||||
|
||||
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
AgenticSystemVectorMemoryBankConfig,
|
||||
AgenticSystemKeyValueMemoryBankConfig,
|
||||
AgenticSystemKeywordMemoryBankConfig,
|
||||
AgenticSystemGraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 10
|
||||
|
||||
|
||||
AgenticSystemToolDefinition = Annotated[
|
||||
Union[
|
||||
BraveSearchToolDefinition,
|
||||
WolframAlphaToolDefinition,
|
||||
PhotogenToolDefinition,
|
||||
CodeInterpreterToolDefinition,
|
||||
FunctionCallToolDefinition,
|
||||
MemoryToolDefinition,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class StepType(Enum):
|
||||
inference = "inference"
|
||||
tool_execution = "tool_execution"
|
||||
shield_call = "shield_call"
|
||||
memory_retrieval = "memory_retrieval"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
model_response: CompletionMessage
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolExecutionStep(StepCommon):
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
tool_calls: List[ToolCall]
|
||||
tool_responses: List[ToolResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
response: ShieldResponse
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
step_type: Literal[StepType.memory_retrieval.value] = (
|
||||
StepType.memory_retrieval.value
|
||||
)
|
||||
memory_bank_ids: List[str]
|
||||
inserted_context: InterleavedTextMedia
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
Union[
|
||||
InferenceStep,
|
||||
ToolExecutionStep,
|
||||
ShieldCallStep,
|
||||
MemoryRetrievalStep,
|
||||
],
|
||||
Field(discriminator="step_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Turn(BaseModel):
|
||||
"""A single turn in an interaction with an Agentic System."""
|
||||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
steps: List[Step]
|
||||
output_message: CompletionMessage
|
||||
output_attachments: List[Attachment] = Field(default_factory=list)
|
||||
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Session(BaseModel):
|
||||
"""A single session of an interaction with an Agentic System."""
|
||||
|
||||
session_id: str
|
||||
session_name: str
|
||||
turns: List[Turn]
|
||||
started_at: datetime
|
||||
|
||||
memory_bank: Optional[MemoryBank] = None
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
||||
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentConfig(AgentConfigCommon):
|
||||
model: str
|
||||
instructions: str
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
|
||||
|
||||
class AgenticSystemTurnResponseEventType(Enum):
|
||||
step_start = "step_start"
|
||||
step_complete = "step_complete"
|
||||
step_progress = "step_progress"
|
||||
|
||||
turn_start = "turn_start"
|
||||
turn_complete = "turn_complete"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_start.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_complete.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_details: Step
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_progress.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
model_response_text_delta: Optional[str] = None
|
||||
tool_call_delta: Optional[ToolCallDelta] = None
|
||||
tool_response_text_delta: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
||||
AgenticSystemTurnResponseEventType.turn_start.value
|
||||
)
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
||||
AgenticSystemTurnResponseEventType.turn_complete.value
|
||||
)
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseEvent(BaseModel):
|
||||
"""Streamed agent execution response."""
|
||||
|
||||
payload: Annotated[
|
||||
Union[
|
||||
AgenticSystemTurnResponseStepStartPayload,
|
||||
AgenticSystemTurnResponseStepProgressPayload,
|
||||
AgenticSystemTurnResponseStepCompletePayload,
|
||||
AgenticSystemTurnResponseTurnStartPayload,
|
||||
AgenticSystemTurnResponseTurnCompletePayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemCreateResponse(BaseModel):
|
||||
agent_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemSessionCreateResponse(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||
agent_id: str
|
||||
session_id: str
|
||||
|
||||
# TODO: figure out how we can simplify this and make why
|
||||
# ToolResponseMessage needs to be here (it is function call
|
||||
# execution from outside the system)
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
attachments: Optional[List[Attachment]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStreamChunk(BaseModel):
|
||||
event: AgenticSystemTurnResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemStepResponse(BaseModel):
|
||||
step: Step
|
||||
|
||||
|
||||
class AgenticSystem(Protocol):
|
||||
@webmethod(route="/agentic_system/create")
|
||||
async def create_agentic_system(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgenticSystemCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/create")
|
||||
async def create_agentic_system_turn(
|
||||
self,
|
||||
request: AgenticSystemTurnCreateRequest,
|
||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/get")
|
||||
async def get_agentic_system_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn: ...
|
||||
|
||||
@webmethod(route="/agentic_system/step/get")
|
||||
async def get_agentic_system_step(
|
||||
self, agent_id: str, turn_id: str, step_id: str
|
||||
) -> AgenticSystemStepResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/create")
|
||||
async def create_agentic_system_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgenticSystemSessionCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/get")
|
||||
async def get_agentic_system_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/delete")
|
||||
async def delete_agentic_system_session(
|
||||
self, agent_id: str, session_id: str
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/agentic_system/delete")
|
||||
async def delete_agentic_system(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
|
@ -1,234 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
||||
from llama_toolchain.memory.api.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemToolDefinition(ToolDefinition):
|
||||
execution_config: Optional[RestAPIExecutionConfig] = None
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class StepType(Enum):
|
||||
inference = "inference"
|
||||
tool_execution = "tool_execution"
|
||||
shield_call = "shield_call"
|
||||
memory_retrieval = "memory_retrieval"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
model_response: CompletionMessage
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolExecutionStep(StepCommon):
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
tool_calls: List[ToolCall]
|
||||
tool_responses: List[ToolResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
response: ShieldResponse
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
step_type: Literal[StepType.memory_retrieval.value] = (
|
||||
StepType.memory_retrieval.value
|
||||
)
|
||||
memory_bank_ids: List[str]
|
||||
documents: List[MemoryBankDocument]
|
||||
scores: List[float]
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
Union[
|
||||
InferenceStep,
|
||||
ToolExecutionStep,
|
||||
ShieldCallStep,
|
||||
MemoryRetrievalStep,
|
||||
],
|
||||
Field(discriminator="step_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Turn(BaseModel):
|
||||
"""A single turn in an interaction with an Agentic System."""
|
||||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
steps: List[Step]
|
||||
output_message: CompletionMessage
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Session(BaseModel):
|
||||
"""A single session of an interaction with an Agentic System."""
|
||||
|
||||
session_id: str
|
||||
session_name: str
|
||||
turns: List[Turn]
|
||||
started_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolPromptFormat(Enum):
|
||||
"""This Enum refers to the prompt format for calling zero shot tools
|
||||
|
||||
`json` --
|
||||
Refers to the json format for calling tools.
|
||||
The json format takes the form like
|
||||
{
|
||||
"type": "function",
|
||||
"function" : {
|
||||
"name": "function_name",
|
||||
"description": "function_description",
|
||||
"parameters": {...}
|
||||
}
|
||||
}
|
||||
|
||||
`function_tag` --
|
||||
This is an example of how you could define
|
||||
your own user defined format for making tool calls.
|
||||
The function_tag format looks like this,
|
||||
<function=function_name>(parameters)</function>
|
||||
|
||||
The detailed prompts for each of these formats are defined in `system_prompt.py`
|
||||
"""
|
||||
|
||||
json = "json"
|
||||
function_tag = "function_tag"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemInstanceConfig(BaseModel):
|
||||
instructions: str
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
# zero-shot or built-in tool configurations as input to the model
|
||||
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
||||
# if you completely want to replace the messages prefixed by the system,
|
||||
# this is debug only
|
||||
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
class AgenticSystemTurnResponseEventType(Enum):
|
||||
step_start = "step_start"
|
||||
step_complete = "step_complete"
|
||||
step_progress = "step_progress"
|
||||
|
||||
turn_start = "turn_start"
|
||||
turn_complete = "turn_complete"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_start.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_complete.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_details: Step
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_progress.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
model_response_text_delta: Optional[str] = None
|
||||
tool_call_delta: Optional[ToolCallDelta] = None
|
||||
tool_response_text_delta: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
||||
AgenticSystemTurnResponseEventType.turn_start.value
|
||||
)
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
||||
AgenticSystemTurnResponseEventType.turn_complete.value
|
||||
)
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseEvent(BaseModel):
|
||||
"""Streamed agent execution response."""
|
||||
|
||||
payload: Annotated[
|
||||
Union[
|
||||
AgenticSystemTurnResponseStepStartPayload,
|
||||
AgenticSystemTurnResponseStepProgressPayload,
|
||||
AgenticSystemTurnResponseStepCompletePayload,
|
||||
AgenticSystemTurnResponseTurnStartPayload,
|
||||
AgenticSystemTurnResponseTurnCompletePayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
]
|
|
@ -1,127 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F403
|
||||
from typing import Protocol
|
||||
|
||||
# this dependency is annoying and we need a forked up version anyway
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemCreateRequest(BaseModel):
|
||||
model: str
|
||||
instance_config: AgenticSystemInstanceConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemCreateResponse(BaseModel):
|
||||
system_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemSessionCreateRequest(BaseModel):
|
||||
system_id: str
|
||||
session_name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemSessionCreateResponse(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
# what's the URI?
|
||||
class AgenticSystemTurnCreateRequest(BaseModel):
|
||||
system_id: str
|
||||
session_id: str
|
||||
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
|
||||
stream: Optional[bool] = False
|
||||
override_config: Optional[AgenticSystemInstanceConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStreamChunk(BaseModel):
|
||||
event: AgenticSystemTurnResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemStepResponse(BaseModel):
|
||||
step: Step
|
||||
|
||||
|
||||
class AgenticSystem(Protocol):
|
||||
@webmethod(route="/agentic_system/create")
|
||||
async def create_agentic_system(
|
||||
self,
|
||||
request: AgenticSystemCreateRequest,
|
||||
) -> AgenticSystemCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/create")
|
||||
async def create_agentic_system_turn(
|
||||
self,
|
||||
request: AgenticSystemTurnCreateRequest,
|
||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
||||
|
||||
@webmethod(route="/agentic_system/turn/get")
|
||||
async def get_agentic_system_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn: ...
|
||||
|
||||
@webmethod(route="/agentic_system/step/get")
|
||||
async def get_agentic_system_step(
|
||||
self, agent_id: str, turn_id: str, step_id: str
|
||||
) -> AgenticSystemStepResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/create")
|
||||
async def create_agentic_system_session(
|
||||
self,
|
||||
request: AgenticSystemSessionCreateRequest,
|
||||
) -> AgenticSystemSessionCreateResponse: ...
|
||||
|
||||
@webmethod(route="/agentic_system/memory_bank/attach")
|
||||
async def attach_memory_bank_to_agentic_system(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
memory_bank_ids: List[str],
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/agentic_system/memory_bank/detach")
|
||||
async def detach_memory_bank_from_agentic_system(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
memory_bank_ids: List[str],
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/get")
|
||||
async def get_agentic_system_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session: ...
|
||||
|
||||
@webmethod(route="/agentic_system/session/delete")
|
||||
async def delete_agentic_system_session(
|
||||
self, agent_id: str, session_id: str
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/agentic_system/delete")
|
||||
async def delete_agentic_system(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
|
@ -6,38 +6,28 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import fire
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
SamplingParams,
|
||||
ToolParamDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.agentic_system.event_logger import EventLogger
|
||||
from .api import (
|
||||
AgenticSystem,
|
||||
AgenticSystemCreateRequest,
|
||||
AgenticSystemCreateResponse,
|
||||
AgenticSystemInstanceConfig,
|
||||
AgenticSystemSessionCreateRequest,
|
||||
AgenticSystemSessionCreateResponse,
|
||||
AgenticSystemToolDefinition,
|
||||
AgenticSystemTurnCreateRequest,
|
||||
AgenticSystemTurnResponseStreamChunk,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||
|
||||
from .api import * # noqa: F403
|
||||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
async def get_client_impl(base_url: str):
|
||||
return AgenticSystemClient(base_url)
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps):
|
||||
return AgenticSystemClient(config.url)
|
||||
|
||||
|
||||
def encodable_dict(d: BaseModel):
|
||||
return json.loads(d.json())
|
||||
|
||||
|
||||
class AgenticSystemClient(AgenticSystem):
|
||||
|
@ -45,12 +35,14 @@ class AgenticSystemClient(AgenticSystem):
|
|||
self.base_url = base_url
|
||||
|
||||
async def create_agentic_system(
|
||||
self, request: AgenticSystemCreateRequest
|
||||
self, agent_config: AgentConfig
|
||||
) -> AgenticSystemCreateResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/agentic_system/create",
|
||||
data=request.json(),
|
||||
json={
|
||||
"agent_config": encodable_dict(agent_config),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
@ -58,12 +50,16 @@ class AgenticSystemClient(AgenticSystem):
|
|||
|
||||
async def create_agentic_system_session(
|
||||
self,
|
||||
request: AgenticSystemSessionCreateRequest,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgenticSystemSessionCreateResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/agentic_system/session/create",
|
||||
data=request.json(),
|
||||
json={
|
||||
"agent_id": agent_id,
|
||||
"session_name": session_name,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
@ -77,7 +73,9 @@ class AgenticSystemClient(AgenticSystem):
|
|||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/agentic_system/turn/create",
|
||||
data=request.json(),
|
||||
json={
|
||||
"request": encodable_dict(request),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as response:
|
||||
|
@ -85,6 +83,10 @@ class AgenticSystemClient(AgenticSystem):
|
|||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
**json.loads(data)
|
||||
)
|
||||
|
@ -93,24 +95,52 @@ class AgenticSystemClient(AgenticSystem):
|
|||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
|
||||
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
||||
agent_config = AgentConfig(
|
||||
model="Meta-Llama3.1-8B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
|
||||
tools=tool_definitions,
|
||||
tool_choice=ToolChoice.auto,
|
||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
)
|
||||
|
||||
create_response = await api.create_agentic_system(agent_config)
|
||||
session_response = await api.create_agentic_system_session(
|
||||
agent_id=create_response.agent_id,
|
||||
session_name="test_session",
|
||||
)
|
||||
|
||||
for content in user_prompts:
|
||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||
iterator = api.create_agentic_system_turn(
|
||||
AgenticSystemTurnCreateRequest(
|
||||
agent_id=create_response.agent_id,
|
||||
session_id=session_response.session_id,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
)
|
||||
)
|
||||
|
||||
async for event, log in EventLogger().log(iterator):
|
||||
if log is not None:
|
||||
log.print()
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
# client to test remote impl of agentic system
|
||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
),
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
),
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
),
|
||||
BraveSearchToolDefinition(),
|
||||
WolframAlphaToolDefinition(),
|
||||
CodeInterpreterToolDefinition(),
|
||||
]
|
||||
tool_definitions += [
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name="get_boiling_point",
|
||||
FunctionCallToolDefinition(
|
||||
function_name="get_boiling_point",
|
||||
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
parameters={
|
||||
"liquid_name": ToolParamDefinition(
|
||||
|
@ -127,30 +157,6 @@ async def run_main(host: str, port: int):
|
|||
),
|
||||
]
|
||||
|
||||
create_request = AgenticSystemCreateRequest(
|
||||
model="Meta-Llama3.1-8B-Instruct",
|
||||
instance_config=AgenticSystemInstanceConfig(
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params=SamplingParams(),
|
||||
available_tools=tool_definitions,
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
debug_prefix_messages=[],
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
),
|
||||
)
|
||||
|
||||
create_response = await api.create_agentic_system(create_request)
|
||||
print(create_response)
|
||||
|
||||
session_response = await api.create_agentic_system_session(
|
||||
AgenticSystemSessionCreateRequest(
|
||||
system_id=create_response.system_id,
|
||||
session_name="test_session",
|
||||
)
|
||||
)
|
||||
print(session_response)
|
||||
|
||||
user_prompts = [
|
||||
"Who are you?",
|
||||
"what is the 100th prime number?",
|
||||
|
@ -158,26 +164,51 @@ async def run_main(host: str, port: int):
|
|||
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
||||
"What is the boiling point of polyjuicepotion ?",
|
||||
]
|
||||
for content in user_prompts:
|
||||
cprint(f"User> {content}", color="blue")
|
||||
iterator = api.create_agentic_system_turn(
|
||||
AgenticSystemTurnCreateRequest(
|
||||
system_id=create_response.system_id,
|
||||
session_id=session_response.session_id,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
await _run_agent(api, tool_definitions, user_prompts)
|
||||
|
||||
|
||||
async def run_rag(host: str, port: int):
|
||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
attachments = [
|
||||
Attachment(
|
||||
content=URL(
|
||||
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
||||
),
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
async for event, log in EventLogger().log(iterator):
|
||||
if log is not None:
|
||||
log.print()
|
||||
# Alternatively, you can pre-populate the memory bank with documents for example,
|
||||
# using `llama_toolchain.memory.client`. Then you can grab the bank_id
|
||||
# from the output of that run.
|
||||
tool_definitions = [
|
||||
MemoryToolDefinition(
|
||||
max_tokens_in_context=2048,
|
||||
memory_bank_configs=[],
|
||||
),
|
||||
]
|
||||
|
||||
user_prompts = [
|
||||
"How do I use Lora?",
|
||||
"Tell me briefly about llama3 and torchtune",
|
||||
]
|
||||
|
||||
await _run_agent(api, tool_definitions, user_prompts, attachments)
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
def main(host: str, port: int, rag: bool = False):
|
||||
fn = run_rag if rag else run_main
|
||||
asyncio.run(fn(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolResponseMessage
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||
|
||||
from termcolor import cprint
|
||||
|
@ -44,7 +44,12 @@ EventType = AgenticSystemTurnResponseEventType
|
|||
|
||||
|
||||
class EventLogger:
|
||||
async def log(self, event_generator, stream=True):
|
||||
async def log(
|
||||
self,
|
||||
event_generator,
|
||||
stream=True,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
):
|
||||
previous_event_type = None
|
||||
previous_step_type = None
|
||||
|
||||
|
@ -132,7 +137,9 @@ class EventLogger:
|
|||
if event_type == EventType.step_complete.value:
|
||||
response = event.payload.step_details.model_response
|
||||
if response.tool_calls:
|
||||
content = ToolUtils.encode_tool_call(response.tool_calls[0])
|
||||
content = ToolUtils.encode_tool_call(
|
||||
response.tool_calls[0], tool_prompt_format
|
||||
)
|
||||
else:
|
||||
content = response.content
|
||||
yield event, LogEvent(
|
||||
|
@ -162,5 +169,19 @@ class EventLogger:
|
|||
color="green",
|
||||
)
|
||||
|
||||
if (
|
||||
step_type == StepType.memory_retrieval
|
||||
and event_type == EventType.step_complete.value
|
||||
):
|
||||
details = event.payload.step_details
|
||||
content = interleaved_text_media_as_str(details.inserted_context)
|
||||
content = content[:200] + "..." if len(content) > 200 else content
|
||||
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>",
|
||||
color="cyan",
|
||||
)
|
||||
|
||||
preivous_event_type = event_type
|
||||
previous_step_type = step_type
|
||||
|
|
96
llama_toolchain/agentic_system/execute_with_custom_tools.py
Normal file
96
llama_toolchain/agentic_system/execute_with_custom_tools.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
from llama_toolchain.safety.api import * # noqa: F403
|
||||
|
||||
from llama_toolchain.agentic_system.api import (
|
||||
AgenticSystemTurnResponseEventType as EventType,
|
||||
)
|
||||
from llama_toolchain.tools.custom.datatypes import CustomTool
|
||||
|
||||
|
||||
class AgentWithCustomToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
api: AgenticSystem,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
agent_config: AgentConfig,
|
||||
custom_tools: List[CustomTool],
|
||||
):
|
||||
self.api = api
|
||||
self.agent_id = agent_id
|
||||
self.session_id = session_id
|
||||
self.agent_config = agent_config
|
||||
self.custom_tools = custom_tools
|
||||
|
||||
async def execute_turn(
|
||||
self,
|
||||
messages: List[Message],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
max_iters: int = 5,
|
||||
stream: bool = True,
|
||||
) -> AsyncGenerator:
|
||||
tools_dict = {t.get_name(): t for t in self.custom_tools}
|
||||
|
||||
current_messages = messages.copy()
|
||||
n_iter = 0
|
||||
while n_iter < max_iters:
|
||||
n_iter += 1
|
||||
|
||||
request = AgenticSystemTurnCreateRequest(
|
||||
agent_id=self.agent_id,
|
||||
session_id=self.session_id,
|
||||
messages=current_messages,
|
||||
attachments=attachments,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
turn = None
|
||||
async for chunk in self.api.create_agentic_system_turn(request):
|
||||
if chunk.event.payload.event_type != EventType.turn_complete.value:
|
||||
yield chunk
|
||||
else:
|
||||
turn = chunk.event.payload.turn
|
||||
|
||||
message = turn.output_message
|
||||
if len(message.tool_calls) == 0:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
if message.stop_reason == StopReason.out_of_tokens:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
if tool_call.tool_name not in tools_dict:
|
||||
m = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
|
||||
)
|
||||
next_message = m
|
||||
else:
|
||||
tool = tools_dict[tool_call.tool_name]
|
||||
result_messages = await execute_custom_tool(tool, message)
|
||||
next_message = result_messages[0]
|
||||
|
||||
yield next_message
|
||||
current_messages = [next_message]
|
||||
|
||||
|
||||
async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]:
|
||||
result_messages = await tool.run([message])
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
), f"Expected single message, got {len(result_messages)}"
|
||||
|
||||
return result_messages
|
|
@ -4,5 +4,27 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .agentic_system import get_provider_impl # noqa
|
||||
from .config import AgenticSystemConfig # noqa
|
||||
from typing import Dict
|
||||
|
||||
from llama_toolchain.core.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
from .agentic_system import MetaReferenceAgenticSystemImpl
|
||||
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceAgenticSystemImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.memory],
|
||||
deps[Api.safety],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,111 +4,111 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import string
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.agentic_system.api.datatypes import (
|
||||
AgenticSystemInstanceConfig,
|
||||
AgenticSystemTurnResponseEvent,
|
||||
AgenticSystemTurnResponseEventType,
|
||||
AgenticSystemTurnResponseStepCompletePayload,
|
||||
AgenticSystemTurnResponseStepProgressPayload,
|
||||
AgenticSystemTurnResponseStepStartPayload,
|
||||
AgenticSystemTurnResponseTurnCompletePayload,
|
||||
AgenticSystemTurnResponseTurnStartPayload,
|
||||
InferenceStep,
|
||||
Session,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
ToolPromptFormat,
|
||||
Turn,
|
||||
)
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
from llama_toolchain.safety.api import * # noqa: F403
|
||||
|
||||
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
|
||||
|
||||
from llama_toolchain.inference.api.datatypes import (
|
||||
Attachment,
|
||||
BuiltinTool,
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
Message,
|
||||
Role,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
URL,
|
||||
from llama_toolchain.tools.base import BaseTool
|
||||
from llama_toolchain.tools.builtin import (
|
||||
interpret_content_as_attachment,
|
||||
SingleMessageBuiltinTool,
|
||||
)
|
||||
from llama_toolchain.safety.api import Safety
|
||||
from llama_toolchain.safety.api.datatypes import (
|
||||
BuiltinShield,
|
||||
ShieldDefinition,
|
||||
ShieldResponse,
|
||||
)
|
||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
||||
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
from .system_prompt import get_agentic_prefix_messages
|
||||
from .tools.base import BaseTool
|
||||
from .tools.builtin import SingleMessageBuiltinTool
|
||||
|
||||
|
||||
class AgentInstance(ShieldRunnerMixin):
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||
)
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
system_id: int,
|
||||
instance_config: AgenticSystemInstanceConfig,
|
||||
model: str,
|
||||
agent_config: AgentConfig,
|
||||
inference_api: Inference,
|
||||
memory_api: Memory,
|
||||
safety_api: Safety,
|
||||
builtin_tools: List[SingleMessageBuiltinTool],
|
||||
custom_tool_definitions: List[ToolDefinition],
|
||||
input_shields: List[ShieldDefinition],
|
||||
output_shields: List[ShieldDefinition],
|
||||
max_infer_iters: int = 10,
|
||||
prefix_messages: Optional[List[Message]] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
):
|
||||
self.system_id = system_id
|
||||
self.instance_config = instance_config
|
||||
|
||||
self.model = model
|
||||
self.agent_config = agent_config
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.safety_api = safety_api
|
||||
|
||||
if prefix_messages is not None and len(prefix_messages) > 0:
|
||||
self.prefix_messages = prefix_messages
|
||||
else:
|
||||
self.prefix_messages = get_agentic_prefix_messages(
|
||||
builtin_tools,
|
||||
custom_tool_definitions,
|
||||
tool_prompt_format,
|
||||
)
|
||||
|
||||
for m in self.prefix_messages:
|
||||
print(m.content)
|
||||
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
||||
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
self.sessions = {}
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
safety_api,
|
||||
input_shields=input_shields,
|
||||
output_shields=output_shields,
|
||||
input_shields=agent_config.input_shields,
|
||||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
shutil.rmtree(self.tempdir)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||
messages = []
|
||||
|
||||
# We do not want to keep adding RAG context to the input messages
|
||||
# May be this should be a parameter of the agentic instance
|
||||
# that can define its behavior in a custom way
|
||||
for m in turn.input_messages:
|
||||
msg = m.copy()
|
||||
if isinstance(msg, UserMessage):
|
||||
msg.context = None
|
||||
messages.append(msg)
|
||||
|
||||
# messages.extend(turn.input_messages)
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
messages.append(step.model_response)
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
tool_name=response.tool_name,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
response = step.response
|
||||
if response.is_violation:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=response.violation_return_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
# print_dialog(messages)
|
||||
return messages
|
||||
|
||||
def create_session(self, name: str) -> Session:
|
||||
session_id = str(uuid.uuid4())
|
||||
session = Session(
|
||||
|
@ -131,32 +131,7 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
|
||||
messages = []
|
||||
for i, turn in enumerate(session.turns):
|
||||
# print(f"turn {i}")
|
||||
# print_dialog(turn.input_messages)
|
||||
messages.extend(turn.input_messages)
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
messages.append(step.model_response)
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
tool_name=response.tool_name,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
response = step.response
|
||||
if response.is_violation:
|
||||
# TODO: Properly persist the
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=response.violation_return_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
||||
messages.extend(request.messages)
|
||||
|
||||
|
@ -164,7 +139,6 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
# print_dialog(messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
params = self.instance_config.sampling_params
|
||||
start_time = datetime.now()
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
|
@ -177,12 +151,12 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session=session,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
max_gen_len=params.max_tokens,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
cprint(
|
||||
|
@ -227,6 +201,53 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
)
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
self,
|
||||
session: Session,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
session, turn_id, input_messages, attachments, sampling_params, stream
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
elif isinstance(res, CompletionMessage):
|
||||
final_response = res
|
||||
break
|
||||
else:
|
||||
yield res
|
||||
|
||||
assert final_response is not None
|
||||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
yield final_response
|
||||
|
||||
async def run_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
|
@ -288,65 +309,62 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
stream: bool = False,
|
||||
max_gen_len: Optional[int] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
turn_id, input_messages, temperature, top_p, stream, max_gen_len
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
elif isinstance(res, CompletionMessage):
|
||||
final_response = res
|
||||
break
|
||||
else:
|
||||
yield res
|
||||
|
||||
assert final_response is not None
|
||||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
yield final_response
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
session: Session,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
max_gen_len: Optional[int] = None,
|
||||
) -> AsyncGenerator:
|
||||
input_messages = preprocess_dialog(input_messages, self.prefix_messages)
|
||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||
need_rag_context = await self._should_retrieve_context(
|
||||
input_messages, attachments
|
||||
)
|
||||
if need_rag_context:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
attachments = []
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session, input_messages, attachments
|
||||
)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
step_details=MemoryRetrievalStep(
|
||||
turn_id=turn_id,
|
||||
step_id=step_id,
|
||||
memory_bank_ids=bank_ids,
|
||||
inserted_context=rag_context or "",
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if rag_context:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = "\n".join(rag_context)
|
||||
|
||||
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||
msg = await attachment_message(self.tempdir, urls)
|
||||
input_messages.append(msg)
|
||||
|
||||
output_attachments = []
|
||||
|
||||
n_iter = 0
|
||||
while True:
|
||||
|
@ -369,17 +387,13 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
# where are the available tools?
|
||||
req = ChatCompletionRequest(
|
||||
model=self.model,
|
||||
model=self.agent_config.model,
|
||||
messages=input_messages,
|
||||
available_tools=self.instance_config.available_tools,
|
||||
tools=self._get_tools(),
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_gen_len,
|
||||
),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
|
@ -464,7 +478,8 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
|
||||
if len(message.tool_calls) == 0:
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
if len(attachments) > 0:
|
||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += attachments
|
||||
else:
|
||||
|
@ -572,63 +587,175 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
yield False
|
||||
return
|
||||
|
||||
if isinstance(result_message.content, Attachment):
|
||||
if out_attachment := interpret_content_as_attachment(
|
||||
result_message.content
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
attachments.append(result_message.content)
|
||||
elif isinstance(result_message.content, list) or isinstance(
|
||||
result_message.content, tuple
|
||||
):
|
||||
for c in result_message.content:
|
||||
if isinstance(c, Attachment):
|
||||
attachments.append(c)
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
n_iter += 1
|
||||
|
||||
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
|
||||
if session.memory_bank is None:
|
||||
session.memory_bank = await self.memory_api.create_memory_bank(
|
||||
name=f"memory_bank_{session.session_id}",
|
||||
config=VectorMemoryBankConfig(
|
||||
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
)
|
||||
|
||||
def attachment_message(url: URL) -> ToolResponseMessage:
|
||||
uri = url.uri
|
||||
assert uri.startswith("file://")
|
||||
filepath = uri[len("file://") :]
|
||||
return session.memory_bank
|
||||
|
||||
async def _should_retrieve_context(
|
||||
self, messages: List[Message], attachments: List[Attachment]
|
||||
) -> bool:
|
||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||
if attachments:
|
||||
if (
|
||||
AgenticSystemTool.code_interpreter.value in enabled_tools
|
||||
and self.agent_config.tool_choice == ToolChoice.required
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
return AgenticSystemTool.memory.value in enabled_tools
|
||||
|
||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||
for t in self.agent_config.tools:
|
||||
if t.type == AgenticSystemTool.memory.value:
|
||||
return t
|
||||
|
||||
return None
|
||||
|
||||
async def _retrieve_context(
|
||||
self, session: Session, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
assert memory is not None, "Memory tool not configured"
|
||||
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
||||
|
||||
if attachments:
|
||||
bank = await self._ensure_memory_bank(session)
|
||||
bank_ids.append(bank.bank_id)
|
||||
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for a in attachments
|
||||
]
|
||||
await self.memory_api.insert_documents(bank.bank_id, documents)
|
||||
elif session.memory_bank:
|
||||
bank_ids.append(session.memory_bank.bank_id)
|
||||
|
||||
if not bank_ids:
|
||||
# this can happen if the per-session memory bank is not yet populated
|
||||
# (i.e., no prior turns uploaded an Attachment)
|
||||
return None, []
|
||||
|
||||
query = " ".join(m.content for m in messages)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": 5,
|
||||
},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
]
|
||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: memory.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > memory.max_tokens_in_context:
|
||||
cprint(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
"red",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
|
||||
return [
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
], bank_ids
|
||||
|
||||
def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
for t in self.agent_config.tools:
|
||||
if isinstance(t, BraveSearchToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
||||
elif isinstance(t, WolframAlphaToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
||||
elif isinstance(t, PhotogenToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
|
||||
elif isinstance(t, CodeInterpreterToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
|
||||
elif isinstance(t, FunctionCallToolDefinition):
|
||||
ret.append(
|
||||
ToolDefinition(
|
||||
tool_name=t.function_name,
|
||||
description=t.description,
|
||||
parameters=t.parameters,
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||
content = []
|
||||
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
if uri.startswith("file://"):
|
||||
filepath = uri[len("file://") :]
|
||||
elif uri.startswith("http"):
|
||||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
print(f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
resp = r.text
|
||||
with open(filepath, "w") as fp:
|
||||
fp.write(resp)
|
||||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
content.append(f'# There is a file accessible to you at "{filepath}"\n')
|
||||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
content=f'# There is a file accessible to you at "{filepath}"',
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def preprocess_dialog(
|
||||
messages: List[Message], prefix_messages: List[Message]
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Preprocesses the dialog by removing the system message and
|
||||
adding the system message to the beginning of the dialog.
|
||||
"""
|
||||
ret = prefix_messages.copy()
|
||||
|
||||
for m in messages:
|
||||
if m.role == Role.system.value:
|
||||
continue
|
||||
|
||||
# NOTE: the ideal behavior is to use `file_path = ...` but that
|
||||
# means we need to have stateful execution o f code which we currently
|
||||
# do not have.
|
||||
if isinstance(m.content, Attachment):
|
||||
ret.append(attachment_message(m.content.url))
|
||||
elif isinstance(m.content, list):
|
||||
for c in m.content:
|
||||
if isinstance(c, Attachment):
|
||||
ret.append(attachment_message(c.url))
|
||||
|
||||
ret.append(m)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
||||
) -> List[ToolResponseMessage]:
|
||||
|
|
|
@ -8,62 +8,42 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Dict
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import Inference
|
||||
from llama_toolchain.inference.api.datatypes import BuiltinTool
|
||||
from llama_toolchain.memory.api import Memory
|
||||
from llama_toolchain.safety.api import Safety
|
||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
||||
from llama_toolchain.agentic_system.api import (
|
||||
AgenticSystem,
|
||||
AgenticSystemCreateRequest,
|
||||
AgenticSystemCreateResponse,
|
||||
AgenticSystemSessionCreateRequest,
|
||||
AgenticSystemSessionCreateResponse,
|
||||
AgenticSystemTurnCreateRequest,
|
||||
)
|
||||
|
||||
from .agent_instance import AgentInstance
|
||||
|
||||
from .config import AgenticSystemConfig
|
||||
|
||||
from .tools.builtin import (
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
from llama_toolchain.tools.builtin import (
|
||||
BraveSearchTool,
|
||||
CodeInterpreterTool,
|
||||
PhotogenTool,
|
||||
WolframAlphaTool,
|
||||
)
|
||||
from .tools.safety import with_safety
|
||||
from llama_toolchain.tools.safety import with_safety
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]):
|
||||
assert isinstance(
|
||||
config, AgenticSystemConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceAgenticSystemImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.safety],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
AGENT_INSTANCES_BY_ID = {}
|
||||
|
||||
|
||||
class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||
def __init__(
|
||||
self, config: AgenticSystemConfig, inference_api: Inference, safety_api: Safety
|
||||
self,
|
||||
config: MetaReferenceImplConfig,
|
||||
inference_api: Inference,
|
||||
memory_api: Memory,
|
||||
safety_api: Safety,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.safety_api = safety_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
@ -71,69 +51,61 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
|
||||
async def create_agentic_system(
|
||||
self,
|
||||
request: AgenticSystemCreateRequest,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgenticSystemCreateResponse:
|
||||
system_id = str(uuid.uuid4())
|
||||
agent_id = str(uuid.uuid4())
|
||||
|
||||
builtin_tools = []
|
||||
custom_tool_definitions = []
|
||||
cfg = request.instance_config
|
||||
for dfn in cfg.available_tools:
|
||||
if isinstance(dfn.tool_name, BuiltinTool):
|
||||
if dfn.tool_name == BuiltinTool.wolfram_alpha:
|
||||
key = self.config.wolfram_api_key
|
||||
if not key:
|
||||
raise ValueError("Wolfram API key not defined in config")
|
||||
tool = WolframAlphaTool(key)
|
||||
elif dfn.tool_name == BuiltinTool.brave_search:
|
||||
key = self.config.brave_search_api_key
|
||||
if not key:
|
||||
raise ValueError("Brave API key not defined in config")
|
||||
tool = BraveSearchTool(key)
|
||||
elif dfn.tool_name == BuiltinTool.code_interpreter:
|
||||
tool = CodeInterpreterTool()
|
||||
elif dfn.tool_name == BuiltinTool.photogen:
|
||||
tool = PhotogenTool(
|
||||
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown builtin tool: {dfn.tool_name}")
|
||||
|
||||
builtin_tools.append(
|
||||
with_safety(
|
||||
tool, self.safety_api, dfn.input_shields, dfn.output_shields
|
||||
)
|
||||
for tool_defn in agent_config.tools:
|
||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
||||
key = self.config.wolfram_api_key
|
||||
if not key:
|
||||
raise ValueError("Wolfram API key not defined in config")
|
||||
tool = WolframAlphaTool(key)
|
||||
elif isinstance(tool_defn, BraveSearchToolDefinition):
|
||||
key = self.config.brave_search_api_key
|
||||
if not key:
|
||||
raise ValueError("Brave API key not defined in config")
|
||||
tool = BraveSearchTool(key)
|
||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||
tool = CodeInterpreterTool()
|
||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||
tool = PhotogenTool(
|
||||
dump_dir="/tmp/photogen_dump_" + os.environ["USER"],
|
||||
)
|
||||
else:
|
||||
custom_tool_definitions.append(dfn)
|
||||
continue
|
||||
|
||||
AGENT_INSTANCES_BY_ID[system_id] = AgentInstance(
|
||||
system_id=system_id,
|
||||
instance_config=request.instance_config,
|
||||
model=request.model,
|
||||
builtin_tools.append(
|
||||
with_safety(
|
||||
tool,
|
||||
self.safety_api,
|
||||
tool_defn.input_shields,
|
||||
tool_defn.output_shields,
|
||||
)
|
||||
)
|
||||
|
||||
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
|
||||
agent_config=agent_config,
|
||||
inference_api=self.inference_api,
|
||||
builtin_tools=builtin_tools,
|
||||
custom_tool_definitions=custom_tool_definitions,
|
||||
safety_api=self.safety_api,
|
||||
input_shields=cfg.input_shields,
|
||||
output_shields=cfg.output_shields,
|
||||
prefix_messages=cfg.debug_prefix_messages,
|
||||
tool_prompt_format=cfg.tool_prompt_format,
|
||||
memory_api=self.memory_api,
|
||||
builtin_tools=builtin_tools,
|
||||
)
|
||||
|
||||
return AgenticSystemCreateResponse(
|
||||
system_id=system_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def create_agentic_system_session(
|
||||
self,
|
||||
request: AgenticSystemSessionCreateRequest,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgenticSystemSessionCreateResponse:
|
||||
system_id = request.system_id
|
||||
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
|
||||
agent = AGENT_INSTANCES_BY_ID[system_id]
|
||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||
|
||||
session = agent.create_session(request.session_name)
|
||||
session = agent.create_session(session_name)
|
||||
return AgenticSystemSessionCreateResponse(
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
@ -142,9 +114,9 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
self,
|
||||
request: AgenticSystemTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
system_id = request.system_id
|
||||
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
|
||||
agent = AGENT_INSTANCES_BY_ID[system_id]
|
||||
agent_id = request.agent_id
|
||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||
|
||||
assert (
|
||||
request.session_id in agent.sessions
|
||||
|
|
|
@ -9,6 +9,6 @@ from typing import Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgenticSystemConfig(BaseModel):
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
brave_search_api_key: Optional[str] = None
|
||||
wolfram_api_key: Optional[str] = None
|
||||
|
|
|
@ -9,12 +9,13 @@ from typing import List
|
|||
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.safety.api.datatypes import (
|
||||
from llama_toolchain.safety.api import (
|
||||
OnViolationAction,
|
||||
RunShieldRequest,
|
||||
Safety,
|
||||
ShieldDefinition,
|
||||
ShieldResponse,
|
||||
)
|
||||
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
|
|
|
@ -1,180 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
|
||||
|
||||
from llama_toolchain.inference.api import (
|
||||
BuiltinTool,
|
||||
Message,
|
||||
SystemMessage,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from .tools.builtin import SingleMessageBuiltinTool
|
||||
|
||||
|
||||
def get_agentic_prefix_messages(
|
||||
builtin_tools: List[SingleMessageBuiltinTool],
|
||||
custom_tools: List[ToolDefinition],
|
||||
tool_prompt_format: ToolPromptFormat,
|
||||
) -> List[Message]:
|
||||
messages = []
|
||||
content = ""
|
||||
if builtin_tools:
|
||||
content += "Environment: ipython\n"
|
||||
|
||||
tool_str = ", ".join(
|
||||
[
|
||||
t.get_name()
|
||||
for t in builtin_tools
|
||||
if t.get_name() != BuiltinTool.code_interpreter.value
|
||||
]
|
||||
)
|
||||
if tool_str:
|
||||
content += f"Tools: {tool_str}"
|
||||
|
||||
current_date = datetime.now()
|
||||
formatted_date = current_date.strftime("%d %B %Y")
|
||||
date_str = f"""
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: {formatted_date}\n"""
|
||||
content += date_str
|
||||
messages.append(SystemMessage(content=content))
|
||||
|
||||
if custom_tools:
|
||||
if tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
text = prompt_for_function_tag(custom_tools)
|
||||
messages.append(UserMessage(content=text))
|
||||
elif tool_prompt_format == ToolPromptFormat.json:
|
||||
text = prompt_for_json(custom_tools)
|
||||
messages.append(UserMessage(content=text))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Tool prompt format {tool_prompt_format} is not supported"
|
||||
)
|
||||
else:
|
||||
messages.append(SystemMessage(content=content))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
|
||||
tool_defs = "\n".join(
|
||||
translate_custom_tool_definition_to_json(t) for t in custom_tools
|
||||
)
|
||||
content = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{tool_defs}
|
||||
|
||||
Return function calls in JSON format.
|
||||
"""
|
||||
)
|
||||
content = content.lstrip("\n").format(tool_defs=tool_defs)
|
||||
return content
|
||||
|
||||
|
||||
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
|
||||
custom_tool_params = ""
|
||||
for t in custom_tools:
|
||||
custom_tool_params += get_instruction_string(t) + "\n"
|
||||
custom_tool_params += get_parameters_string(t) + "\n\n"
|
||||
|
||||
content = f"""
|
||||
You have access to the following functions:
|
||||
|
||||
{custom_tool_params}
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{{"example_name": "example_value"}}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
"""
|
||||
return content
|
||||
|
||||
|
||||
def get_instruction_string(custom_tool_definition) -> str:
|
||||
return f"Use the function '{custom_tool_definition.tool_name}' to '{custom_tool_definition.description}'"
|
||||
|
||||
|
||||
def get_parameters_string(custom_tool_definition) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"name": custom_tool_definition.tool_name,
|
||||
"description": custom_tool_definition.description,
|
||||
"parameters": {
|
||||
name: definition.__dict__
|
||||
for name, definition in custom_tool_definition.parameters.items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def translate_custom_tool_definition_to_json(tool_def):
|
||||
"""Translates ToolDefinition to json as expected by model
|
||||
eg. output for a function
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "conv_int",
|
||||
"description": "Convert serialized fract24 integer into int value.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"data": {
|
||||
"type": "object",
|
||||
"description": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["data"]
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert isinstance(tool_def.tool_name, str)
|
||||
func_def = {"type": "function", "function": {}}
|
||||
func_def["function"]["name"] = tool_def.tool_name
|
||||
func_def["function"]["description"] = tool_def.description or ""
|
||||
if tool_def.parameters:
|
||||
required = []
|
||||
properties = []
|
||||
for p_name, p_def in tool_def.parameters.items():
|
||||
properties.append(
|
||||
{
|
||||
p_name: {
|
||||
# TODO: see if this should not always be object
|
||||
"type": "object",
|
||||
"description": p_def.description or "",
|
||||
}
|
||||
}
|
||||
)
|
||||
if p_def.required:
|
||||
required.append(p_name)
|
||||
func_def["function"]["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
else:
|
||||
func_def["function"]["parameters"] = {}
|
||||
|
||||
return json.dumps(func_def, indent=4)
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.inference.api import Message
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
raise NotImplementedError
|
|
@ -1,322 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
from termcolor import cprint
|
||||
|
||||
from .ipython_tool.code_execution import (
|
||||
CodeExecutionContext,
|
||||
CodeExecutionRequest,
|
||||
CodeExecutor,
|
||||
TOOLS_ATTACHMENT_KEY_REGEX,
|
||||
)
|
||||
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||
if match:
|
||||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class SingleMessageBuiltinTool(BaseTool):
|
||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
||||
assert len(messages) == 1, f"Expected single message, got {len(messages)}"
|
||||
|
||||
message = messages[0]
|
||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
||||
|
||||
tool_call = messages[0].tool_calls[0]
|
||||
|
||||
query = tool_call.arguments["query"]
|
||||
response: str = await self.run_impl(query)
|
||||
|
||||
message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=response,
|
||||
)
|
||||
if attachment := interpret_content_as_attachment(response):
|
||||
message.content = attachment
|
||||
|
||||
return [message]
|
||||
|
||||
@abstractmethod
|
||||
async def run_impl(self, query: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PhotogenTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, dump_dir: str) -> None:
|
||||
self.dump_dir = dump_dir
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.photogen.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
"""
|
||||
Implement this to give the model an ability to generate images.
|
||||
|
||||
Return:
|
||||
info = {
|
||||
"filepath": str(image_filepath),
|
||||
"mimetype": "image/png",
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BraveSearchTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.brave_search.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key,
|
||||
"Accept-Encoding": "gzip",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": query}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
return json.dumps(self._clean_brave_response(response.json()))
|
||||
|
||||
def _clean_brave_response(self, search_response, top_k=3):
|
||||
query = None
|
||||
clean_response = []
|
||||
if "query" in search_response:
|
||||
if "original" in search_response["query"]:
|
||||
query = search_response["query"]["original"]
|
||||
if "mixed" in search_response:
|
||||
mixed_results = search_response["mixed"]
|
||||
for m in mixed_results["main"][:top_k]:
|
||||
r_type = m["type"]
|
||||
results = search_response[r_type]["results"]
|
||||
if r_type == "web":
|
||||
# For web data - add a single output from the search
|
||||
idx = m["index"]
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"date",
|
||||
"extra_snippets",
|
||||
]
|
||||
cleaned = {
|
||||
k: v for k, v in results[idx].items() if k in selected_keys
|
||||
}
|
||||
elif r_type == "faq":
|
||||
# For faw data - take a list of all the questions & answers
|
||||
selected_keys = ["type", "question", "answer", "title", "url"]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
elif r_type == "infobox":
|
||||
idx = m["index"]
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"long_desc",
|
||||
]
|
||||
cleaned = {
|
||||
k: v for k, v in results[idx].items() if k in selected_keys
|
||||
}
|
||||
elif r_type == "videos":
|
||||
selected_keys = [
|
||||
"type",
|
||||
"url",
|
||||
"title",
|
||||
"description",
|
||||
"date",
|
||||
]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
elif r_type == "locations":
|
||||
# For faw data - take a list of all the questions & answers
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"coordinates",
|
||||
"postal_address",
|
||||
"contact",
|
||||
"rating",
|
||||
"distance",
|
||||
"zoom_level",
|
||||
]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
elif r_type == "news":
|
||||
# For faw data - take a list of all the questions & answers
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
else:
|
||||
cleaned = []
|
||||
|
||||
clean_response.append(cleaned)
|
||||
|
||||
return {"query": query, "top_k": clean_response}
|
||||
|
||||
|
||||
class WolframAlphaTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.url = "https://api.wolframalpha.com/v2/query"
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.wolfram_alpha.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
params = {
|
||||
"input": query,
|
||||
"appid": self.api_key,
|
||||
"format": "plaintext",
|
||||
"output": "json",
|
||||
}
|
||||
response = requests.get(
|
||||
self.url,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
||||
|
||||
def _clean_wolfram_alpha_response(self, wa_response):
|
||||
remove = {
|
||||
"queryresult": [
|
||||
"datatypes",
|
||||
"error",
|
||||
"timedout",
|
||||
"timedoutpods",
|
||||
"numpods",
|
||||
"timing",
|
||||
"parsetiming",
|
||||
"parsetimedout",
|
||||
"recalculate",
|
||||
"id",
|
||||
"host",
|
||||
"server",
|
||||
"related",
|
||||
"version",
|
||||
{
|
||||
"pods": [
|
||||
"scanner",
|
||||
"id",
|
||||
"error",
|
||||
"expressiontypes",
|
||||
"states",
|
||||
"infos",
|
||||
"position",
|
||||
"numsubpods",
|
||||
]
|
||||
},
|
||||
"assumptions",
|
||||
],
|
||||
}
|
||||
for main_key in remove:
|
||||
for key_to_remove in remove[main_key]:
|
||||
try:
|
||||
if key_to_remove == "assumptions":
|
||||
if "assumptions" in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
if isinstance(key_to_remove, dict):
|
||||
for sub_key in key_to_remove:
|
||||
if sub_key == "pods":
|
||||
for i in range(len(wa_response[main_key][sub_key])):
|
||||
if (
|
||||
wa_response[main_key][sub_key][i]["title"]
|
||||
== "Result"
|
||||
):
|
||||
del wa_response[main_key][sub_key][i + 1 :]
|
||||
break
|
||||
sub_items = wa_response[main_key][sub_key]
|
||||
for i in range(len(sub_items)):
|
||||
for sub_key_to_remove in key_to_remove[sub_key]:
|
||||
if sub_key_to_remove in sub_items[i]:
|
||||
del sub_items[i][sub_key_to_remove]
|
||||
elif key_to_remove in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
except KeyError:
|
||||
pass
|
||||
return wa_response
|
||||
|
||||
|
||||
class CodeInterpreterTool(BaseTool):
|
||||
def __init__(self) -> None:
|
||||
ctx = CodeExecutionContext(
|
||||
matplotlib_dump_dir=f"/tmp/{os.environ['USER']}_matplotlib_dump",
|
||||
)
|
||||
self.code_executor = CodeExecutor(ctx)
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.code_interpreter.value
|
||||
|
||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
||||
message = messages[0]
|
||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
||||
|
||||
tool_call = messages[0].tool_calls[0]
|
||||
script = tool_call.arguments["code"]
|
||||
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
res = self.code_executor.execute(req)
|
||||
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
res_out = res[out_type]
|
||||
if res_out != "":
|
||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||
if out_type == "stderr":
|
||||
cprint(f"ipython tool error: ↓\n{res_out}", color="red")
|
||||
|
||||
message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content="\n".join(pieces),
|
||||
)
|
||||
if attachment := interpret_content_as_attachment(res["stdout"]):
|
||||
message.content = attachment
|
||||
|
||||
return [message]
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,133 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
|
||||
# Disabling potentially dangerous functions
|
||||
import os as _os
|
||||
from functools import partial
|
||||
|
||||
os_funcs_to_disable = [
|
||||
"kill",
|
||||
"system",
|
||||
"putenv",
|
||||
"remove",
|
||||
"removedirs",
|
||||
"rmdir",
|
||||
"fchdir",
|
||||
"setuid",
|
||||
"fork",
|
||||
"forkpty",
|
||||
"killpg",
|
||||
"rename",
|
||||
"renames",
|
||||
"truncate",
|
||||
"replace",
|
||||
# "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
|
||||
"fchmod",
|
||||
"fchown",
|
||||
"chmod",
|
||||
"chown",
|
||||
"chroot",
|
||||
"fchdir",
|
||||
"lchflags",
|
||||
"lchmod",
|
||||
"lchown",
|
||||
"chdir",
|
||||
]
|
||||
|
||||
|
||||
def call_not_allowed(*args, **kwargs):
|
||||
raise OSError(errno.EPERM, "Call are not permitted in this environment")
|
||||
|
||||
|
||||
for func_name in os_funcs_to_disable:
|
||||
if hasattr(_os, func_name):
|
||||
setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
|
||||
|
||||
import shutil as _shutil
|
||||
|
||||
for func_name in ["rmtree", "move", "chown"]:
|
||||
if hasattr(_shutil, func_name):
|
||||
setattr(
|
||||
_shutil,
|
||||
func_name,
|
||||
partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
|
||||
)
|
||||
|
||||
import subprocess as _subprocess
|
||||
|
||||
|
||||
def popen_not_allowed(*args, **kwargs):
|
||||
raise _subprocess.CalledProcessError(
|
||||
-1,
|
||||
args[0] if args else "unknown",
|
||||
stderr="subprocess.Popen is not allowed in this environment",
|
||||
)
|
||||
|
||||
|
||||
_subprocess.Popen = popen_not_allowed
|
||||
|
||||
|
||||
import atexit as _atexit
|
||||
import builtins as _builtins
|
||||
import io as _io
|
||||
import json as _json
|
||||
import sys as _sys
|
||||
|
||||
# NB! The following "unused" imports crucial, make sure not not to remove
|
||||
# them with linters - they're used in code_execution.py
|
||||
from contextlib import ( # noqa
|
||||
contextmanager as _contextmanager,
|
||||
redirect_stderr as _redirect_stderr,
|
||||
redirect_stdout as _redirect_stdout,
|
||||
)
|
||||
from multiprocessing.connection import Connection as _Connection
|
||||
|
||||
# Mangle imports to avoid polluting model execution namespace.
|
||||
|
||||
_IO_SINK = _io.StringIO()
|
||||
_NETWORK_TIMEOUT = 5
|
||||
_NETWORK_CONNECTIONS = None
|
||||
|
||||
|
||||
def _open_connections():
|
||||
global _NETWORK_CONNECTIONS
|
||||
if _NETWORK_CONNECTIONS is not None:
|
||||
# Ensure connections only opened once.
|
||||
return _NETWORK_CONNECTIONS
|
||||
req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
|
||||
req_con = _Connection(int(req_w_fd), readable=False)
|
||||
resp_con = _Connection(int(resp_r_fd), writable=False)
|
||||
_NETWORK_CONNECTIONS = (req_con, resp_con)
|
||||
return _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
_builtins._open_connections = _open_connections
|
||||
|
||||
|
||||
@_atexit.register
|
||||
def _close_connections():
|
||||
global _NETWORK_CONNECTIONS
|
||||
if _NETWORK_CONNECTIONS is None:
|
||||
return
|
||||
for con in _NETWORK_CONNECTIONS:
|
||||
con.close()
|
||||
del _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
def _network_call(request):
|
||||
# NOTE: We communicate with the parent process in json, encoded
|
||||
# in raw bytes. We do this because native send/recv methods use
|
||||
# pickle which involves execution of arbitrary code.
|
||||
_open_connections()
|
||||
req_con, resp_con = _NETWORK_CONNECTIONS
|
||||
|
||||
req_con.send_bytes(_json.dumps(request).encode("utf-8"))
|
||||
if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
|
||||
raise Exception(f"Network request timed out: {_json.dumps(request)}")
|
||||
else:
|
||||
return _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
|
@ -1,256 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .utils import get_code_env_prefix
|
||||
|
||||
TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
|
||||
DIRNAME = Path(__file__).parent
|
||||
|
||||
CODE_EXEC_TIMEOUT = 20
|
||||
CODE_ENV_PREFIX = get_code_env_prefix()
|
||||
|
||||
STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
|
||||
with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
|
||||
{code}\
|
||||
"""
|
||||
|
||||
TRYEXCEPT_WRAPPER_TEMPLATE = """\
|
||||
try:
|
||||
{code}
|
||||
except:
|
||||
pass\
|
||||
"""
|
||||
|
||||
|
||||
def generate_bwrap_command(bind_dirs: List[str]) -> str:
|
||||
"""
|
||||
Generate the bwrap command string for binding all
|
||||
directories in the current directory read-only.
|
||||
"""
|
||||
bwrap_args = ""
|
||||
bwrap_args += "--ro-bind / / "
|
||||
# Add the --dev flag to mount device files
|
||||
bwrap_args += "--dev /dev "
|
||||
for d in bind_dirs:
|
||||
bwrap_args += f"--bind {d} {d} "
|
||||
|
||||
# Add the --unshare-all flag to isolate the sandbox from the rest of the system
|
||||
bwrap_args += "--unshare-all "
|
||||
# Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
|
||||
bwrap_args += "--die-with-parent "
|
||||
return bwrap_args
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeExecutionContext:
|
||||
matplotlib_dump_dir: str
|
||||
use_proxy: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeExecutionRequest:
|
||||
scripts: List[str]
|
||||
only_last_cell_stdouterr: bool = True
|
||||
only_last_cell_fail: bool = True
|
||||
seed: int = 0
|
||||
strip_fpaths_in_stderr: bool = True
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
def __init__(self, context: CodeExecutionContext):
|
||||
self.context = context
|
||||
|
||||
def execute(self, req: CodeExecutionRequest) -> dict:
|
||||
scripts = req.scripts
|
||||
for i in range(len(scripts) - 1):
|
||||
if req.only_last_cell_stdouterr:
|
||||
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(
|
||||
code=textwrap.indent(scripts[i], " " * 4)
|
||||
)
|
||||
if req.only_last_cell_fail:
|
||||
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(
|
||||
code=textwrap.indent(scripts[i], " " * 4)
|
||||
)
|
||||
|
||||
# Seeds prefix:
|
||||
seed = req.seed
|
||||
seeds_prefix = f"""\
|
||||
def _set_seeds():
|
||||
import random
|
||||
random.seed({seed})
|
||||
import numpy as np
|
||||
np.random.seed({seed})
|
||||
_set_seeds()\
|
||||
"""
|
||||
|
||||
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
|
||||
with tempfile.TemporaryDirectory() as dpath:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
code_fpath = os.path.join(dpath, "code.py")
|
||||
with open(code_fpath, "w") as f:
|
||||
f.write(script)
|
||||
|
||||
try:
|
||||
python_path = os.environ.get("PYTHONPATH", "")
|
||||
env = dict(
|
||||
os.environ,
|
||||
PYTHONHASHSEED=str(seed),
|
||||
MPLCONFIGDIR=dpath,
|
||||
MPLBACKEND="module://matplotlib_custom_backend",
|
||||
PYTHONPATH=f"{DIRNAME}:{python_path}",
|
||||
)
|
||||
stdout, stderr, returncode = do_subprocess(
|
||||
cmd=cmd,
|
||||
env=env,
|
||||
ctx=self.context,
|
||||
)
|
||||
|
||||
stderr = stderr.strip()
|
||||
if req.strip_fpaths_in_stderr:
|
||||
pattern = r'File "([^"]+)", line (\d+)'
|
||||
stderr = re.sub(pattern, r"line \2", stderr)
|
||||
|
||||
return {
|
||||
"process_status": "completed",
|
||||
"returncode": returncode,
|
||||
"stdout": stdout.strip(),
|
||||
"stderr": stderr,
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"process_status": "timeout",
|
||||
"stdout": "Timed out",
|
||||
"stderr": "Timed out",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"process_status": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"stderr": str(e),
|
||||
"stdout": str(e),
|
||||
}
|
||||
|
||||
|
||||
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||
image_data = response["image_data"]
|
||||
# Convert the base64 string to a bytes object
|
||||
images = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||
# Create a list of PIL images from the bytes objects
|
||||
images = [Image.open(BytesIO(img)) for img in images]
|
||||
# Create a list of image paths
|
||||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
# create new directory for each day to better organize data:
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
||||
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||
# save image into a file
|
||||
dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
|
||||
dump_fpath = dump_dpath / dump_fname
|
||||
img.save(dump_fpath, "PNG")
|
||||
image_paths.append(str(dump_fpath))
|
||||
|
||||
# this is kind of convoluted, we send back this response to the subprocess which
|
||||
# prints it out
|
||||
info = {
|
||||
"filepath": str(image_paths[-1]),
|
||||
"mimetype": "image/png",
|
||||
}
|
||||
return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
|
||||
|
||||
|
||||
def execute_subprocess_request(request, ctx: CodeExecutionContext):
|
||||
"Route requests from the subprocess (via network Pipes) to the internet/tools."
|
||||
if request["type"] == "matplotlib":
|
||||
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
|
||||
else:
|
||||
raise Exception(f'Unrecognised network request type: {request["type"]}')
|
||||
|
||||
|
||||
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
|
||||
# Create Pipes to be used for any external tool/network requests.
|
||||
req_r, req_w = multiprocessing.Pipe(duplex=False)
|
||||
resp_r, resp_w = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
cmd += [str(req_w.fileno()), str(resp_r.fileno())]
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
pass_fds=(req_w.fileno(), resp_r.fileno()),
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
close_fds=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Close unnecessary fds.
|
||||
req_w.close()
|
||||
resp_r.close()
|
||||
|
||||
pipe_close = False
|
||||
done_read = False
|
||||
start = time.monotonic()
|
||||
while proc.poll() is None and not pipe_close:
|
||||
if req_r.poll(0.1):
|
||||
# NB: Python pipe semantics for poll and recv mean that
|
||||
# poll() returns True is a pipe is closed.
|
||||
# CF old school PEP from '09
|
||||
# https://bugs.python.org/issue5573
|
||||
try:
|
||||
request = json.loads(req_r.recv_bytes().decode("utf-8"))
|
||||
response = execute_subprocess_request(request, ctx)
|
||||
|
||||
resp_w.send_bytes(json.dumps(response).encode("utf-8"))
|
||||
except EOFError:
|
||||
# The request pipe is closed - set a marker to exit
|
||||
# after the next attempt at reading stdout/stderr.
|
||||
pipe_close = True
|
||||
|
||||
try:
|
||||
# If lots has been printed, pipe might be full but
|
||||
# proc cannot exit until all the stdout/stderr
|
||||
# been written/read.
|
||||
stdout, stderr = proc.communicate(timeout=0.3)
|
||||
done_read = True
|
||||
except subprocess.TimeoutExpired:
|
||||
# The program has not terminated. Ignore it, there
|
||||
# may be more network/tool requests.
|
||||
continue
|
||||
if time.monotonic() - start > CODE_EXEC_TIMEOUT:
|
||||
proc.terminate()
|
||||
raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
|
||||
|
||||
if not done_read:
|
||||
# Solve race condition where process terminates before
|
||||
# we hit the while loop.
|
||||
stdout, stderr = proc.communicate(timeout=0.3)
|
||||
|
||||
resp_w.close()
|
||||
req_r.close()
|
||||
return stdout, stderr, proc.returncode
|
|
@ -1,87 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
A custom Matplotlib backend that overrides the show method to return image bytes.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json as _json
|
||||
|
||||
import matplotlib
|
||||
from matplotlib.backend_bases import FigureManagerBase
|
||||
|
||||
# Import necessary components from Matplotlib
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
|
||||
class CustomFigureCanvas(FigureCanvasAgg):
|
||||
def show(self):
|
||||
# Save the figure to a BytesIO object
|
||||
buf = io.BytesIO()
|
||||
self.print_png(buf)
|
||||
image_bytes = buf.getvalue()
|
||||
buf.close()
|
||||
return image_bytes
|
||||
|
||||
|
||||
class CustomFigureManager(FigureManagerBase):
|
||||
def __init__(self, canvas, num):
|
||||
super().__init__(canvas, num)
|
||||
|
||||
|
||||
# Mimic module initialization that integrates with the Matplotlib backend system
|
||||
def _create_figure_manager(num, *args, **kwargs):
|
||||
"""
|
||||
Create a custom figure manager instance.
|
||||
"""
|
||||
FigureClass = kwargs.pop("FigureClass", None) # noqa: N806
|
||||
if FigureClass is None:
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
FigureClass = Figure # noqa: N806
|
||||
fig = FigureClass(*args, **kwargs)
|
||||
canvas = CustomFigureCanvas(fig)
|
||||
manager = CustomFigureManager(canvas, num)
|
||||
return manager
|
||||
|
||||
|
||||
def show():
|
||||
"""
|
||||
Handle all figures and potentially return their images as bytes.
|
||||
|
||||
This function iterates over all figures registered with the custom backend,
|
||||
renders them as images in bytes format, and could return a list of bytes objects,
|
||||
one for each figure, or handle them as needed.
|
||||
"""
|
||||
image_data = []
|
||||
for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers():
|
||||
# Get the figure from the manager
|
||||
fig = manager.canvas.figure
|
||||
buf = io.BytesIO() # Create a buffer for the figure
|
||||
fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format
|
||||
buf.seek(0) # Go to the beginning of the buffer
|
||||
image_bytes = buf.getvalue() # Retrieve bytes value
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_data.append({"image_base64": image_base64})
|
||||
buf.close()
|
||||
|
||||
req_con, resp_con = _open_connections()
|
||||
|
||||
_json_dump = _json.dumps(
|
||||
{
|
||||
"type": "matplotlib",
|
||||
"image_data": image_data,
|
||||
}
|
||||
)
|
||||
req_con.send_bytes(_json_dump.encode("utf-8"))
|
||||
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
||||
print(resp)
|
||||
|
||||
|
||||
FigureCanvas = CustomFigureCanvas
|
||||
FigureManager = CustomFigureManager
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
|
||||
CODE_ENV_PREFIX = None
|
||||
|
||||
|
||||
def get_code_env_prefix() -> str:
|
||||
global CODE_ENV_PREFIX
|
||||
|
||||
if CODE_ENV_PREFIX is None:
|
||||
with open(CODE_ENV_PREFIX_FILE, "r") as f:
|
||||
CODE_ENV_PREFIX = f.read()
|
||||
|
||||
return CODE_ENV_PREFIX
|
|
@ -1,59 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.agentic_system.meta_reference.safety import ShieldRunnerMixin
|
||||
|
||||
from llama_toolchain.inference.api import Message
|
||||
from llama_toolchain.safety.api.datatypes import ShieldDefinition
|
||||
from llama_toolchain.safety.api.endpoints import Safety
|
||||
|
||||
from .builtin import BaseTool
|
||||
|
||||
|
||||
class SafeTool(BaseTool, ShieldRunnerMixin):
|
||||
"""A tool that makes other tools safety enabled"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
safety_api: Safety,
|
||||
input_shields: List[ShieldDefinition] = None,
|
||||
output_shields: List[ShieldDefinition] = None,
|
||||
):
|
||||
self._tool = tool
|
||||
ShieldRunnerMixin.__init__(
|
||||
self, safety_api, input_shields=input_shields, output_shields=output_shields
|
||||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
# return the name of the wrapped tool
|
||||
return self._tool.get_name()
|
||||
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
if self.input_shields:
|
||||
await self.run_shields(messages, self.input_shields)
|
||||
# run the underlying tool
|
||||
res = await self._tool.run(messages)
|
||||
if self.output_shields:
|
||||
await self.run_shields(messages, self.output_shields)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def with_safety(
|
||||
tool: BaseTool,
|
||||
safety_api: Safety,
|
||||
input_shields: List[ShieldDefinition] = None,
|
||||
output_shields: List[ShieldDefinition] = None,
|
||||
) -> SafeTool:
|
||||
return SafeTool(
|
||||
tool,
|
||||
safety_api,
|
||||
input_shields=input_shields,
|
||||
output_shields=output_shields,
|
||||
)
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_agentic_system_providers() -> List[ProviderSpec]:
|
||||
|
@ -16,15 +16,19 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
|
|||
provider_id="meta-reference",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
"matplotlib",
|
||||
"pillow",
|
||||
"pandas",
|
||||
"scikit-learn",
|
||||
"torch",
|
||||
"transformers",
|
||||
],
|
||||
module="llama_toolchain.agentic_system.meta_reference",
|
||||
config_class="llama_toolchain.agentic_system.meta_reference.AgenticSystemConfig",
|
||||
config_class="llama_toolchain.agentic_system.meta_reference.MetaReferenceImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.safety,
|
||||
Api.memory,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,106 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
|
||||
# TODO: this is symptomatic of us needing to pull more tooling related utilities
|
||||
from llama_toolchain.agentic_system.meta_reference.tools.builtin import (
|
||||
interpret_content_as_attachment,
|
||||
)
|
||||
|
||||
|
||||
class CustomTool:
|
||||
"""
|
||||
Developers can define their custom tools that models can use
|
||||
by extending this class.
|
||||
|
||||
Developers need to provide
|
||||
- name
|
||||
- description
|
||||
- params_definition
|
||||
- implement tool's behavior in `run_impl` method
|
||||
|
||||
NOTE: The return of the `run` method needs to be json serializable
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_params_definition(self) -> Dict[str, ToolParamDefinition]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_instruction_string(self) -> str:
|
||||
return f"Use the function '{self.get_name()}' to: {self.get_description()}"
|
||||
|
||||
def parameters_for_system_prompt(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"name": self.get_name(),
|
||||
"description": self.get_description(),
|
||||
"parameters": {
|
||||
name: definition.__dict__
|
||||
for name, definition in self.get_params_definition().items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def get_tool_definition(self) -> AgenticSystemToolDefinition:
|
||||
return AgenticSystemToolDefinition(
|
||||
tool_name=self.get_name(),
|
||||
description=self.get_description(),
|
||||
parameters=self.get_params_definition(),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SingleMessageCustomTool(CustomTool):
|
||||
"""
|
||||
Helper class to handle custom tools that take a single message
|
||||
Extending this class and implementing the `run_impl` method will
|
||||
allow for the tool be called by the model and the necessary plumbing.
|
||||
"""
|
||||
|
||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
||||
assert len(messages) == 1, "Expected single message"
|
||||
|
||||
message = messages[0]
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
|
||||
try:
|
||||
response = await self.run_impl(**tool_call.arguments)
|
||||
response_str = json.dumps(response, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
response_str = f"Error when running tool: {e}"
|
||||
|
||||
message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=response_str,
|
||||
)
|
||||
if attachment := interpret_content_as_attachment(response_str):
|
||||
message.content = attachment
|
||||
|
||||
return [message]
|
||||
|
||||
@abstractmethod
|
||||
async def run_impl(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
|
@ -1,83 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage
|
||||
|
||||
from llama_toolchain.agentic_system.api import (
|
||||
AgenticSystem,
|
||||
AgenticSystemTurnCreateRequest,
|
||||
AgenticSystemTurnResponseEventType as EventType,
|
||||
)
|
||||
|
||||
from llama_toolchain.inference.api import Message
|
||||
|
||||
|
||||
async def execute_with_custom_tools(
|
||||
system: AgenticSystem,
|
||||
system_id: str,
|
||||
session_id: str,
|
||||
messages: List[Message],
|
||||
custom_tools: List[Any],
|
||||
max_iters: int = 5,
|
||||
stream: bool = True,
|
||||
) -> AsyncGenerator:
|
||||
# first create a session, or do you keep a persistent session?
|
||||
tools_dict = {t.get_name(): t for t in custom_tools}
|
||||
|
||||
current_messages = messages.copy()
|
||||
n_iter = 0
|
||||
while n_iter < max_iters:
|
||||
n_iter += 1
|
||||
|
||||
request = AgenticSystemTurnCreateRequest(
|
||||
system_id=system_id,
|
||||
session_id=session_id,
|
||||
messages=current_messages,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
turn = None
|
||||
async for chunk in system.create_agentic_system_turn(request):
|
||||
if chunk.event.payload.event_type != EventType.turn_complete.value:
|
||||
yield chunk
|
||||
else:
|
||||
turn = chunk.event.payload.turn
|
||||
|
||||
message = turn.output_message
|
||||
if len(message.tool_calls) == 0:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
if message.stop_reason == StopReason.out_of_tokens:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
if tool_call.tool_name not in tools_dict:
|
||||
m = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
|
||||
)
|
||||
next_message = m
|
||||
else:
|
||||
tool = tools_dict[tool_call.tool_name]
|
||||
result_messages = await execute_custom_tool(tool, message)
|
||||
next_message = result_messages[0]
|
||||
|
||||
yield next_message
|
||||
current_messages = [next_message]
|
||||
|
||||
|
||||
async def execute_custom_tool(tool: Any, message: Message) -> List[Message]:
|
||||
result_messages = await tool.run([message])
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
), f"Expected single message, got {len(result_messages)}"
|
||||
|
||||
return result_messages
|
|
@ -1,122 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams
|
||||
|
||||
from llama_toolchain.agentic_system.api import (
|
||||
AgenticSystemCreateRequest,
|
||||
AgenticSystemInstanceConfig,
|
||||
AgenticSystemSessionCreateRequest,
|
||||
AgenticSystemToolDefinition,
|
||||
)
|
||||
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
|
||||
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
||||
|
||||
from llama_toolchain.agentic_system.tools.custom.execute import (
|
||||
execute_with_custom_tools,
|
||||
)
|
||||
from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
|
||||
|
||||
|
||||
# TODO: this should move back to the llama-agentic-system repo
|
||||
|
||||
|
||||
class AgenticSystemClientWrapper:
|
||||
def __init__(self, api, system_id, custom_tools):
|
||||
self.api = api
|
||||
self.system_id = system_id
|
||||
self.custom_tools = custom_tools
|
||||
self.session_id = None
|
||||
|
||||
async def create_session(self, name: str = None):
|
||||
if name is None:
|
||||
name = f"Session-{uuid.uuid4()}"
|
||||
|
||||
response = await self.api.create_agentic_system_session(
|
||||
AgenticSystemSessionCreateRequest(
|
||||
system_id=self.system_id,
|
||||
session_name=name,
|
||||
)
|
||||
)
|
||||
self.session_id = response.session_id
|
||||
return self.session_id
|
||||
|
||||
async def run(self, messages: List[Message], stream: bool = True):
|
||||
async for chunk in execute_with_custom_tools(
|
||||
self.api,
|
||||
self.system_id,
|
||||
self.session_id,
|
||||
messages,
|
||||
self.custom_tools,
|
||||
stream=stream,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
async def get_agent_system_instance(
|
||||
host: str,
|
||||
port: int,
|
||||
custom_tools: Optional[List[Any]] = None,
|
||||
disable_safety: bool = False,
|
||||
model: str = "Meta-Llama3.1-8B-Instruct",
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
) -> AgenticSystemClientWrapper:
|
||||
custom_tools = custom_tools or []
|
||||
|
||||
api = AgenticSystemClient(base_url=f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
),
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
),
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.photogen,
|
||||
),
|
||||
AgenticSystemToolDefinition(
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
),
|
||||
] + [t.get_tool_definition() for t in custom_tools]
|
||||
|
||||
if not disable_safety:
|
||||
for t in tool_definitions:
|
||||
t.input_shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
||||
t.output_shields = [
|
||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
|
||||
]
|
||||
|
||||
create_request = AgenticSystemCreateRequest(
|
||||
model=model,
|
||||
instance_config=AgenticSystemInstanceConfig(
|
||||
instructions="You are a helpful assistant",
|
||||
available_tools=tool_definitions,
|
||||
input_shields=(
|
||||
[]
|
||||
if disable_safety
|
||||
else [
|
||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||
ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield),
|
||||
]
|
||||
),
|
||||
output_shields=(
|
||||
[]
|
||||
if disable_safety
|
||||
else [
|
||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||
]
|
||||
),
|
||||
sampling_params=SamplingParams(),
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
),
|
||||
)
|
||||
create_response = await api.create_agentic_system(create_request)
|
||||
return AgenticSystemClientWrapper(api, create_response.system_id, custom_tools)
|
Loading…
Add table
Add a link
Reference in a new issue