mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:22:25 +00:00
add documents to turn
This commit is contained in:
parent
d0e8e1647b
commit
9efe30c9d3
9 changed files with 887 additions and 381 deletions
|
|
@ -33,13 +33,18 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
Attachment,
|
||||
Document,
|
||||
InferenceStep,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
|
|
@ -55,8 +60,8 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
|
@ -190,6 +195,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents,
|
||||
tools_for_turn=request.tools,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
|
|
@ -240,6 +246,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
|
|
@ -257,7 +264,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn
|
||||
session_id,
|
||||
turn_id,
|
||||
input_messages,
|
||||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
tools_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -352,6 +365,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
) -> AsyncGenerator:
|
||||
tool_args = {}
|
||||
|
|
@ -361,6 +375,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_args[tool.name] = tool.args
|
||||
|
||||
tool_defs = await self._get_tool_defs(tools_for_turn)
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
if "memory" in tool_defs and len(input_messages) > 0:
|
||||
with tracing.span("memory_tool") as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
|
|
@ -378,6 +393,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"query": input_messages[-1],
|
||||
**extra_args,
|
||||
}
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
args["memory_bank_id"] = session_info.memory_bank_id
|
||||
serialized_args = tracing.serialize_value(args)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
@ -732,6 +752,112 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
return ret
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
session_id: str,
|
||||
documents: List[Document],
|
||||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get("memory", None)
|
||||
code_interpreter_tool = tool_defs.get("code_interpreter", None)
|
||||
if documents:
|
||||
content_items = [
|
||||
d for d in documents if isinstance(d.content, InterleavedContent)
|
||||
]
|
||||
url_items = [d for d in documents if isinstance(d.content, URL)]
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
url_items = [
|
||||
URL(uri=a.content) for a in url_items if pattern.match(a.content)
|
||||
]
|
||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||
if code_interpreter_tool:
|
||||
for c in content_items:
|
||||
temp_file_path = os.path.join(
|
||||
self.tempdir, f"{make_random_string()}.txt"
|
||||
)
|
||||
with open(temp_file_path, "w") as temp_file:
|
||||
temp_file.write(c.content)
|
||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||
|
||||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
else:
|
||||
# if no memory or code_interpreter tool is available,
|
||||
# we try to load the data from the URLs and content items as a message to inference
|
||||
# and add it to the last message's context
|
||||
input_messages[-1].context = content_items + load_data_from_urls(
|
||||
url_items
|
||||
)
|
||||
|
||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
await self.memory_banks_api.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
||||
return bank_id
|
||||
|
||||
async def add_to_session_memory_bank(
|
||||
self, session_id: str, data: List[Document]
|
||||
) -> None:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for a in data
|
||||
]
|
||||
await self.memory_api.insert_documents(
|
||||
bank_id=bank_id,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
|
||||
async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
||||
data = []
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
if uri.startswith("file://"):
|
||||
filepath = uri[len("file://") :]
|
||||
with open(filepath, "r") as f:
|
||||
data.append(f.read())
|
||||
elif uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
resp = r.text
|
||||
data.append(resp)
|
||||
return data
|
||||
|
||||
|
||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||
content = []
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
|||
AgentStepResponse,
|
||||
AgentTool,
|
||||
AgentTurnCreateRequest,
|
||||
Document,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
|
@ -147,6 +148,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
]
|
||||
],
|
||||
tools: Optional[List[AgentTool]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
|
|
@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
messages=messages,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
documents=documents,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
|
|||
class AgentSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
session_name: str
|
||||
memory_bank_id: Optional[str] = None
|
||||
started_at: datetime
|
||||
|
||||
|
||||
|
|
@ -51,6 +52,17 @@ class AgentPersistence:
|
|||
|
||||
return AgentSessionInfo(**json.loads(value))
|
||||
|
||||
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session_info.memory_bank_id = bank_id
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue