mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Avoid deleting temp directory between agent turns
This brings an interesting aspect -- we need to maintain session-level tempdir state (!) since the model was told there was some resource at a given location that it needs to maintain
This commit is contained in:
parent
d7dc69c8a9
commit
5335393fe3
3 changed files with 33 additions and 11 deletions
|
@ -11,7 +11,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
|
||||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
from llama_stack_client.lib.inference.event_logger import EventLogger
|
||||||
from llama_stack_client.types import UserMessage
|
from llama_stack_client.types import Attachment, UserMessage
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,9 +67,15 @@ def main(config_path: str):
|
||||||
]
|
]
|
||||||
if os.getenv("BRAVE_SEARCH_API_KEY")
|
if os.getenv("BRAVE_SEARCH_API_KEY")
|
||||||
else []
|
else []
|
||||||
|
)
|
||||||
|
+ (
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "code_interpreter",
|
||||||
|
}
|
||||||
|
]
|
||||||
),
|
),
|
||||||
tool_choice="auto",
|
tool_choice="required",
|
||||||
tool_prompt_format="json",
|
|
||||||
input_shields=[],
|
input_shields=[],
|
||||||
output_shields=[],
|
output_shields=[],
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
|
@ -79,10 +85,27 @@ def main(config_path: str):
|
||||||
"Hello",
|
"Hello",
|
||||||
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
|
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
|
||||||
]
|
]
|
||||||
|
user_prompts = [
|
||||||
|
(
|
||||||
|
"Here is a csv, can you describe it ?",
|
||||||
|
[
|
||||||
|
Attachment(
|
||||||
|
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||||
|
mime_type="test/csv",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
("Which year ended with the highest inflation ?", None),
|
||||||
|
(
|
||||||
|
"What macro economic situations that led to such high inflation in that period?",
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
("Plot average yearly inflation as a time series", None),
|
||||||
|
]
|
||||||
|
|
||||||
session_id = agent.create_session("test-session")
|
session_id = agent.create_session("test-session")
|
||||||
|
|
||||||
for prompt in user_prompts:
|
for prompt, attachments in user_prompts:
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
@ -90,6 +113,7 @@ def main(config_path: str):
|
||||||
"content": prompt,
|
"content": prompt,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
attachments=attachments,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import shutil
|
|
||||||
import string
|
import string
|
||||||
import tempfile
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Tuple
|
from typing import AsyncGenerator, List, Tuple
|
||||||
|
@ -57,6 +55,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
|
tempdir: str,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
memory_banks_api: MemoryBanks,
|
memory_banks_api: MemoryBanks,
|
||||||
|
@ -65,14 +64,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
|
self.tempdir = tempdir
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
self.memory_banks_api = memory_banks_api
|
self.memory_banks_api = memory_banks_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
|
|
||||||
self.tempdir = tempfile.mkdtemp()
|
|
||||||
|
|
||||||
builtin_tools = []
|
builtin_tools = []
|
||||||
for tool_defn in agent_config.tools:
|
for tool_defn in agent_config.tools:
|
||||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
||||||
|
@ -103,9 +101,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
output_shields=agent_config.output_shields,
|
output_shields=agent_config.output_shields,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
shutil.rmtree(self.tempdir)
|
|
||||||
|
|
||||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
@ -43,6 +44,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.memory_banks_api = memory_banks_api
|
self.memory_banks_api = memory_banks_api
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
|
self.tempdir = tempfile.mkdtemp()
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||||
|
@ -94,6 +96,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
return ChatAgent(
|
return ChatAgent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
|
tempdir=self.tempdir,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
memory_api=self.memory_api,
|
memory_api=self.memory_api,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue