add documents to turn

This commit is contained in:
Dinesh Yeduguru 2025-01-06 11:40:22 -08:00
parent d0e8e1647b
commit 9efe30c9d3
9 changed files with 887 additions and 381 deletions

File diff suppressed because one or more lines are too long

View file

@ -3974,6 +3974,41 @@
"stream": {
"type": "boolean"
},
"documents": {
"type": "array",
"items": {
"type": "object",
"properties": {
"content": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/InterleavedContentItem"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
},
{
"$ref": "#/components/schemas/URL"
}
]
},
"mime_type": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"content",
"mime_type"
]
}
},
"tools": {
"type": "array",
"items": {

View file

@ -618,6 +618,25 @@ components:
properties:
agent_id:
type: string
documents:
items:
additionalProperties: false
properties:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/InterleavedContentItem'
- items:
$ref: '#/components/schemas/InterleavedContentItem'
type: array
- $ref: '#/components/schemas/URL'
mime_type:
type: string
required:
- content
- mime_type
type: object
type: array
messages:
items:
oneOf:

View file

@ -45,6 +45,11 @@ class Attachment(BaseModel):
mime_type: str
class Document(BaseModel):
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
turn_id: str
step_id: str
@ -272,6 +277,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
]
]
documents: Optional[List[Document]] = None
tools: Optional[List[AgentTool]] = None
stream: Optional[bool] = False
@ -308,6 +316,7 @@ class Agents(Protocol):
]
],
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
tools: Optional[List[AgentTool]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...

View file

@ -33,13 +33,18 @@ from llama_stack.apis.agents import (
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
Attachment,
Document,
InferenceStep,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
)
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
@ -55,8 +60,8 @@ from llama_stack.apis.inference import (
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.memory import Memory, MemoryBankDocument
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.providers.utils.kvstore import KVStore
@ -190,6 +195,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
tools_for_turn=request.tools,
):
if isinstance(chunk, CompletionMessage):
@ -240,6 +246,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
@ -257,7 +264,13 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn
session_id,
turn_id,
input_messages,
sampling_params,
stream,
documents,
tools_for_turn,
):
if isinstance(res, bool):
return
@ -352,6 +365,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
) -> AsyncGenerator:
tool_args = {}
@ -361,6 +375,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_args[tool.name] = tool.args
tool_defs = await self._get_tool_defs(tools_for_turn)
await self.handle_documents(session_id, documents, input_messages, tool_defs)
if "memory" in tool_defs and len(input_messages) > 0:
with tracing.span("memory_tool") as span:
step_id = str(uuid.uuid4())
@ -378,6 +393,11 @@ class ChatAgent(ShieldRunnerMixin):
"query": input_messages[-1],
**extra_args,
}
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(args)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -732,6 +752,112 @@ class ChatAgent(ShieldRunnerMixin):
return ret
async def handle_documents(
self,
session_id: str,
documents: List[Document],
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get("memory", None)
code_interpreter_tool = tool_defs.get("code_interpreter", None)
if documents:
content_items = [
d for d in documents if isinstance(d.content, InterleavedContent)
]
url_items = [d for d in documents if isinstance(d.content, URL)]
pattern = re.compile("^(https?://|file://|data:)")
url_items = [
URL(uri=a.content) for a in url_items if pattern.match(a.content)
]
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = content_items + load_data_from_urls(
url_items
)
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
return bank_id
async def add_to_session_memory_bank(
self, session_id: str, data: List[Document]
) -> None:
bank_id = await self._ensure_memory_bank(session_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in data
]
await self.memory_api.insert_documents(
bank_id=bank_id,
documents=documents,
)
async def load_data_from_urls(urls: List[URL]) -> List[str]:
data = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
with open(filepath, "r") as f:
data.append(f.read())
elif uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
data.append(resp)
return data
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
content = []

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
AgentStepResponse,
AgentTool,
AgentTurnCreateRequest,
Document,
Session,
Turn,
)
@ -147,6 +148,7 @@ class MetaReferenceAgentsImpl(Agents):
]
],
tools: Optional[List[AgentTool]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents):
messages=messages,
stream=True,
tools=tools,
documents=documents,
)
if stream:
return self._create_agent_turn_streaming(request)

View file

@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
started_at: datetime
@ -51,6 +52,17 @@ class AgentPersistence:
return AgentSessionInfo(**json.loads(value))
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",

View file

@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
Document,
ShieldCallStep,
StepType,
ToolChoice,
@ -22,8 +23,6 @@ from llama_stack.apis.agents import (
Turn,
)
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory_banks import VectorMemoryBankParams
from llama_stack.apis.safety import ViolationLevel
from llama_stack.providers.datatypes import Api
@ -232,8 +231,6 @@ class TestAgents:
common_params,
):
agents_impl = agents_stack.impls[Api.agents]
memory_banks_impl = agents_stack.impls[Api.memory_banks]
memory_impl = agents_stack.impls[Api.memory]
urls = [
"memory_optimizations.rst",
"chat.rst",
@ -243,28 +240,12 @@ class TestAgents:
"lora_finetune.rst",
]
documents = [
MemoryBankDocument(
document_id=f"num-{i}",
Document(
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
await memory_banks_impl.register_memory_bank(
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
provider_id="faiss",
)
memory_impl.insert_documents(
bank_id="test_bank",
documents=documents,
)
agent_config = AgentConfig(
**{
**common_params,
@ -278,6 +259,7 @@ class TestAgents:
agent_id=agent_id,
session_id=session_id,
messages=attachment_message,
documents=documents,
stream=True,
)
turn_response = [

View file

@ -203,6 +203,79 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
assert "Tool:code_interpreter Response" in logs_str
def test_code_execution(llama_stack_client):
agent_config = AgentConfig(
model="meta-llama/Llama-3.1-70B-Instruct",
instructions="You are a helpful assistant",
tools=[
"brave_search",
"code_interpreter",
],
tool_choice="required",
input_shields=[],
output_shields=[],
enable_session_persistence=False,
)
memory_bank_id = "inflation_data_memory_bank"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
)
AugmentConfigWithMemoryTool(agent_config, llama_stack_client)
codex_agent = Agent(llama_stack_client, agent_config)
session_id = codex_agent.create_session("test-session")
llama_stack_client.memory.insert(
bank_id=memory_bank_id,
documents=[
Document(
document_id="inflation",
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="text/csv",
metadata={},
)
],
)
user_prompts = [
{
"prompt": "Can you describe the data in the context?",
"tools": [{"name": "memory", "args": {"memory_bank_id": memory_bank_id}}],
},
{
"prompt": "Plot average yearly inflation as a time series",
"tools": [
{"name": "memory", "args": {"memory_bank_id": memory_bank_id}},
"code_interpreter",
],
},
]
for input in user_prompts:
print(f'User> {input["prompt"]}')
response = codex_agent.create_turn(
messages=[
{
"role": "user",
"content": input["prompt"],
}
],
session_id=session_id,
tools=input["tools"],
)
# for chunk in response:
# print(chunk)
for log in EventLogger().log(response):
log.print()
def test_custom_tool(llama_stack_client, agent_config):
client_tool = TestClientTool()
agent_config = {