Avoid deleting temp directory between agent turns

This brings an interesting aspect -- we need to maintain session-level
tempdir state (!) since the model was told there was some resource at a
given location that it needs to maintain
This commit is contained in:
Ashwin Bharambe 2024-12-08 22:25:37 -08:00
parent d7dc69c8a9
commit 5335393fe3
3 changed files with 33 additions and 11 deletions

View file

@ -11,7 +11,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import UserMessage
from llama_stack_client.types import Attachment, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
@ -67,9 +67,15 @@ def main(config_path: str):
]
if os.getenv("BRAVE_SEARCH_API_KEY")
else []
)
+ (
[
{
"type": "code_interpreter",
}
]
),
tool_choice="auto",
tool_prompt_format="json",
tool_choice="required",
input_shields=[],
output_shields=[],
enable_session_persistence=False,
@ -79,10 +85,27 @@ def main(config_path: str):
"Hello",
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
]
user_prompts = [
(
"Here is a csv, can you describe it ?",
[
Attachment(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="test/csv",
)
],
),
("Which year ended with the highest inflation ?", None),
(
"What macro economic situations that led to such high inflation in that period?",
None,
),
("Plot average yearly inflation as a time series", None),
]
session_id = agent.create_session("test-session")
for prompt in user_prompts:
for prompt, attachments in user_prompts:
response = agent.create_turn(
messages=[
{
@ -90,6 +113,7 @@ def main(config_path: str):
"content": prompt,
}
],
attachments=attachments,
session_id=session_id,
)

View file

@ -10,9 +10,7 @@ import logging
import os
import re
import secrets
import shutil
import string
import tempfile
import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Tuple
@ -57,6 +55,7 @@ class ChatAgent(ShieldRunnerMixin):
self,
agent_id: str,
agent_config: AgentConfig,
tempdir: str,
inference_api: Inference,
memory_api: Memory,
memory_banks_api: MemoryBanks,
@ -65,14 +64,13 @@ class ChatAgent(ShieldRunnerMixin):
):
self.agent_id = agent_id
self.agent_config = agent_config
self.tempdir = tempdir
self.inference_api = inference_api
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api
self.storage = AgentPersistence(agent_id, persistence_store)
self.tempdir = tempfile.mkdtemp()
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
@ -103,9 +101,6 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
def __del__(self):
shutil.rmtree(self.tempdir)
def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = []

View file

@ -7,6 +7,7 @@
import json
import logging
import shutil
import tempfile
import uuid
from typing import AsyncGenerator
@ -43,6 +44,7 @@ class MetaReferenceAgentsImpl(Agents):
self.memory_banks_api = memory_banks_api
self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
@ -94,6 +96,7 @@ class MetaReferenceAgentsImpl(Agents):
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
tempdir=self.tempdir,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,