mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
* [1/n] migrate inference/chat_completion * migrate inference/completion * inference/completion * inference regenerate openapi spec * safety api * migrate agentic system * migrate apis without implementations * re-generate openapi spec * remove hack from openapi generator * fix inference * fix inference * openapi generator rerun * Simplified Telemetry API and tying it to logger (#57) * Simplified Telemetry API and tying it to logger * small update which adds a METRIC type * move span events one level down into structured log events --------- Co-authored-by: Ashwin Bharambe <ashwin@meta.com> * fix api to work with openapi generator * fix agentic calling inference * together adapter inference * update inference adapters --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
797 lines
30 KiB
Python
797 lines
30 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import copy
|
|
import os
|
|
import secrets
|
|
import shutil
|
|
import string
|
|
import tempfile
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import AsyncGenerator, List, Tuple
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
|
|
from termcolor import cprint
|
|
|
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
|
from llama_toolchain.inference.api import * # noqa: F403
|
|
from llama_toolchain.memory.api import * # noqa: F403
|
|
from llama_toolchain.safety.api import * # noqa: F403
|
|
|
|
from llama_toolchain.tools.base import BaseTool
|
|
from llama_toolchain.tools.builtin import (
|
|
interpret_content_as_attachment,
|
|
SingleMessageBuiltinTool,
|
|
)
|
|
|
|
from .rag.context_retriever import generate_rag_query
|
|
from .safety import SafetyException, ShieldRunnerMixin
|
|
|
|
|
|
def make_random_string(length: int = 8):
|
|
return "".join(
|
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
|
)
|
|
|
|
|
|
class ChatAgent(ShieldRunnerMixin):
|
|
def __init__(
|
|
self,
|
|
agent_config: AgentConfig,
|
|
inference_api: Inference,
|
|
memory_api: Memory,
|
|
safety_api: Safety,
|
|
builtin_tools: List[SingleMessageBuiltinTool],
|
|
max_infer_iters: int = 10,
|
|
):
|
|
self.agent_config = agent_config
|
|
self.inference_api = inference_api
|
|
self.memory_api = memory_api
|
|
self.safety_api = safety_api
|
|
|
|
self.max_infer_iters = max_infer_iters
|
|
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
|
|
|
self.tempdir = tempfile.mkdtemp()
|
|
self.sessions = {}
|
|
|
|
ShieldRunnerMixin.__init__(
|
|
self,
|
|
safety_api,
|
|
input_shields=agent_config.input_shields,
|
|
output_shields=agent_config.output_shields,
|
|
)
|
|
|
|
def __del__(self):
|
|
shutil.rmtree(self.tempdir)
|
|
|
|
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
|
messages = []
|
|
|
|
# We do not want to keep adding RAG context to the input messages
|
|
# May be this should be a parameter of the agentic instance
|
|
# that can define its behavior in a custom way
|
|
for m in turn.input_messages:
|
|
msg = m.copy()
|
|
if isinstance(msg, UserMessage):
|
|
msg.context = None
|
|
messages.append(msg)
|
|
|
|
# messages.extend(turn.input_messages)
|
|
for step in turn.steps:
|
|
if step.step_type == StepType.inference.value:
|
|
messages.append(step.model_response)
|
|
elif step.step_type == StepType.tool_execution.value:
|
|
for response in step.tool_responses:
|
|
messages.append(
|
|
ToolResponseMessage(
|
|
call_id=response.call_id,
|
|
tool_name=response.tool_name,
|
|
content=response.content,
|
|
)
|
|
)
|
|
elif step.step_type == StepType.shield_call.value:
|
|
response = step.response
|
|
if response.is_violation:
|
|
# CompletionMessage itself in the ShieldResponse
|
|
messages.append(
|
|
CompletionMessage(
|
|
content=response.violation_return_message,
|
|
stop_reason=StopReason.end_of_turn,
|
|
)
|
|
)
|
|
# print_dialog(messages)
|
|
return messages
|
|
|
|
def create_session(self, name: str) -> Session:
|
|
session_id = str(uuid.uuid4())
|
|
session = Session(
|
|
session_id=session_id,
|
|
session_name=name,
|
|
turns=[],
|
|
started_at=datetime.now(),
|
|
)
|
|
self.sessions[session_id] = session
|
|
return session
|
|
|
|
async def create_and_execute_turn(
|
|
self, request: AgenticSystemTurnCreateRequest
|
|
) -> AsyncGenerator:
|
|
assert (
|
|
request.session_id in self.sessions
|
|
), f"Session {request.session_id} not found"
|
|
|
|
session = self.sessions[request.session_id]
|
|
|
|
messages = []
|
|
for i, turn in enumerate(session.turns):
|
|
messages.extend(self.turn_to_messages(turn))
|
|
|
|
messages.extend(request.messages)
|
|
|
|
# print("processed dialog ======== ")
|
|
# print_dialog(messages)
|
|
|
|
turn_id = str(uuid.uuid4())
|
|
start_time = datetime.now()
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseTurnStartPayload(
|
|
turn_id=turn_id,
|
|
)
|
|
)
|
|
)
|
|
|
|
steps = []
|
|
output_message = None
|
|
async for chunk in self.run(
|
|
session=session,
|
|
turn_id=turn_id,
|
|
input_messages=messages,
|
|
attachments=request.attachments or [],
|
|
sampling_params=self.agent_config.sampling_params,
|
|
stream=request.stream,
|
|
):
|
|
if isinstance(chunk, CompletionMessage):
|
|
cprint(
|
|
f"{chunk.role.capitalize()}: {chunk.content}",
|
|
"white",
|
|
attrs=["bold"],
|
|
)
|
|
output_message = chunk
|
|
continue
|
|
|
|
assert isinstance(
|
|
chunk, AgenticSystemTurnResponseStreamChunk
|
|
), f"Unexpected type {type(chunk)}"
|
|
event = chunk.event
|
|
if (
|
|
event.payload.event_type
|
|
== AgenticSystemTurnResponseEventType.step_complete.value
|
|
):
|
|
steps.append(event.payload.step_details)
|
|
|
|
yield chunk
|
|
|
|
assert output_message is not None
|
|
|
|
turn = Turn(
|
|
turn_id=turn_id,
|
|
session_id=request.session_id,
|
|
input_messages=request.messages,
|
|
output_message=output_message,
|
|
started_at=start_time,
|
|
completed_at=datetime.now(),
|
|
steps=steps,
|
|
)
|
|
session.turns.append(turn)
|
|
|
|
chunk = AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseTurnCompletePayload(
|
|
turn=turn,
|
|
)
|
|
)
|
|
)
|
|
yield chunk
|
|
|
|
async def run(
|
|
self,
|
|
session: Session,
|
|
turn_id: str,
|
|
input_messages: List[Message],
|
|
attachments: List[Attachment],
|
|
sampling_params: SamplingParams,
|
|
stream: bool = False,
|
|
) -> AsyncGenerator:
|
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
|
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
|
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
|
|
|
async for res in self.run_shields_wrapper(
|
|
turn_id, input_messages, self.input_shields, "user-input"
|
|
):
|
|
if isinstance(res, bool):
|
|
return
|
|
else:
|
|
yield res
|
|
|
|
async for res in self._run(
|
|
session, turn_id, input_messages, attachments, sampling_params, stream
|
|
):
|
|
if isinstance(res, bool):
|
|
return
|
|
elif isinstance(res, CompletionMessage):
|
|
final_response = res
|
|
break
|
|
else:
|
|
yield res
|
|
|
|
assert final_response is not None
|
|
# for output shields run on the full input and output combination
|
|
messages = input_messages + [final_response]
|
|
|
|
async for res in self.run_shields_wrapper(
|
|
turn_id, messages, self.output_shields, "assistant-output"
|
|
):
|
|
if isinstance(res, bool):
|
|
return
|
|
else:
|
|
yield res
|
|
|
|
yield final_response
|
|
|
|
async def run_shields_wrapper(
|
|
self,
|
|
turn_id: str,
|
|
messages: List[Message],
|
|
shields: List[ShieldDefinition],
|
|
touchpoint: str,
|
|
) -> AsyncGenerator:
|
|
if len(shields) == 0:
|
|
return
|
|
|
|
step_id = str(uuid.uuid4())
|
|
try:
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepStartPayload(
|
|
step_type=StepType.shield_call.value,
|
|
step_id=step_id,
|
|
metadata=dict(touchpoint=touchpoint),
|
|
)
|
|
)
|
|
)
|
|
await self.run_shields(messages, shields)
|
|
|
|
except SafetyException as e:
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
|
step_type=StepType.shield_call.value,
|
|
step_details=ShieldCallStep(
|
|
step_id=step_id,
|
|
turn_id=turn_id,
|
|
response=e.response,
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
yield CompletionMessage(
|
|
content=str(e),
|
|
stop_reason=StopReason.end_of_turn,
|
|
)
|
|
yield False
|
|
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
|
step_type=StepType.shield_call.value,
|
|
step_details=ShieldCallStep(
|
|
step_id=step_id,
|
|
turn_id=turn_id,
|
|
response=ShieldResponse(
|
|
# TODO: fix this, give each shield a shield type method and
|
|
# fire one event for each shield run
|
|
shield_type=BuiltinShield.llama_guard,
|
|
is_violation=False,
|
|
),
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
async def _run(
|
|
self,
|
|
session: Session,
|
|
turn_id: str,
|
|
input_messages: List[Message],
|
|
attachments: List[Attachment],
|
|
sampling_params: SamplingParams,
|
|
stream: bool = False,
|
|
) -> AsyncGenerator:
|
|
enabled_tools = set(t.type for t in self.agent_config.tools)
|
|
need_rag_context = await self._should_retrieve_context(
|
|
input_messages, attachments
|
|
)
|
|
if need_rag_context:
|
|
step_id = str(uuid.uuid4())
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepStartPayload(
|
|
step_type=StepType.memory_retrieval.value,
|
|
step_id=step_id,
|
|
)
|
|
)
|
|
)
|
|
|
|
# TODO: find older context from the session and either replace it
|
|
# or append with a sliding window. this is really a very simplistic implementation
|
|
rag_context, bank_ids = await self._retrieve_context(
|
|
session, input_messages, attachments
|
|
)
|
|
|
|
step_id = str(uuid.uuid4())
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
|
step_type=StepType.memory_retrieval.value,
|
|
step_id=step_id,
|
|
step_details=MemoryRetrievalStep(
|
|
turn_id=turn_id,
|
|
step_id=step_id,
|
|
memory_bank_ids=bank_ids,
|
|
inserted_context=rag_context or "",
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
if rag_context:
|
|
last_message = input_messages[-1]
|
|
last_message.context = "\n".join(rag_context)
|
|
|
|
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
|
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
|
msg = await attachment_message(self.tempdir, urls)
|
|
input_messages.append(msg)
|
|
|
|
output_attachments = []
|
|
|
|
n_iter = 0
|
|
while True:
|
|
msg = input_messages[-1]
|
|
if msg.role == Role.user.value:
|
|
color = "blue"
|
|
elif msg.role == Role.ipython.value:
|
|
color = "yellow"
|
|
else:
|
|
color = None
|
|
cprint(f"{str(msg)}", color=color)
|
|
|
|
step_id = str(uuid.uuid4())
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepStartPayload(
|
|
step_type=StepType.inference.value,
|
|
step_id=step_id,
|
|
)
|
|
)
|
|
)
|
|
|
|
tool_calls = []
|
|
content = ""
|
|
stop_reason = None
|
|
async for chunk in self.inference_api.chat_completion(
|
|
self.agent_config.model,
|
|
input_messages,
|
|
tools=self._get_tools(),
|
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
|
stream=True,
|
|
sampling_params=sampling_params,
|
|
):
|
|
event = chunk.event
|
|
if event.event_type == ChatCompletionResponseEventType.start:
|
|
continue
|
|
elif event.event_type == ChatCompletionResponseEventType.complete:
|
|
stop_reason = StopReason.end_of_turn
|
|
continue
|
|
|
|
delta = event.delta
|
|
if isinstance(delta, ToolCallDelta):
|
|
if delta.parse_status == ToolCallParseStatus.success:
|
|
tool_calls.append(delta.content)
|
|
|
|
if stream:
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepProgressPayload(
|
|
step_type=StepType.inference.value,
|
|
step_id=step_id,
|
|
model_response_text_delta="",
|
|
tool_call_delta=delta,
|
|
)
|
|
)
|
|
)
|
|
|
|
elif isinstance(delta, str):
|
|
content += delta
|
|
if stream and event.stop_reason is None:
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepProgressPayload(
|
|
step_type=StepType.inference.value,
|
|
step_id=step_id,
|
|
model_response_text_delta=event.delta,
|
|
)
|
|
)
|
|
)
|
|
else:
|
|
raise ValueError(f"Unexpected delta type {type(delta)}")
|
|
|
|
if event.stop_reason is not None:
|
|
stop_reason = event.stop_reason
|
|
|
|
stop_reason = stop_reason or StopReason.out_of_tokens
|
|
message = CompletionMessage(
|
|
content=content,
|
|
stop_reason=stop_reason,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
|
step_type=StepType.inference.value,
|
|
step_id=step_id,
|
|
step_details=InferenceStep(
|
|
# somewhere deep, we are re-assigning message or closing over some
|
|
# variable which causes message to mutate later on. fix with a
|
|
# `deepcopy` for now, but this is symptomatic of a deeper issue.
|
|
step_id=step_id,
|
|
turn_id=turn_id,
|
|
model_response=copy.deepcopy(message),
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
if n_iter >= self.max_infer_iters:
|
|
cprint("Done with MAX iterations, exiting.")
|
|
yield message
|
|
break
|
|
|
|
if stop_reason == StopReason.out_of_tokens:
|
|
cprint("Out of token budget, exiting.")
|
|
yield message
|
|
break
|
|
|
|
if len(message.tool_calls) == 0:
|
|
if stop_reason == StopReason.end_of_turn:
|
|
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
|
if len(output_attachments) > 0:
|
|
if isinstance(message.content, list):
|
|
message.content += attachments
|
|
else:
|
|
message.content = [message.content] + attachments
|
|
yield message
|
|
else:
|
|
cprint(f"Partial message: {str(message)}", color="green")
|
|
input_messages = input_messages + [message]
|
|
else:
|
|
cprint(f"{str(message)}", color="green")
|
|
try:
|
|
tool_call = message.tool_calls[0]
|
|
|
|
name = tool_call.tool_name
|
|
if not isinstance(name, BuiltinTool):
|
|
yield message
|
|
return
|
|
|
|
step_id = str(uuid.uuid4())
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepStartPayload(
|
|
step_type=StepType.tool_execution.value,
|
|
step_id=step_id,
|
|
)
|
|
)
|
|
)
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepProgressPayload(
|
|
step_type=StepType.tool_execution.value,
|
|
step_id=step_id,
|
|
tool_call=tool_call,
|
|
)
|
|
)
|
|
)
|
|
|
|
result_messages = await execute_tool_call_maybe(
|
|
self.tools_dict,
|
|
[message],
|
|
)
|
|
assert (
|
|
len(result_messages) == 1
|
|
), "Currently not supporting multiple messages"
|
|
result_message = result_messages[0]
|
|
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
|
step_type=StepType.tool_execution.value,
|
|
step_details=ToolExecutionStep(
|
|
step_id=step_id,
|
|
turn_id=turn_id,
|
|
tool_calls=[tool_call],
|
|
tool_responses=[
|
|
ToolResponse(
|
|
call_id=result_message.call_id,
|
|
tool_name=result_message.tool_name,
|
|
content=result_message.content,
|
|
)
|
|
],
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
|
# but that needs a lot more refactoring of Tool code potentially
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
|
step_type=StepType.shield_call.value,
|
|
step_details=ShieldCallStep(
|
|
step_id=str(uuid.uuid4()),
|
|
turn_id=turn_id,
|
|
response=ShieldResponse(
|
|
# TODO: fix this, give each shield a shield type method and
|
|
# fire one event for each shield run
|
|
shield_type=BuiltinShield.llama_guard,
|
|
is_violation=False,
|
|
),
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
except SafetyException as e:
|
|
yield AgenticSystemTurnResponseStreamChunk(
|
|
event=AgenticSystemTurnResponseEvent(
|
|
payload=AgenticSystemTurnResponseStepCompletePayload(
|
|
step_type=StepType.shield_call.value,
|
|
step_details=ShieldCallStep(
|
|
step_id=str(uuid.uuid4()),
|
|
turn_id=turn_id,
|
|
response=e.response,
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
yield CompletionMessage(
|
|
content=str(e),
|
|
stop_reason=StopReason.end_of_turn,
|
|
)
|
|
yield False
|
|
return
|
|
|
|
if out_attachment := interpret_content_as_attachment(
|
|
result_message.content
|
|
):
|
|
# NOTE: when we push this message back to the model, the model may ignore the
|
|
# attached file path etc. since the model is trained to only provide a user message
|
|
# with the summary. We keep all generated attachments and then attach them to final message
|
|
output_attachments.append(out_attachment)
|
|
|
|
input_messages = input_messages + [message, result_message]
|
|
|
|
n_iter += 1
|
|
|
|
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
|
|
if session.memory_bank is None:
|
|
session.memory_bank = await self.memory_api.create_memory_bank(
|
|
name=f"memory_bank_{session.session_id}",
|
|
config=VectorMemoryBankConfig(
|
|
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
|
|
chunk_size_in_tokens=512,
|
|
),
|
|
)
|
|
|
|
return session.memory_bank
|
|
|
|
async def _should_retrieve_context(
|
|
self, messages: List[Message], attachments: List[Attachment]
|
|
) -> bool:
|
|
enabled_tools = set(t.type for t in self.agent_config.tools)
|
|
if attachments:
|
|
if (
|
|
AgenticSystemTool.code_interpreter.value in enabled_tools
|
|
and self.agent_config.tool_choice == ToolChoice.required
|
|
):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
return AgenticSystemTool.memory.value in enabled_tools
|
|
|
|
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
|
for t in self.agent_config.tools:
|
|
if t.type == AgenticSystemTool.memory.value:
|
|
return t
|
|
|
|
return None
|
|
|
|
async def _retrieve_context(
|
|
self, session: Session, messages: List[Message], attachments: List[Attachment]
|
|
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
|
bank_ids = []
|
|
|
|
memory = self._memory_tool_definition()
|
|
assert memory is not None, "Memory tool not configured"
|
|
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
|
|
|
if attachments:
|
|
bank = await self._ensure_memory_bank(session)
|
|
bank_ids.append(bank.bank_id)
|
|
|
|
documents = [
|
|
MemoryBankDocument(
|
|
document_id=str(uuid.uuid4()),
|
|
content=a.content,
|
|
mime_type=a.mime_type,
|
|
metadata={},
|
|
)
|
|
for a in attachments
|
|
]
|
|
await self.memory_api.insert_documents(bank.bank_id, documents)
|
|
elif session.memory_bank:
|
|
bank_ids.append(session.memory_bank.bank_id)
|
|
|
|
if not bank_ids:
|
|
# this can happen if the per-session memory bank is not yet populated
|
|
# (i.e., no prior turns uploaded an Attachment)
|
|
return None, []
|
|
|
|
query = await generate_rag_query(
|
|
memory.query_generator_config, messages, inference_api=self.inference_api
|
|
)
|
|
tasks = [
|
|
self.memory_api.query_documents(
|
|
bank_id=bank_id,
|
|
query=query,
|
|
params={
|
|
"max_chunks": 5,
|
|
},
|
|
)
|
|
for bank_id in bank_ids
|
|
]
|
|
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
|
chunks = [c for r in results for c in r.chunks]
|
|
scores = [s for r in results for s in r.scores]
|
|
|
|
# sort by score
|
|
chunks, scores = zip(
|
|
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
|
)
|
|
if not chunks:
|
|
return None, bank_ids
|
|
|
|
tokens = 0
|
|
picked = []
|
|
for c in chunks[: memory.max_chunks]:
|
|
tokens += c.token_count
|
|
if tokens > memory.max_tokens_in_context:
|
|
cprint(
|
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
|
"red",
|
|
)
|
|
break
|
|
picked.append(f"id:{c.document_id}; content:{c.content}")
|
|
|
|
return [
|
|
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
|
*picked,
|
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
|
], bank_ids
|
|
|
|
def _get_tools(self) -> List[ToolDefinition]:
|
|
ret = []
|
|
for t in self.agent_config.tools:
|
|
if isinstance(t, SearchToolDefinition):
|
|
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
|
elif isinstance(t, WolframAlphaToolDefinition):
|
|
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
|
elif isinstance(t, PhotogenToolDefinition):
|
|
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
|
|
elif isinstance(t, CodeInterpreterToolDefinition):
|
|
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
|
|
elif isinstance(t, FunctionCallToolDefinition):
|
|
ret.append(
|
|
ToolDefinition(
|
|
tool_name=t.function_name,
|
|
description=t.description,
|
|
parameters=t.parameters,
|
|
)
|
|
)
|
|
return ret
|
|
|
|
|
|
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
|
content = []
|
|
|
|
for url in urls:
|
|
uri = url.uri
|
|
if uri.startswith("file://"):
|
|
filepath = uri[len("file://") :]
|
|
elif uri.startswith("http"):
|
|
path = urlparse(uri).path
|
|
basename = os.path.basename(path)
|
|
filepath = f"{tempdir}/{make_random_string() + basename}"
|
|
print(f"Downloading {url} -> {filepath}")
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.get(uri)
|
|
resp = r.text
|
|
with open(filepath, "w") as fp:
|
|
fp.write(resp)
|
|
else:
|
|
raise ValueError(f"Unsupported URL {url}")
|
|
|
|
content.append(f'# There is a file accessible to you at "{filepath}"\n')
|
|
|
|
return ToolResponseMessage(
|
|
call_id="",
|
|
tool_name=BuiltinTool.code_interpreter,
|
|
content=content,
|
|
)
|
|
|
|
|
|
async def execute_tool_call_maybe(
|
|
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
|
) -> List[ToolResponseMessage]:
|
|
# While Tools.run interface takes a list of messages,
|
|
# All tools currently only run on a single message
|
|
# When this changes, we can drop this assert
|
|
# Whether to call tools on each message and aggregate
|
|
# or aggregate and call tool once, reamins to be seen.
|
|
assert len(messages) == 1, "Expected single message"
|
|
message = messages[0]
|
|
|
|
tool_call = message.tool_calls[0]
|
|
name = tool_call.tool_name
|
|
assert isinstance(name, BuiltinTool)
|
|
|
|
name = name.value
|
|
|
|
assert name in tools_dict, f"Tool {name} not found"
|
|
tool = tools_dict[name]
|
|
result_messages = await tool.run(messages)
|
|
return result_messages
|
|
|
|
|
|
def print_dialog(messages: List[Message]):
|
|
for i, m in enumerate(messages):
|
|
if m.role == Role.user.value:
|
|
color = "red"
|
|
elif m.role == Role.assistant.value:
|
|
color = "white"
|
|
elif m.role == Role.ipython.value:
|
|
color = "yellow"
|
|
elif m.role == Role.system.value:
|
|
color = "green"
|
|
else:
|
|
color = "white"
|
|
|
|
s = str(m)
|
|
cprint(f"{i} ::: {s[:100]}...", color=color)
|