diff --git a/llama_stack/distribution/tests/library_client_test.py b/llama_stack/distribution/tests/library_client_test.py index 5e7b997f3..955640c2b 100644 --- a/llama_stack/distribution/tests/library_client_test.py +++ b/llama_stack/distribution/tests/library_client_test.py @@ -11,7 +11,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger from llama_stack_client.lib.inference.event_logger import EventLogger -from llama_stack_client.types import UserMessage +from llama_stack_client.types import Attachment, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig @@ -67,9 +67,15 @@ def main(config_path: str): ] if os.getenv("BRAVE_SEARCH_API_KEY") else [] + ) + + ( + [ + { + "type": "code_interpreter", + } + ] ), - tool_choice="auto", - tool_prompt_format="json", + tool_choice="required", input_shields=[], output_shields=[], enable_session_persistence=False, @@ -79,10 +85,27 @@ def main(config_path: str): "Hello", "Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools", ] + user_prompts = [ + ( + "Here is a csv, can you describe it ?", + [ + Attachment( + content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", + mime_type="test/csv", + ) + ], + ), + ("Which year ended with the highest inflation ?", None), + ( + "What macro economic situations that led to such high inflation in that period?", + None, + ), + ("Plot average yearly inflation as a time series", None), + ] session_id = agent.create_session("test-session") - for prompt in user_prompts: + for prompt, attachments in user_prompts: response = agent.create_turn( messages=[ { @@ -90,6 +113,7 @@ def main(config_path: str): "content": prompt, } ], + attachments=attachments, session_id=session_id, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 7df5d3bd4..e367f3c41 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -10,9 +10,7 @@ import logging import os import re import secrets -import shutil import string -import tempfile import uuid from datetime import datetime from typing import AsyncGenerator, List, Tuple @@ -57,6 +55,7 @@ class ChatAgent(ShieldRunnerMixin): self, agent_id: str, agent_config: AgentConfig, + tempdir: str, inference_api: Inference, memory_api: Memory, memory_banks_api: MemoryBanks, @@ -65,14 +64,13 @@ class ChatAgent(ShieldRunnerMixin): ): self.agent_id = agent_id self.agent_config = agent_config + self.tempdir = tempdir self.inference_api = inference_api self.memory_api = memory_api self.memory_banks_api = memory_banks_api self.safety_api = safety_api self.storage = AgentPersistence(agent_id, persistence_store) - self.tempdir = tempfile.mkdtemp() - builtin_tools = [] for tool_defn in agent_config.tools: if isinstance(tool_defn, WolframAlphaToolDefinition): @@ -103,9 +101,6 @@ class ChatAgent(ShieldRunnerMixin): output_shields=agent_config.output_shields, ) - def __del__(self): - shutil.rmtree(self.tempdir) - def turn_to_messages(self, turn: Turn) -> List[Message]: messages = [] diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 0b0bb6e27..dec5ec960 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -7,6 +7,7 @@ import json import logging import shutil +import tempfile import uuid from typing import AsyncGenerator @@ -43,6 +44,7 @@ class MetaReferenceAgentsImpl(Agents): self.memory_banks_api = memory_banks_api self.in_memory_store = InmemoryKVStoreImpl() + self.tempdir = tempfile.mkdtemp() async def initialize(self) -> None: self.persistence_store = await kvstore_impl(self.config.persistence_store) @@ -94,6 +96,7 @@ class MetaReferenceAgentsImpl(Agents): return ChatAgent( agent_id=agent_id, agent_config=agent_config, + tempdir=self.tempdir, inference_api=self.inference_api, safety_api=self.safety_api, memory_api=self.memory_api,