diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 14278b803..88ae91906 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -39,7 +39,6 @@ from llama_stack.apis.safety import SafetyViolation from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -@json_schema_type class Attachment(BaseModel): content: InterleavedContent | URL mime_type: str @@ -258,7 +257,6 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): ToolResponseMessage, ] ] - attachments: Optional[List[Attachment]] = None stream: Optional[bool] = False @@ -295,7 +293,6 @@ class Agents(Protocol): ToolResponseMessage, ] ], - attachments: Optional[List[Attachment]] = None, stream: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1ecb95e68..6cf031bf7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -188,7 +188,6 @@ class ChatAgent(ShieldRunnerMixin): session_id=request.session_id, turn_id=turn_id, input_messages=messages, - attachments=request.attachments or [], sampling_params=self.agent_config.sampling_params, stream=request.stream, ): @@ -238,7 +237,6 @@ class ChatAgent(ShieldRunnerMixin): session_id: str, turn_id: str, input_messages: List[Message], - attachments: List[Attachment], sampling_params: SamplingParams, stream: bool = False, ) -> AsyncGenerator: @@ -257,7 +255,7 @@ class ChatAgent(ShieldRunnerMixin): yield res async for res in self._run( - session_id, turn_id, input_messages, attachments, sampling_params, stream + session_id, turn_id, input_messages, sampling_params, stream ): if isinstance(res, bool): return @@ -350,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin): session_id: str, turn_id: str, input_messages: List[Message], - attachments: List[Attachment], sampling_params: SamplingParams, stream: bool = False, ) -> AsyncGenerator: @@ -370,7 +367,6 @@ class ChatAgent(ShieldRunnerMixin): session_id=session_id, turn_id=turn_id, input_messages=input_messages, - attachments=attachments, ) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -423,7 +419,10 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("output", result.content) span.set_attribute("error_code", result.error_code) span.set_attribute("error_message", result.error_message) - span.set_attribute("tool_name", tool_name) + if isinstance(tool_name, BuiltinTool): + span.set_attribute("tool_name", tool_name.value) + else: + span.set_attribute("tool_name", tool_name) if result.error_code == 0: last_message = input_messages[-1] last_message.context = result.content @@ -553,9 +552,9 @@ class ChatAgent(ShieldRunnerMixin): # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) if len(output_attachments) > 0: if isinstance(message.content, list): - message.content += attachments + message.content += output_attachments else: - message.content = [message.content] + attachments + message.content = [message.content] + output_attachments yield message else: log.info(f"Partial message: {str(message)}") @@ -586,10 +585,13 @@ class ChatAgent(ShieldRunnerMixin): ) ) + tool_name = tool_call.tool_name + if isinstance(tool_name, BuiltinTool): + tool_name = tool_name.value with tracing.span( "tool_execution", { - "tool_name": tool_call.tool_name, + "tool_name": tool_name, "input": message.model_dump_json(), }, ) as span: @@ -608,6 +610,7 @@ class ChatAgent(ShieldRunnerMixin): event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( step_type=StepType.tool_execution.value, + step_id=step_id, step_details=ToolExecutionStep( step_id=step_id, turn_id=turn_id, diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 89b38a7fc..5769c42e5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -146,14 +146,12 @@ class MetaReferenceAgentsImpl(Agents): ToolResponseMessage, ] ], - attachments: Optional[List[Attachment]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, session_id=session_id, messages=messages, - attachments=attachments, stream=True, ) if stream: diff --git a/llama_stack/providers/inline/tool_runtime/memory/__init__.py b/llama_stack/providers/inline/tool_runtime/memory/__init__.py index 36377f147..928afa484 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/memory/__init__.py @@ -8,11 +8,11 @@ from typing import Any, Dict from llama_stack.providers.datatypes import Api -from .config import MemoryToolConfig +from .config import MemoryToolRuntimeConfig from .memory import MemoryToolRuntimeImpl -async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]): +async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]): impl = MemoryToolRuntimeImpl( config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference] ) diff --git a/llama_stack/providers/inline/tool_runtime/memory/config.py b/llama_stack/providers/inline/tool_runtime/memory/config.py index cb24883dc..6ff242c6b 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/config.py +++ b/llama_stack/providers/inline/tool_runtime/memory/config.py @@ -7,9 +7,6 @@ from enum import Enum from typing import Annotated, List, Literal, Union -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR -from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig - from pydantic import BaseModel, Field @@ -81,13 +78,13 @@ MemoryQueryGeneratorConfig = Annotated[ class MemoryToolConfig(BaseModel): memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) + + +class MemoryToolRuntimeConfig(BaseModel): # This config defines how a query is generated using the messages # for memory bank retrieval. query_generator_config: MemoryQueryGeneratorConfig = Field( default=DefaultMemoryQueryGeneratorConfig() ) max_tokens_in_context: int = 4096 - max_chunks: int = 10 - kvstore_config: KVStoreConfig = SqliteKVStoreConfig( - db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix() - ) + max_chunks: int = 5 diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index da97cb3a3..7ee751a17 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from jinja2 import Template @@ -23,7 +22,7 @@ from .config import ( async def generate_rag_query( config: MemoryQueryGeneratorConfig, - messages: List[Message], + message: Message, **kwargs, ): """ @@ -31,9 +30,9 @@ async def generate_rag_query( retrieving relevant information from the memory bank. """ if config.type == MemoryQueryGenerator.default.value: - query = await default_rag_query_generator(config, messages, **kwargs) + query = await default_rag_query_generator(config, message, **kwargs) elif config.type == MemoryQueryGenerator.llm.value: - query = await llm_rag_query_generator(config, messages, **kwargs) + query = await llm_rag_query_generator(config, message, **kwargs) else: raise NotImplementedError(f"Unsupported memory query generator {config.type}") return query @@ -41,21 +40,21 @@ async def generate_rag_query( async def default_rag_query_generator( config: DefaultMemoryQueryGeneratorConfig, - messages: List[Message], + message: Message, **kwargs, ): - return config.sep.join(interleaved_content_as_str(m.content) for m in messages) + return interleaved_content_as_str(message.content) async def llm_rag_query_generator( config: LLMMemoryQueryGeneratorConfig, - messages: List[Message], + message: Message, **kwargs, ): assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" inference_api = kwargs["inference_api"] - m_dict = {"messages": [m.model_dump() for m in messages]} + m_dict = {"messages": [message.model_dump()]} template = Template(config.template) content = template.render(m_dict) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index 3a08bf1f9..d492309cd 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -5,24 +5,14 @@ # the root directory of this source tree. import asyncio -import json import logging -import os -import re import secrets import string -import tempfile -import uuid from typing import Any, Dict, List, Optional -from urllib.parse import urlparse -import httpx - -from llama_stack.apis.agents import Attachment -from llama_stack.apis.common.content_types import TextContentItem, URL from llama_stack.apis.inference import Inference, InterleavedContent, Message -from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse -from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams +from llama_stack.apis.memory import Memory, QueryDocumentsResponse +from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.tools import ( ToolDef, ToolGroupDef, @@ -30,22 +20,14 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.providers.datatypes import ToolsProtocolPrivate -from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content -from pydantic import BaseModel -from .config import MemoryToolConfig +from .config import MemoryToolConfig, MemoryToolRuntimeConfig from .context_retriever import generate_rag_query log = logging.getLogger(__name__) -class MemorySessionInfo(BaseModel): - session_id: str - session_name: str - memory_bank_id: Optional[str] = None - - def make_random_string(length: int = 8): return "".join( secrets.choice(string.ascii_letters + string.digits) for _ in range(length) @@ -55,7 +37,7 @@ def make_random_string(length: int = 8): class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): def __init__( self, - config: MemoryToolConfig, + config: MemoryToolRuntimeConfig, memory_api: Memory, memory_banks_api: MemoryBanks, inference_api: Inference, @@ -63,113 +45,26 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): self.config = config self.memory_api = memory_api self.memory_banks_api = memory_banks_api - self.tempdir = tempfile.mkdtemp() self.inference_api = inference_api async def initialize(self): - self.kvstore = await kvstore_impl(self.config.kvstore_config) + pass async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: return [] - async def create_session(self, session_id: str) -> MemorySessionInfo: - session_info = MemorySessionInfo( - session_id=session_id, - session_name=f"session_{session_id}", - ) - await self.kvstore.set( - key=f"memory::session:{session_id}", - value=session_info.model_dump_json(), - ) - return session_info - - async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]: - value = await self.kvstore.get( - key=f"memory::session:{session_id}", - ) - if not value: - session_info = await self.create_session(session_id) - return session_info - - return MemorySessionInfo(**json.loads(value)) - - async def add_memory_bank_to_session(self, session_id: str, bank_id: str): - session_info = await self.get_session_info(session_id) - - session_info.memory_bank_id = bank_id - await self.kvstore.set( - key=f"memory::session:{session_id}", - value=session_info.model_dump_json(), - ) - - async def _ensure_memory_bank(self, session_id: str) -> str: - session_info = await self.get_session_info(session_id) - - if session_info.memory_bank_id is None: - bank_id = f"memory_bank_{session_id}" - await self.memory_banks_api.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ), - ) - await self.add_memory_bank_to_session(session_id, bank_id) - else: - bank_id = session_info.memory_bank_id - - return bank_id - - async def attachment_message( - self, tempdir: str, urls: List[URL] - ) -> List[TextContentItem]: - content = [] - - for url in urls: - uri = url.uri - if uri.startswith("file://"): - filepath = uri[len("file://") :] - elif uri.startswith("http"): - path = urlparse(uri).path - basename = os.path.basename(path) - filepath = f"{tempdir}/{make_random_string() + basename}" - log.info(f"Downloading {url} -> {filepath}") - - async with httpx.AsyncClient() as client: - r = await client.get(uri) - resp = r.text - with open(filepath, "w") as fp: - fp.write(resp) - else: - raise ValueError(f"Unsupported URL {url}") - - content.append( - TextContentItem( - text=f'# There is a file accessible to you at "{filepath}"\n' - ) - ) - - return content - async def _retrieve_context( - self, session_id: str, messages: List[Message] + self, messages: List[Message], bank_ids: List[str] ) -> Optional[List[InterleavedContent]]: - bank_ids = [] - - bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs) - - session_info = await self.get_session_info(session_id) - if session_info.memory_bank_id: - bank_ids.append(session_info.memory_bank_id) - if not bank_ids: - # this can happen if the per-session memory bank is not yet populated - # (i.e., no prior turns uploaded an Attachment) + return None + if len(messages) == 0: return None + message = messages[-1] # only use the last message as input to the query query = await generate_rag_query( self.config.query_generator_config, - messages, + message, inference_api=self.inference_api, ) tasks = [ @@ -177,7 +72,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): bank_id=bank_id, query=query, params={ - "max_chunks": 5, + "max_chunks": self.config.max_chunks, }, ) for bank_id in bank_ids @@ -211,43 +106,20 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): "\n=== END-RETRIEVED-CONTEXT ===\n", ] - async def _process_attachments( - self, session_id: str, attachments: List[Attachment] - ): - bank_id = await self._ensure_memory_bank(session_id) - - documents = [ - MemoryBankDocument( - document_id=str(uuid.uuid4()), - content=a.content, - mime_type=a.mime_type, - metadata={}, - ) - for a in attachments - if isinstance(a.content, str) - ] - await self.memory_api.insert_documents(bank_id, documents) - - urls = [a.content for a in attachments if isinstance(a.content, URL)] - # TODO: we need to migrate URL away from str type - pattern = re.compile("^(https?://|file://|data:)") - urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)] - return await self.attachment_message(self.tempdir, urls) - async def invoke_tool( self, tool_name: str, args: Dict[str, Any] ) -> ToolInvocationResult: - if args["session_id"] is None: - raise ValueError("session_id is required") + tool = await self.tool_store.get_tool(tool_name) + config = MemoryToolConfig() + if tool.metadata.get("config") is not None: + config = MemoryToolConfig(**tool.metadata["config"]) context = await self._retrieve_context( - args["session_id"], args["input_messages"] + args["input_messages"], + [bank_config.bank_id for bank_config in config.memory_bank_configs], ) if context is None: context = [] - attachments = args["attachments"] - if attachments and len(attachments) > 0: - context += await self._process_attachments(args["session_id"], attachments) return ToolInvocationResult( content=concat_interleaved_content(context), error_code=0 ) diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index b6b34edf0..d6e892599 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -22,7 +22,7 @@ def available_providers() -> List[ProviderSpec]: provider_type="inline::memory-runtime", pip_packages=[], module="llama_stack.providers.inline.tool_runtime.memory", - config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig", + config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", api_dependencies=[Api.memory, Api.memory_banks, Api.inference], ), InlineProviderSpec( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 10aaa09b5..7f8b5b26b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -21,7 +21,7 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage class TestCustomTool(CustomTool): """Tool to give boiling point of a liquid - Returns the correct value for water in Celcius and Fahrenheit + Returns the correct value for polyjuice in Celcius and Fahrenheit and returns -1 for other liquids """ @@ -50,7 +50,7 @@ class TestCustomTool(CustomTool): return "get_boiling_point" def get_description(self) -> str: - return "Get the boiling point of a imaginary liquids (eg. polyjuice)" + return "Get the boiling point of imaginary liquids (eg. polyjuice)" def get_params_definition(self) -> Dict[str, Parameter]: return { @@ -279,7 +279,6 @@ def test_rag_agent(llama_stack_client, agent_config): "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.", "Was anything related to 'Llama3' discussed, if so what?", "Tell me how to use LoRA", - "What about Quantization?", ] for prompt in user_prompts: