Avoid deleting temp directory between agent turns

This brings an interesting aspect -- we need to maintain session-level
tempdir state (!) since the model was told there was some resource at a
given location that it needs to maintain
This commit is contained in:
Ashwin Bharambe 2024-12-08 22:25:37 -08:00
parent d7dc69c8a9
commit 5335393fe3
3 changed files with 33 additions and 11 deletions

View file

@ -11,7 +11,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
from llama_stack_client.lib.inference.event_logger import EventLogger from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import UserMessage from llama_stack_client.types import Attachment, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agent_create_params import AgentConfig
@ -67,9 +67,15 @@ def main(config_path: str):
] ]
if os.getenv("BRAVE_SEARCH_API_KEY") if os.getenv("BRAVE_SEARCH_API_KEY")
else [] else []
)
+ (
[
{
"type": "code_interpreter",
}
]
), ),
tool_choice="auto", tool_choice="required",
tool_prompt_format="json",
input_shields=[], input_shields=[],
output_shields=[], output_shields=[],
enable_session_persistence=False, enable_session_persistence=False,
@ -79,10 +85,27 @@ def main(config_path: str):
"Hello", "Hello",
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools", "Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
] ]
user_prompts = [
(
"Here is a csv, can you describe it ?",
[
Attachment(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="test/csv",
)
],
),
("Which year ended with the highest inflation ?", None),
(
"What macro economic situations that led to such high inflation in that period?",
None,
),
("Plot average yearly inflation as a time series", None),
]
session_id = agent.create_session("test-session") session_id = agent.create_session("test-session")
for prompt in user_prompts: for prompt, attachments in user_prompts:
response = agent.create_turn( response = agent.create_turn(
messages=[ messages=[
{ {
@ -90,6 +113,7 @@ def main(config_path: str):
"content": prompt, "content": prompt,
} }
], ],
attachments=attachments,
session_id=session_id, session_id=session_id,
) )

View file

@ -10,9 +10,7 @@ import logging
import os import os
import re import re
import secrets import secrets
import shutil
import string import string
import tempfile
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, List, Tuple
@ -57,6 +55,7 @@ class ChatAgent(ShieldRunnerMixin):
self, self,
agent_id: str, agent_id: str,
agent_config: AgentConfig, agent_config: AgentConfig,
tempdir: str,
inference_api: Inference, inference_api: Inference,
memory_api: Memory, memory_api: Memory,
memory_banks_api: MemoryBanks, memory_banks_api: MemoryBanks,
@ -65,14 +64,13 @@ class ChatAgent(ShieldRunnerMixin):
): ):
self.agent_id = agent_id self.agent_id = agent_id
self.agent_config = agent_config self.agent_config = agent_config
self.tempdir = tempdir
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api self.memory_api = memory_api
self.memory_banks_api = memory_banks_api self.memory_banks_api = memory_banks_api
self.safety_api = safety_api self.safety_api = safety_api
self.storage = AgentPersistence(agent_id, persistence_store) self.storage = AgentPersistence(agent_id, persistence_store)
self.tempdir = tempfile.mkdtemp()
builtin_tools = [] builtin_tools = []
for tool_defn in agent_config.tools: for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition): if isinstance(tool_defn, WolframAlphaToolDefinition):
@ -103,9 +101,6 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields, output_shields=agent_config.output_shields,
) )
def __del__(self):
shutil.rmtree(self.tempdir)
def turn_to_messages(self, turn: Turn) -> List[Message]: def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = [] messages = []

View file

@ -7,6 +7,7 @@
import json import json
import logging import logging
import shutil import shutil
import tempfile
import uuid import uuid
from typing import AsyncGenerator from typing import AsyncGenerator
@ -43,6 +44,7 @@ class MetaReferenceAgentsImpl(Agents):
self.memory_banks_api = memory_banks_api self.memory_banks_api = memory_banks_api
self.in_memory_store = InmemoryKVStoreImpl() self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
async def initialize(self) -> None: async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store) self.persistence_store = await kvstore_impl(self.config.persistence_store)
@ -94,6 +96,7 @@ class MetaReferenceAgentsImpl(Agents):
return ChatAgent( return ChatAgent(
agent_id=agent_id, agent_id=agent_id,
agent_config=agent_config, agent_config=agent_config,
tempdir=self.tempdir,
inference_api=self.inference_api, inference_api=self.inference_api,
safety_api=self.safety_api, safety_api=self.safety_api,
memory_api=self.memory_api, memory_api=self.memory_api,