mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
add documents to turn
This commit is contained in:
parent
d0e8e1647b
commit
9efe30c9d3
9 changed files with 887 additions and 381 deletions
File diff suppressed because one or more lines are too long
|
@ -3974,6 +3974,41 @@
|
|||
"stream": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/URL"
|
||||
}
|
||||
]
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"content",
|
||||
"mime_type"
|
||||
]
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
|
|
@ -618,6 +618,25 @@ components:
|
|||
properties:
|
||||
agent_id:
|
||||
type: string
|
||||
documents:
|
||||
items:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
content:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/InterleavedContentItem'
|
||||
- items:
|
||||
$ref: '#/components/schemas/InterleavedContentItem'
|
||||
type: array
|
||||
- $ref: '#/components/schemas/URL'
|
||||
mime_type:
|
||||
type: string
|
||||
required:
|
||||
- content
|
||||
- mime_type
|
||||
type: object
|
||||
type: array
|
||||
messages:
|
||||
items:
|
||||
oneOf:
|
||||
|
|
|
@ -45,6 +45,11 @@ class Attachment(BaseModel):
|
|||
mime_type: str
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
turn_id: str
|
||||
step_id: str
|
||||
|
@ -272,6 +277,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
]
|
||||
]
|
||||
|
||||
documents: Optional[List[Document]] = None
|
||||
tools: Optional[List[AgentTool]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
|
@ -308,6 +316,7 @@ class Agents(Protocol):
|
|||
]
|
||||
],
|
||||
stream: Optional[bool] = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools: Optional[List[AgentTool]] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
|
|
|
@ -33,13 +33,18 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
Attachment,
|
||||
Document,
|
||||
InferenceStep,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
|
@ -55,8 +60,8 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
@ -190,6 +195,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents,
|
||||
tools_for_turn=request.tools,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
|
@ -240,6 +246,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
|
@ -257,7 +264,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn
|
||||
session_id,
|
||||
turn_id,
|
||||
input_messages,
|
||||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
tools_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -352,6 +365,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
) -> AsyncGenerator:
|
||||
tool_args = {}
|
||||
|
@ -361,6 +375,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_args[tool.name] = tool.args
|
||||
|
||||
tool_defs = await self._get_tool_defs(tools_for_turn)
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
if "memory" in tool_defs and len(input_messages) > 0:
|
||||
with tracing.span("memory_tool") as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
|
@ -378,6 +393,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"query": input_messages[-1],
|
||||
**extra_args,
|
||||
}
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
args["memory_bank_id"] = session_info.memory_bank_id
|
||||
serialized_args = tracing.serialize_value(args)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -732,6 +752,112 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
return ret
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
session_id: str,
|
||||
documents: List[Document],
|
||||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get("memory", None)
|
||||
code_interpreter_tool = tool_defs.get("code_interpreter", None)
|
||||
if documents:
|
||||
content_items = [
|
||||
d for d in documents if isinstance(d.content, InterleavedContent)
|
||||
]
|
||||
url_items = [d for d in documents if isinstance(d.content, URL)]
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
url_items = [
|
||||
URL(uri=a.content) for a in url_items if pattern.match(a.content)
|
||||
]
|
||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||
if code_interpreter_tool:
|
||||
for c in content_items:
|
||||
temp_file_path = os.path.join(
|
||||
self.tempdir, f"{make_random_string()}.txt"
|
||||
)
|
||||
with open(temp_file_path, "w") as temp_file:
|
||||
temp_file.write(c.content)
|
||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||
|
||||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
else:
|
||||
# if no memory or code_interpreter tool is available,
|
||||
# we try to load the data from the URLs and content items as a message to inference
|
||||
# and add it to the last message's context
|
||||
input_messages[-1].context = content_items + load_data_from_urls(
|
||||
url_items
|
||||
)
|
||||
|
||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
await self.memory_banks_api.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
||||
return bank_id
|
||||
|
||||
async def add_to_session_memory_bank(
|
||||
self, session_id: str, data: List[Document]
|
||||
) -> None:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for a in data
|
||||
]
|
||||
await self.memory_api.insert_documents(
|
||||
bank_id=bank_id,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
|
||||
async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
||||
data = []
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
if uri.startswith("file://"):
|
||||
filepath = uri[len("file://") :]
|
||||
with open(filepath, "r") as f:
|
||||
data.append(f.read())
|
||||
elif uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
resp = r.text
|
||||
data.append(resp)
|
||||
return data
|
||||
|
||||
|
||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||
content = []
|
||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
|||
AgentStepResponse,
|
||||
AgentTool,
|
||||
AgentTurnCreateRequest,
|
||||
Document,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
@ -147,6 +148,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
]
|
||||
],
|
||||
tools: Optional[List[AgentTool]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
|
@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
messages=messages,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
documents=documents,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
|
@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
|
|||
class AgentSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
session_name: str
|
||||
memory_bank_id: Optional[str] = None
|
||||
started_at: datetime
|
||||
|
||||
|
||||
|
@ -51,6 +52,17 @@ class AgentPersistence:
|
|||
|
||||
return AgentSessionInfo(**json.loads(value))
|
||||
|
||||
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session_info.memory_bank_id = bank_id
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
Document,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolChoice,
|
||||
|
@ -22,8 +23,6 @@ from llama_stack.apis.agents import (
|
|||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
|
||||
from llama_stack.apis.memory import MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import ViolationLevel
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
@ -232,8 +231,6 @@ class TestAgents:
|
|||
common_params,
|
||||
):
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
memory_banks_impl = agents_stack.impls[Api.memory_banks]
|
||||
memory_impl = agents_stack.impls[Api.memory]
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
|
@ -243,28 +240,12 @@ class TestAgents:
|
|||
"lora_finetune.rst",
|
||||
]
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=f"num-{i}",
|
||||
Document(
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
await memory_banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
provider_id="faiss",
|
||||
)
|
||||
memory_impl.insert_documents(
|
||||
bank_id="test_bank",
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
|
@ -278,6 +259,7 @@ class TestAgents:
|
|||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=attachment_message,
|
||||
documents=documents,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [
|
||||
|
|
|
@ -203,6 +203,79 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
|||
assert "Tool:code_interpreter Response" in logs_str
|
||||
|
||||
|
||||
def test_code_execution(llama_stack_client):
|
||||
agent_config = AgentConfig(
|
||||
model="meta-llama/Llama-3.1-70B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
tools=[
|
||||
"brave_search",
|
||||
"code_interpreter",
|
||||
],
|
||||
tool_choice="required",
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
memory_bank_id = "inflation_data_memory_bank"
|
||||
llama_stack_client.memory_banks.register(
|
||||
memory_bank_id=memory_bank_id,
|
||||
params={
|
||||
"memory_bank_type": "vector",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size_in_tokens": 512,
|
||||
"overlap_size_in_tokens": 64,
|
||||
},
|
||||
)
|
||||
AugmentConfigWithMemoryTool(agent_config, llama_stack_client)
|
||||
codex_agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = codex_agent.create_session("test-session")
|
||||
|
||||
llama_stack_client.memory.insert(
|
||||
bank_id=memory_bank_id,
|
||||
documents=[
|
||||
Document(
|
||||
document_id="inflation",
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
mime_type="text/csv",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
user_prompts = [
|
||||
{
|
||||
"prompt": "Can you describe the data in the context?",
|
||||
"tools": [{"name": "memory", "args": {"memory_bank_id": memory_bank_id}}],
|
||||
},
|
||||
{
|
||||
"prompt": "Plot average yearly inflation as a time series",
|
||||
"tools": [
|
||||
{"name": "memory", "args": {"memory_bank_id": memory_bank_id}},
|
||||
"code_interpreter",
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for input in user_prompts:
|
||||
print(f'User> {input["prompt"]}')
|
||||
response = codex_agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input["prompt"],
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
tools=input["tools"],
|
||||
)
|
||||
# for chunk in response:
|
||||
# print(chunk)
|
||||
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
|
||||
|
||||
def test_custom_tool(llama_stack_client, agent_config):
|
||||
client_tool = TestClientTool()
|
||||
agent_config = {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue