InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool

This commit is contained in:
Ashwin Bharambe 2024-08-22 17:44:56 -07:00
parent 48c6a32edd
commit 31289e3f47
5 changed files with 56 additions and 40 deletions

View file

@ -10,7 +10,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_toolchain.common.deployment_types import * # noqa: F403 from llama_toolchain.common.deployment_types import * # noqa: F403
@ -20,7 +20,31 @@ from llama_toolchain.memory.api.datatypes import * # noqa: F403
@json_schema_type @json_schema_type
class AgenticSystemToolDefinition(ToolDefinition): class Attachment(BaseModel):
url: URL
mime_type: str
class AgenticSystemBuiltinTool(BuiltinTool):
memory = "memory"
@json_schema_type
class AgenticSystemToolDefinition(BaseModel):
tool_name: Union[AgenticSystemBuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@validator("tool_name", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return AgenticSystemBuiltinTool(v)
except ValueError:
return v
return v
execution_config: Optional[RestAPIExecutionConfig] = None execution_config: Optional[RestAPIExecutionConfig] = None
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
@ -124,11 +148,7 @@ class MemoryConfig(BaseModel):
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot or built-in tool configurations as input to the model memory_configs: Optional[List[MemoryConfig]] = Field(default_factory=list)
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
default_factory=list
)
memory: Optional[List[MemoryConfig]] = Field(default_factory=list)
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
@ -136,6 +156,9 @@ class AgentConfigCommon(BaseModel):
# if you completely want to replace the messages prefixed by the system, # if you completely want to replace the messages prefixed by the system,
# this is debug only # this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list) debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json default=ToolPromptFormat.json
) )

View file

@ -46,7 +46,7 @@ class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn):
ToolResponseMessage, ToolResponseMessage,
] ]
] ]
attachments: List[Attachment] attachments: Optional[List[Attachment]] = None
stream: Optional[bool] = False stream: Optional[bool] = False

View file

@ -33,7 +33,6 @@ class ChatAgent(ShieldRunnerMixin):
memory_api: Memory, memory_api: Memory,
safety_api: Safety, safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool], builtin_tools: List[SingleMessageBuiltinTool],
custom_tool_definitions: List[ToolDefinition],
max_infer_iters: int = 10, max_infer_iters: int = 10,
): ):
self.agent_config = agent_config self.agent_config = agent_config
@ -108,7 +107,6 @@ class ChatAgent(ShieldRunnerMixin):
# print_dialog(messages) # print_dialog(messages)
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
params = self.agent_config.sampling_params
start_time = datetime.now() start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk( yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgenticSystemTurnResponseEvent(
@ -123,10 +121,9 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in self.run( async for chunk in self.run(
turn_id=turn_id, turn_id=turn_id,
input_messages=messages, input_messages=messages,
temperature=params.temperature, attachments=request.attachments or [],
top_p=params.top_p, sampling_params=self.agent_config.sampling_params,
stream=request.stream, stream=request.stream,
max_gen_len=params.max_tokens,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
cprint( cprint(
@ -175,10 +172,9 @@ class ChatAgent(ShieldRunnerMixin):
self, self,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: List[Message],
temperature: float, attachments: List[Attachment],
top_p: float, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to # Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot # streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -194,7 +190,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res yield res
async for res in self._run( async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len turn_id, input_messages, attachments, sampling_params, stream
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -279,20 +275,27 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
async def _should_retrieve_context(self, messages: List[Message]) -> bool: ... async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
return self.agent_config.memory_configs or len(attachments) > 0
async def _retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> List[Message]:
return []
async def _run( async def _run(
self, self,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: List[Message],
temperature: float, attachments: List[Attachment],
top_p: float, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
need_context = await self._should_retrieve_context(input_messages) need_context = await self._should_retrieve_context(input_messages, attachments)
if need_context: if need_context:
context = await self._retrieve_context(input_messages) context_messages = await self._retrieve_context(input_messages)
# input_messages = preprocess_dialog(input_messages, self.prefix_messages) # input_messages = preprocess_dialog(input_messages, self.prefix_messages)
# input_messages = input_messages + context # input_messages = input_messages + context
input_messages = preprocess_dialog(input_messages) input_messages = preprocess_dialog(input_messages)
@ -320,18 +323,13 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
# where are the available tools?
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=self.agent_config.model, model=self.agent_config.model,
messages=input_messages, messages=input_messages,
tools=self.agent_config.available_tools, tools=self.agent_config.available_tools,
tool_prompt_format=self.agent_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True, stream=True,
sampling_params=SamplingParams( sampling_params=sampling_params,
temperature=temperature,
top_p=top_p,
max_tokens=max_gen_len,
),
) )
tool_calls = [] tool_calls = []
@ -554,6 +552,7 @@ def attachment_message(url: URL) -> ToolResponseMessage:
def preprocess_dialog(messages: List[Message]) -> List[Message]: def preprocess_dialog(messages: List[Message]) -> List[Message]:
# remove system message since those are
""" """
Preprocesses the dialog by removing the system message and Preprocesses the dialog by removing the system message and
adding the system message to the beginning of the dialog. adding the system message to the beginning of the dialog.
@ -565,7 +564,7 @@ def preprocess_dialog(messages: List[Message]) -> List[Message]:
continue continue
# NOTE: the ideal behavior is to use `file_path = ...` but that # NOTE: the ideal behavior is to use `file_path = ...` but that
# means we need to have stateful execution o f code which we currently # means we need to have stateful execution of code which we currently
# do not have. # do not have.
if isinstance(m.content, Attachment): if isinstance(m.content, Attachment):
ret.append(attachment_message(m.content.url)) ret.append(attachment_message(m.content.url))

View file

@ -16,9 +16,6 @@ from llama_toolchain.inference.api.datatypes import BuiltinTool
from llama_toolchain.memory.api import Memory from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety from llama_toolchain.safety.api import Safety
from llama_toolchain.agentic_system.api import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403
from .agent_instance import ChatAgent
from .config import MetaReferenceImplConfig
from llama_toolchain.tools.builtin import ( from llama_toolchain.tools.builtin import (
BraveSearchTool, BraveSearchTool,
CodeInterpreterTool, CodeInterpreterTool,
@ -27,7 +24,8 @@ from llama_toolchain.tools.builtin import (
) )
from llama_toolchain.tools.safety import with_safety from llama_toolchain.tools.safety import with_safety
from .agent_instance import AgentInstance from .agent_instance import AgentInstance, ChatAgent
from .config import MetaReferenceImplConfig
logger = logging.getLogger() logger = logging.getLogger()
@ -76,7 +74,6 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
system_id = str(uuid.uuid4()) system_id = str(uuid.uuid4())
builtin_tools = [] builtin_tools = []
custom_tool_definitions = []
cfg = request.agent_config cfg = request.agent_config
for dfn in cfg.available_tools: for dfn in cfg.available_tools:
if isinstance(dfn.tool_name, BuiltinTool): if isinstance(dfn.tool_name, BuiltinTool):
@ -104,8 +101,6 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
tool, self.safety_api, dfn.input_shields, dfn.output_shields tool, self.safety_api, dfn.input_shields, dfn.output_shields
) )
) )
else:
custom_tool_definitions.append(dfn)
AGENT_INSTANCES_BY_ID[system_id] = ChatAgent( AGENT_INSTANCES_BY_ID[system_id] = ChatAgent(
agent_config=cfg, agent_config=cfg,
@ -113,7 +108,6 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
safety_api=self.safety_api, safety_api=self.safety_api,
memory_api=self.memory_api, memory_api=self.memory_api,
builtin_tools=builtin_tools, builtin_tools=builtin_tools,
custom_tool_definitions=custom_tool_definitions,
) )
return AgenticSystemCreateResponse( return AgenticSystemCreateResponse(

View file

@ -6,7 +6,7 @@
from typing import List, Protocol from typing import List, Protocol
from llama_models.llama3_1.api.datatypes import InterleavedTextAttachment from llama_models.llama3.api.datatypes import InterleavedTextMedia
from llama_models.schema_utils import webmethod from llama_models.schema_utils import webmethod
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403
@ -14,7 +14,7 @@ from .datatypes import * # noqa: F403
@json_schema_type @json_schema_type
class RetrieveMemoryDocumentsRequest(BaseModel): class RetrieveMemoryDocumentsRequest(BaseModel):
query: InterleavedTextAttachment query: InterleavedTextMedia
bank_ids: str bank_ids: str