InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool

This commit is contained in:
Ashwin Bharambe 2024-08-22 17:44:56 -07:00
parent 48c6a32edd
commit 31289e3f47
5 changed files with 56 additions and 40 deletions

View file

@ -33,7 +33,6 @@ class ChatAgent(ShieldRunnerMixin):
memory_api: Memory,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
custom_tool_definitions: List[ToolDefinition],
max_infer_iters: int = 10,
):
self.agent_config = agent_config
@ -108,7 +107,6 @@ class ChatAgent(ShieldRunnerMixin):
# print_dialog(messages)
turn_id = str(uuid.uuid4())
params = self.agent_config.sampling_params
start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
@ -123,10 +121,9 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in self.run(
turn_id=turn_id,
input_messages=messages,
temperature=params.temperature,
top_p=params.top_p,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
max_gen_len=params.max_tokens,
):
if isinstance(chunk, CompletionMessage):
cprint(
@ -175,10 +172,9 @@ class ChatAgent(ShieldRunnerMixin):
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -194,7 +190,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
turn_id, input_messages, attachments, sampling_params, stream
):
if isinstance(res, bool):
return
@ -279,20 +275,27 @@ class ChatAgent(ShieldRunnerMixin):
)
)
async def _should_retrieve_context(self, messages: List[Message]) -> bool: ...
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
return self.agent_config.memory_configs or len(attachments) > 0
async def _retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> List[Message]:
return []
async def _run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
need_context = await self._should_retrieve_context(input_messages)
need_context = await self._should_retrieve_context(input_messages, attachments)
if need_context:
context = await self._retrieve_context(input_messages)
context_messages = await self._retrieve_context(input_messages)
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
# input_messages = input_messages + context
input_messages = preprocess_dialog(input_messages)
@ -320,18 +323,13 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# where are the available tools?
req = ChatCompletionRequest(
model=self.agent_config.model,
messages=input_messages,
tools=self.agent_config.available_tools,
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_gen_len,
),
sampling_params=sampling_params,
)
tool_calls = []
@ -554,6 +552,7 @@ def attachment_message(url: URL) -> ToolResponseMessage:
def preprocess_dialog(messages: List[Message]) -> List[Message]:
# remove system message since those are
"""
Preprocesses the dialog by removing the system message and
adding the system message to the beginning of the dialog.
@ -565,7 +564,7 @@ def preprocess_dialog(messages: List[Message]) -> List[Message]:
continue
# NOTE: the ideal behavior is to use `file_path = ...` but that
# means we need to have stateful execution o f code which we currently
# means we need to have stateful execution of code which we currently
# do not have.
if isinstance(m.content, Attachment):
ret.append(attachment_message(m.content.url))