mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
add documents to turn
This commit is contained in:
parent
d0e8e1647b
commit
9efe30c9d3
9 changed files with 887 additions and 381 deletions
File diff suppressed because one or more lines are too long
|
@ -3974,6 +3974,41 @@
|
||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
"documents": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/URL"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mime_type": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"content",
|
||||||
|
"mime_type"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
"tools": {
|
"tools": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
|
|
|
@ -618,6 +618,25 @@ components:
|
||||||
properties:
|
properties:
|
||||||
agent_id:
|
agent_id:
|
||||||
type: string
|
type: string
|
||||||
|
documents:
|
||||||
|
items:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
content:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- $ref: '#/components/schemas/InterleavedContentItem'
|
||||||
|
- items:
|
||||||
|
$ref: '#/components/schemas/InterleavedContentItem'
|
||||||
|
type: array
|
||||||
|
- $ref: '#/components/schemas/URL'
|
||||||
|
mime_type:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- content
|
||||||
|
- mime_type
|
||||||
|
type: object
|
||||||
|
type: array
|
||||||
messages:
|
messages:
|
||||||
items:
|
items:
|
||||||
oneOf:
|
oneOf:
|
||||||
|
|
|
@ -45,6 +45,11 @@ class Attachment(BaseModel):
|
||||||
mime_type: str
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class Document(BaseModel):
|
||||||
|
content: InterleavedContent | URL
|
||||||
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
class StepCommon(BaseModel):
|
class StepCommon(BaseModel):
|
||||||
turn_id: str
|
turn_id: str
|
||||||
step_id: str
|
step_id: str
|
||||||
|
@ -272,6 +277,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
documents: Optional[List[Document]] = None
|
||||||
|
tools: Optional[List[AgentTool]] = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -308,6 +316,7 @@ class Agents(Protocol):
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
tools: Optional[List[AgentTool]] = None,
|
tools: Optional[List[AgentTool]] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
|
|
|
@ -33,13 +33,18 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
Attachment,
|
Attachment,
|
||||||
|
Document,
|
||||||
InferenceStep,
|
InferenceStep,
|
||||||
ShieldCallStep,
|
ShieldCallStep,
|
||||||
StepType,
|
StepType,
|
||||||
ToolExecutionStep,
|
ToolExecutionStep,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
TextContentItem,
|
||||||
|
URL,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
@ -55,8 +60,8 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
@ -190,6 +195,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=self.agent_config.sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
|
documents=request.documents,
|
||||||
tools_for_turn=request.tools,
|
tools_for_turn=request.tools,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
|
@ -240,6 +246,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
|
@ -257,7 +264,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
async for res in self._run(
|
async for res in self._run(
|
||||||
session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
input_messages,
|
||||||
|
sampling_params,
|
||||||
|
stream,
|
||||||
|
documents,
|
||||||
|
tools_for_turn,
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
@ -352,6 +365,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
tool_args = {}
|
tool_args = {}
|
||||||
|
@ -361,6 +375,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_args[tool.name] = tool.args
|
tool_args[tool.name] = tool.args
|
||||||
|
|
||||||
tool_defs = await self._get_tool_defs(tools_for_turn)
|
tool_defs = await self._get_tool_defs(tools_for_turn)
|
||||||
|
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||||
if "memory" in tool_defs and len(input_messages) > 0:
|
if "memory" in tool_defs and len(input_messages) > 0:
|
||||||
with tracing.span("memory_tool") as span:
|
with tracing.span("memory_tool") as span:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
|
@ -378,6 +393,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
"query": input_messages[-1],
|
"query": input_messages[-1],
|
||||||
**extra_args,
|
**extra_args,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
|
if session_info.memory_bank_id:
|
||||||
|
args["memory_bank_id"] = session_info.memory_bank_id
|
||||||
serialized_args = tracing.serialize_value(args)
|
serialized_args = tracing.serialize_value(args)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -732,6 +752,112 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
async def handle_documents(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
documents: List[Document],
|
||||||
|
input_messages: List[Message],
|
||||||
|
tool_defs: Dict[str, ToolDefinition],
|
||||||
|
) -> None:
|
||||||
|
memory_tool = tool_defs.get("memory", None)
|
||||||
|
code_interpreter_tool = tool_defs.get("code_interpreter", None)
|
||||||
|
if documents:
|
||||||
|
content_items = [
|
||||||
|
d for d in documents if isinstance(d.content, InterleavedContent)
|
||||||
|
]
|
||||||
|
url_items = [d for d in documents if isinstance(d.content, URL)]
|
||||||
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
url_items = [
|
||||||
|
URL(uri=a.content) for a in url_items if pattern.match(a.content)
|
||||||
|
]
|
||||||
|
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||||
|
if code_interpreter_tool:
|
||||||
|
for c in content_items:
|
||||||
|
temp_file_path = os.path.join(
|
||||||
|
self.tempdir, f"{make_random_string()}.txt"
|
||||||
|
)
|
||||||
|
with open(temp_file_path, "w") as temp_file:
|
||||||
|
temp_file.write(c.content)
|
||||||
|
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||||
|
|
||||||
|
if memory_tool and code_interpreter_tool:
|
||||||
|
# if both memory and code_interpreter are available, we download the URLs
|
||||||
|
# and attach the data to the last message.
|
||||||
|
msg = await attachment_message(self.tempdir, url_items)
|
||||||
|
input_messages.append(msg)
|
||||||
|
# Since memory is present, add all the data to the memory bank
|
||||||
|
await self.add_to_session_memory_bank(session_id, documents)
|
||||||
|
elif code_interpreter_tool:
|
||||||
|
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||||
|
# and attach the path to them as a message to inference with the
|
||||||
|
# assumption that the model invokes the code_interpreter tool with the path
|
||||||
|
msg = await attachment_message(self.tempdir, url_items)
|
||||||
|
input_messages.append(msg)
|
||||||
|
elif memory_tool:
|
||||||
|
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||||
|
await self.add_to_session_memory_bank(session_id, documents)
|
||||||
|
else:
|
||||||
|
# if no memory or code_interpreter tool is available,
|
||||||
|
# we try to load the data from the URLs and content items as a message to inference
|
||||||
|
# and add it to the last message's context
|
||||||
|
input_messages[-1].context = content_items + load_data_from_urls(
|
||||||
|
url_items
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||||
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
if session_info.memory_bank_id is None:
|
||||||
|
bank_id = f"memory_bank_{session_id}"
|
||||||
|
await self.memory_banks_api.register_memory_bank(
|
||||||
|
memory_bank_id=bank_id,
|
||||||
|
params=VectorMemoryBankParams(
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||||
|
else:
|
||||||
|
bank_id = session_info.memory_bank_id
|
||||||
|
|
||||||
|
return bank_id
|
||||||
|
|
||||||
|
async def add_to_session_memory_bank(
|
||||||
|
self, session_id: str, data: List[Document]
|
||||||
|
) -> None:
|
||||||
|
bank_id = await self._ensure_memory_bank(session_id)
|
||||||
|
documents = [
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id=str(uuid.uuid4()),
|
||||||
|
content=a.content,
|
||||||
|
mime_type=a.mime_type,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for a in data
|
||||||
|
]
|
||||||
|
await self.memory_api.insert_documents(
|
||||||
|
bank_id=bank_id,
|
||||||
|
documents=documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
||||||
|
data = []
|
||||||
|
for url in urls:
|
||||||
|
uri = url.uri
|
||||||
|
if uri.startswith("file://"):
|
||||||
|
filepath = uri[len("file://") :]
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
data.append(f.read())
|
||||||
|
elif uri.startswith("http"):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(uri)
|
||||||
|
resp = r.text
|
||||||
|
data.append(resp)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||||
content = []
|
content = []
|
||||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
AgentTool,
|
AgentTool,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
|
Document,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
@ -147,6 +148,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
tools: Optional[List[AgentTool]] = None,
|
tools: Optional[List[AgentTool]] = None,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
|
@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
documents=documents,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
|
|
@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
|
||||||
class AgentSessionInfo(BaseModel):
|
class AgentSessionInfo(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
session_name: str
|
session_name: str
|
||||||
|
memory_bank_id: Optional[str] = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,6 +52,17 @@ class AgentPersistence:
|
||||||
|
|
||||||
return AgentSessionInfo(**json.loads(value))
|
return AgentSessionInfo(**json.loads(value))
|
||||||
|
|
||||||
|
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
||||||
|
session_info = await self.get_session_info(session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
session_info.memory_bank_id = bank_id
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
value=session_info.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseStepCompletePayload,
|
AgentTurnResponseStepCompletePayload,
|
||||||
AgentTurnResponseStreamChunk,
|
AgentTurnResponseStreamChunk,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
|
Document,
|
||||||
ShieldCallStep,
|
ShieldCallStep,
|
||||||
StepType,
|
StepType,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
@ -22,8 +23,6 @@ from llama_stack.apis.agents import (
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
|
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
|
||||||
from llama_stack.apis.memory import MemoryBankDocument
|
|
||||||
from llama_stack.apis.memory_banks import VectorMemoryBankParams
|
|
||||||
from llama_stack.apis.safety import ViolationLevel
|
from llama_stack.apis.safety import ViolationLevel
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
@ -232,8 +231,6 @@ class TestAgents:
|
||||||
common_params,
|
common_params,
|
||||||
):
|
):
|
||||||
agents_impl = agents_stack.impls[Api.agents]
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
memory_banks_impl = agents_stack.impls[Api.memory_banks]
|
|
||||||
memory_impl = agents_stack.impls[Api.memory]
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
"chat.rst",
|
"chat.rst",
|
||||||
|
@ -243,28 +240,12 @@ class TestAgents:
|
||||||
"lora_finetune.rst",
|
"lora_finetune.rst",
|
||||||
]
|
]
|
||||||
documents = [
|
documents = [
|
||||||
MemoryBankDocument(
|
Document(
|
||||||
document_id=f"num-{i}",
|
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
metadata={},
|
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
await memory_banks_impl.register_memory_bank(
|
|
||||||
memory_bank_id="test_bank",
|
|
||||||
params=VectorMemoryBankParams(
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
overlap_size_in_tokens=64,
|
|
||||||
),
|
|
||||||
provider_id="faiss",
|
|
||||||
)
|
|
||||||
memory_impl.insert_documents(
|
|
||||||
bank_id="test_bank",
|
|
||||||
documents=documents,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
**{
|
**{
|
||||||
**common_params,
|
**common_params,
|
||||||
|
@ -278,6 +259,7 @@ class TestAgents:
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=attachment_message,
|
messages=attachment_message,
|
||||||
|
documents=documents,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
turn_response = [
|
turn_response = [
|
||||||
|
|
|
@ -203,6 +203,79 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||||
assert "Tool:code_interpreter Response" in logs_str
|
assert "Tool:code_interpreter Response" in logs_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_execution(llama_stack_client):
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="meta-llama/Llama-3.1-70B-Instruct",
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
tools=[
|
||||||
|
"brave_search",
|
||||||
|
"code_interpreter",
|
||||||
|
],
|
||||||
|
tool_choice="required",
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
enable_session_persistence=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_bank_id = "inflation_data_memory_bank"
|
||||||
|
llama_stack_client.memory_banks.register(
|
||||||
|
memory_bank_id=memory_bank_id,
|
||||||
|
params={
|
||||||
|
"memory_bank_type": "vector",
|
||||||
|
"embedding_model": "all-MiniLM-L6-v2",
|
||||||
|
"chunk_size_in_tokens": 512,
|
||||||
|
"overlap_size_in_tokens": 64,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
AugmentConfigWithMemoryTool(agent_config, llama_stack_client)
|
||||||
|
codex_agent = Agent(llama_stack_client, agent_config)
|
||||||
|
session_id = codex_agent.create_session("test-session")
|
||||||
|
|
||||||
|
llama_stack_client.memory.insert(
|
||||||
|
bank_id=memory_bank_id,
|
||||||
|
documents=[
|
||||||
|
Document(
|
||||||
|
document_id="inflation",
|
||||||
|
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||||
|
mime_type="text/csv",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompts = [
|
||||||
|
{
|
||||||
|
"prompt": "Can you describe the data in the context?",
|
||||||
|
"tools": [{"name": "memory", "args": {"memory_bank_id": memory_bank_id}}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "Plot average yearly inflation as a time series",
|
||||||
|
"tools": [
|
||||||
|
{"name": "memory", "args": {"memory_bank_id": memory_bank_id}},
|
||||||
|
"code_interpreter",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for input in user_prompts:
|
||||||
|
print(f'User> {input["prompt"]}')
|
||||||
|
response = codex_agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input["prompt"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
tools=input["tools"],
|
||||||
|
)
|
||||||
|
# for chunk in response:
|
||||||
|
# print(chunk)
|
||||||
|
|
||||||
|
for log in EventLogger().log(response):
|
||||||
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
def test_custom_tool(llama_stack_client, agent_config):
|
def test_custom_tool(llama_stack_client, agent_config):
|
||||||
client_tool = TestClientTool()
|
client_tool = TestClientTool()
|
||||||
agent_config = {
|
agent_config = {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue