diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5748b4e41..65be92348 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -14,18 +14,16 @@ from typing import ( Literal, Optional, Protocol, - runtime_checkable, Union, + runtime_checkable, ) from llama_models.llama3.api.datatypes import ToolParamDefinition - from llama_models.schema_utils import json_schema_type, webmethod - from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig from llama_stack.apis.inference import ( CompletionMessage, @@ -40,7 +38,6 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.memory import MemoryBank from llama_stack.apis.safety import SafetyViolation - from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -110,85 +107,6 @@ class FunctionCallToolDefinition(ToolDefinitionCommon): remote_execution: Optional[RestAPIExecutionConfig] = None -class _MemoryBankConfigCommon(BaseModel): - bank_id: str - - -class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["vector"] = "vector" - - -class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyvalue"] = "keyvalue" - keys: List[str] # what keys to focus on - - -class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyword"] = "keyword" - - -class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["graph"] = "graph" - entities: List[str] # what entities to focus on - - -MemoryBankConfig = Annotated[ - Union[ - AgentVectorMemoryBankConfig, - AgentKeyValueMemoryBankConfig, - AgentKeywordMemoryBankConfig, - AgentGraphMemoryBankConfig, - ], - Field(discriminator="type"), -] - - -class MemoryQueryGenerator(Enum): - default = "default" - llm = "llm" - custom = "custom" - - -class DefaultMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.default.value] = ( - MemoryQueryGenerator.default.value - ) - sep: str = " " - - -class LLMMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value - model: str - template: str - - -class CustomMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value - - -MemoryQueryGeneratorConfig = Annotated[ - Union[ - DefaultMemoryQueryGeneratorConfig, - LLMMemoryQueryGeneratorConfig, - CustomMemoryQueryGeneratorConfig, - ], - Field(discriminator="type"), -] - - -@json_schema_type -class MemoryToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.memory.value] = AgentTool.memory.value - memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) - # This config defines how a query is generated using the messages - # for memory bank retrieval. - query_generator_config: MemoryQueryGeneratorConfig = Field( - default=DefaultMemoryQueryGeneratorConfig() - ) - max_tokens_in_context: int = 4096 - max_chunks: int = 10 - - AgentToolDefinition = Annotated[ Union[ SearchToolDefinition, @@ -196,7 +114,6 @@ AgentToolDefinition = Annotated[ PhotogenToolDefinition, CodeInterpreterToolDefinition, FunctionCallToolDefinition, - MemoryToolDefinition, ], Field(discriminator="type"), ] @@ -295,7 +212,11 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list) + tools: Optional[List[AgentToolDefinition]] = Field( + default_factory=list, deprecated=True + ) + available_tools: Optional[List[str]] = Field(default_factory=list) + preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 23110543b..60b2bdab9 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -68,10 +68,16 @@ ToolGroupDef = register_schema( Annotated[ Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type") ], - name="ToolGroup", + name="ToolGroupDef", ) +class ToolGroupInput(BaseModel): + tool_group_id: str + tool_group: ToolGroupDef + provider_id: Optional[str] = None + + class ToolGroup(Resource): type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index dec62bfae..ba7ba62bd 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -161,6 +161,7 @@ a default SQLite store will be used.""", datasets: List[DatasetInput] = Field(default_factory=list) scoring_fns: List[ScoringFnInput] = Field(default_factory=list) eval_tasks: List[EvalTaskInput] = Field(default_factory=list) + tool_groups: List[ToolGroupInput] = Field(default_factory=list) class BuildConfig(BaseModel): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 0a6eed345..3ea93301f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -5,9 +5,7 @@ # the root directory of this source tree. import importlib import inspect - import logging - from typing import Any, Dict, List, Set from llama_stack.apis.agents import Agents @@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.distribution.client import get_client_impl - from llama_stack.distribution.datatypes import ( AutoRoutedProviderSpec, Provider, @@ -38,7 +35,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type - +from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import ( Api, DatasetsProtocolPrivate, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ab1becfdd..8d622a5c2 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -523,6 +523,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): ) provider_id = list(self.impls_by_provider_id.keys())[0] + # parse tool group to the type if dict + tool_group = parse_obj_as(ToolGroupDef, tool_group) if isinstance(tool_group, MCPToolGroupDef): tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( tool_group diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 7fc2c7650..9d12303c9 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -12,7 +12,7 @@ from typing import Any, Dict, Optional import pkg_resources import yaml - +from llama_models.llama3.api.datatypes import * # noqa: F403 from termcolor import colored from llama_stack.apis.agents import Agents @@ -33,14 +33,12 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry - from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.providers.datatypes import Api - log = logging.getLogger(__name__) LLAMA_STACK_API_VERSION = "alpha" @@ -81,6 +79,7 @@ RESOURCES = [ "list_scoring_functions", ), ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), + ("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"), ] diff --git a/llama_stack/llama_stack/providers/tests/agents/conftest.py b/llama_stack/llama_stack/providers/tests/agents/conftest.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/llama_stack/providers/tests/agents/conftest.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 156de9a17..50f61fb42 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -22,6 +22,8 @@ async def get_provider_impl( deps[Api.memory], deps[Api.safety], deps[Api.memory_banks], + deps[Api.tool_runtime], + deps[Api.tool_groups], ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 09738d7b7..00d8bbd36 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -4,25 +4,21 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import copy import logging import os -import re import secrets import string import uuid from datetime import datetime -from typing import AsyncGenerator, Dict, List, Optional, Tuple +from typing import AsyncGenerator, Dict, List from urllib.parse import urlparse import httpx - from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, - AgentTool, AgentTurnCreateRequest, AgentTurnResponseEvent, AgentTurnResponseEventType, @@ -36,8 +32,6 @@ from llama_stack.apis.agents import ( CodeInterpreterToolDefinition, FunctionCallToolDefinition, InferenceStep, - MemoryRetrievalStep, - MemoryToolDefinition, PhotogenToolDefinition, SearchToolDefinition, ShieldCallStep, @@ -46,11 +40,9 @@ from llama_stack.apis.agents import ( Turn, WolframAlphaToolDefinition, ) - from llama_stack.apis.common.content_types import ( - InterleavedContent, - TextContentItem, URL, + TextContentItem, ) from llama_stack.apis.inference import ( ChatCompletionResponseEventType, @@ -62,30 +54,26 @@ from llama_stack.apis.inference import ( SystemMessage, ToolCallDelta, ToolCallParseStatus, - ToolChoice, ToolDefinition, ToolResponse, ToolResponseMessage, UserMessage, ) -from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse -from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety - from llama_stack.providers.utils.kvstore import KVStore -from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence -from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin from .tools.base import BaseTool from .tools.builtin import ( CodeInterpreterTool, - interpret_content_as_attachment, PhotogenTool, SearchTool, WolframAlphaTool, + interpret_content_as_attachment, ) from .tools.safety import SafeTool @@ -108,6 +96,8 @@ class ChatAgent(ShieldRunnerMixin): memory_api: Memory, memory_banks_api: MemoryBanks, safety_api: Safety, + tool_runtime_api: ToolRuntime, + tool_groups_api: ToolGroups, persistence_store: KVStore, ): self.agent_id = agent_id @@ -118,6 +108,8 @@ class ChatAgent(ShieldRunnerMixin): self.memory_banks_api = memory_banks_api self.safety_api = safety_api self.storage = AgentPersistence(agent_id, persistence_store) + self.tool_runtime_api = tool_runtime_api + self.tool_groups_api = tool_groups_api builtin_tools = [] for tool_defn in agent_config.tools: @@ -392,62 +384,50 @@ class ChatAgent(ShieldRunnerMixin): sampling_params: SamplingParams, stream: bool = False, ) -> AsyncGenerator: - enabled_tools = set(t.type for t in self.agent_config.tools) - need_rag_context = await self._should_retrieve_context( - input_messages, attachments - ) - if need_rag_context: - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.memory_retrieval.value, - step_id=step_id, + if self.agent_config.preprocessing_tools: + with tracing.span("preprocessing_tools") as span: + for tool_name in self.agent_config.preprocessing_tools: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( + step_type=StepType.tool_execution.value, + step_id=str(uuid.uuid4()), + ) + ) ) - ) - ) - - # TODO: find older context from the session and either replace it - # or append with a sliding window. this is really a very simplistic implementation - with tracing.span("retrieve_rag_context") as span: - rag_context, bank_ids = await self._retrieve_context( - session_id, input_messages, attachments - ) - span.set_attribute( - "input", [m.model_dump_json() for m in input_messages] - ) - span.set_attribute("output", rag_context) - span.set_attribute("bank_ids", bank_ids) - - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.memory_retrieval.value, - step_id=step_id, - step_details=MemoryRetrievalStep( - turn_id=turn_id, - step_id=step_id, - memory_bank_ids=bank_ids, - inserted_context=rag_context or "", - ), + args = dict( + session_id=session_id, + input_messages=input_messages, + attachments=attachments, ) - ) - ) - - if rag_context: - last_message = input_messages[-1] - last_message.context = rag_context - - elif attachments and AgentTool.code_interpreter.value in enabled_tools: - urls = [a.content for a in attachments if isinstance(a.content, URL)] - # TODO: we need to migrate URL away from str type - pattern = re.compile("^(https?://|file://|data:)") - urls += [ - URL(uri=a.content) for a in attachments if pattern.match(a.content) - ] - msg = await attachment_message(self.tempdir, urls) - input_messages.append(msg) + result = await self.tool_runtime_api.invoke_tool( + tool_name=tool_name, + args=args, + ) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.tool_execution.value, + step_id=str(uuid.uuid4()), + tool_call_delta=ToolCallDelta( + parse_status=ToolCallParseStatus.success, + content=ToolCall( + call_id="", tool_name=tool_name, arguments={} + ), + ), + ) + ) + ) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute("output", result.content) + span.set_attribute("error_code", result.error_code) + span.set_attribute("error_message", result.error_message) + span.set_attribute("tool_name", tool_name) + if result.error_code != 0 and result.content: + last_message = input_messages[-1] + last_message.context = result.content output_attachments = [] @@ -659,129 +639,6 @@ class ChatAgent(ShieldRunnerMixin): n_iter += 1 - async def _ensure_memory_bank(self, session_id: str) -> str: - session_info = await self.storage.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") - - if session_info.memory_bank_id is None: - bank_id = f"memory_bank_{session_id}" - await self.memory_banks_api.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ), - ) - await self.storage.add_memory_bank_to_session(session_id, bank_id) - else: - bank_id = session_info.memory_bank_id - - return bank_id - - async def _should_retrieve_context( - self, messages: List[Message], attachments: List[Attachment] - ) -> bool: - enabled_tools = set(t.type for t in self.agent_config.tools) - if attachments: - if ( - AgentTool.code_interpreter.value in enabled_tools - and self.agent_config.tool_choice == ToolChoice.required - ): - return False - else: - return True - - return AgentTool.memory.value in enabled_tools - - def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: - for t in self.agent_config.tools: - if t.type == AgentTool.memory.value: - return t - - return None - - async def _retrieve_context( - self, session_id: str, messages: List[Message], attachments: List[Attachment] - ) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids) - bank_ids = [] - - memory = self._memory_tool_definition() - assert memory is not None, "Memory tool not configured" - bank_ids.extend(c.bank_id for c in memory.memory_bank_configs) - - if attachments: - bank_id = await self._ensure_memory_bank(session_id) - bank_ids.append(bank_id) - - documents = [ - MemoryBankDocument( - document_id=str(uuid.uuid4()), - content=a.content, - mime_type=a.mime_type, - metadata={}, - ) - for a in attachments - ] - with tracing.span("insert_documents"): - await self.memory_api.insert_documents(bank_id, documents) - else: - session_info = await self.storage.get_session_info(session_id) - if session_info.memory_bank_id: - bank_ids.append(session_info.memory_bank_id) - - if not bank_ids: - # this can happen if the per-session memory bank is not yet populated - # (i.e., no prior turns uploaded an Attachment) - return None, [] - - query = await generate_rag_query( - memory.query_generator_config, messages, inference_api=self.inference_api - ) - tasks = [ - self.memory_api.query_documents( - bank_id=bank_id, - query=query, - params={ - "max_chunks": 5, - }, - ) - for bank_id in bank_ids - ] - results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) - chunks = [c for r in results for c in r.chunks] - scores = [s for r in results for s in r.scores] - - if not chunks: - return None, bank_ids - - # sort by score - chunks, scores = zip( - *sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True) - ) - - tokens = 0 - picked = [] - for c in chunks[: memory.max_chunks]: - tokens += c.token_count - if tokens > memory.max_tokens_in_context: - log.error( - f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", - ) - break - picked.append(f"id:{c.document_id}; content:{c.content}") - - return ( - concat_interleaved_content( - [ - "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", - *picked, - "\n=== END-RETRIEVED-CONTEXT ===\n", - ] - ), - bank_ids, - ) - def _get_tools(self) -> List[ToolDefinition]: ret = [] for t in self.agent_config.tools: diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 93bfab5f4..89b38a7fc 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -24,12 +24,11 @@ from llama_stack.apis.agents import ( Session, Turn, ) - from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety - +from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from .agent_instance import ChatAgent @@ -47,12 +46,16 @@ class MetaReferenceAgentsImpl(Agents): memory_api: Memory, safety_api: Safety, memory_banks_api: MemoryBanks, + tool_runtime_api: ToolRuntime, + tool_groups_api: ToolGroups, ): self.config = config self.inference_api = inference_api self.memory_api = memory_api self.safety_api = safety_api self.memory_banks_api = memory_banks_api + self.tool_runtime_api = tool_runtime_api + self.tool_groups_api = tool_groups_api self.in_memory_store = InmemoryKVStoreImpl() self.tempdir = tempfile.mkdtemp() @@ -112,6 +115,8 @@ class MetaReferenceAgentsImpl(Agents): safety_api=self.safety_api, memory_api=self.memory_api, memory_banks_api=self.memory_banks_api, + tool_runtime_api=self.tool_runtime_api, + tool_groups_api=self.tool_groups_api, persistence_store=( self.persistence_store if agent_config.enable_session_persistence diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index a4b1af616..144f65863 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -8,13 +8,11 @@ import json import logging import uuid from datetime import datetime - from typing import List, Optional from pydantic import BaseModel from llama_stack.apis.agents import Turn - from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -23,7 +21,6 @@ log = logging.getLogger(__name__) class AgentSessionInfo(BaseModel): session_id: str session_name: str - memory_bank_id: Optional[str] = None started_at: datetime @@ -54,17 +51,6 @@ class AgentPersistence: return AgentSessionInfo(**json.loads(value)) - async def add_memory_bank_to_session(self, session_id: str, bank_id: str): - session_info = await self.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") - - session_info.memory_bank_id = bank_id - await self.kvstore.set( - key=f"session:{self.agent_id}:{session_id}", - value=session_info.model_dump_json(), - ) - async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", diff --git a/llama_stack/providers/inline/tool_runtime/memory/__init__.py b/llama_stack/providers/inline/tool_runtime/memory/__init__.py new file mode 100644 index 000000000..36377f147 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/memory/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from llama_stack.providers.datatypes import Api + +from .config import MemoryToolConfig +from .memory import MemoryToolRuntimeImpl + + +async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]): + impl = MemoryToolRuntimeImpl( + config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference] + ) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/memory/config.py b/llama_stack/providers/inline/tool_runtime/memory/config.py new file mode 100644 index 000000000..cb24883dc --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/memory/config.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Annotated, List, Literal, Union + +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig + +from pydantic import BaseModel, Field + + +class _MemoryBankConfigCommon(BaseModel): + bank_id: str + + +class VectorMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal["vector"] = "vector" + + +class KeyValueMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal["keyvalue"] = "keyvalue" + keys: List[str] # what keys to focus on + + +class KeywordMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal["keyword"] = "keyword" + + +class GraphMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal["graph"] = "graph" + entities: List[str] # what entities to focus on + + +MemoryBankConfig = Annotated[ + Union[ + VectorMemoryBankConfig, + KeyValueMemoryBankConfig, + KeywordMemoryBankConfig, + GraphMemoryBankConfig, + ], + Field(discriminator="type"), +] + + +class MemoryQueryGenerator(Enum): + default = "default" + llm = "llm" + custom = "custom" + + +class DefaultMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.default.value] = ( + MemoryQueryGenerator.default.value + ) + sep: str = " " + + +class LLMMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value + model: str + template: str + + +class CustomMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value + + +MemoryQueryGeneratorConfig = Annotated[ + Union[ + DefaultMemoryQueryGeneratorConfig, + LLMMemoryQueryGeneratorConfig, + CustomMemoryQueryGeneratorConfig, + ], + Field(discriminator="type"), +] + + +class MemoryToolConfig(BaseModel): + memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) + # This config defines how a query is generated using the messages + # for memory bank retrieval. + query_generator_config: MemoryQueryGeneratorConfig = Field( + default=DefaultMemoryQueryGeneratorConfig() + ) + max_tokens_in_context: int = 4096 + max_chunks: int = 10 + kvstore_config: KVStoreConfig = SqliteKVStoreConfig( + db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix() + ) diff --git a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py similarity index 98% rename from llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py rename to llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index 74eb91c53..da97cb3a3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -8,16 +8,17 @@ from typing import List from jinja2 import Template -from llama_stack.apis.agents import ( +from llama_stack.apis.inference import Message, UserMessage +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) + +from .config import ( DefaultMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig, MemoryQueryGenerator, MemoryQueryGeneratorConfig, ) -from llama_stack.apis.inference import Message, UserMessage -from llama_stack.providers.utils.inference.prompt_adapter import ( - interleaved_content_as_str, -) async def generate_rag_query( diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py new file mode 100644 index 000000000..3a08bf1f9 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +import logging +import os +import re +import secrets +import string +import tempfile +import uuid +from typing import Any, Dict, List, Optional +from urllib.parse import urlparse + +import httpx + +from llama_stack.apis.agents import Attachment +from llama_stack.apis.common.content_types import TextContentItem, URL +from llama_stack.apis.inference import Inference, InterleavedContent, Message +from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse +from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams +from llama_stack.apis.tools import ( + ToolDef, + ToolGroupDef, + ToolInvocationResult, + ToolRuntime, +) +from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content +from pydantic import BaseModel + +from .config import MemoryToolConfig +from .context_retriever import generate_rag_query + +log = logging.getLogger(__name__) + + +class MemorySessionInfo(BaseModel): + session_id: str + session_name: str + memory_bank_id: Optional[str] = None + + +def make_random_string(length: int = 8): + return "".join( + secrets.choice(string.ascii_letters + string.digits) for _ in range(length) + ) + + +class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__( + self, + config: MemoryToolConfig, + memory_api: Memory, + memory_banks_api: MemoryBanks, + inference_api: Inference, + ): + self.config = config + self.memory_api = memory_api + self.memory_banks_api = memory_banks_api + self.tempdir = tempfile.mkdtemp() + self.inference_api = inference_api + + async def initialize(self): + self.kvstore = await kvstore_impl(self.config.kvstore_config) + + async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: + return [] + + async def create_session(self, session_id: str) -> MemorySessionInfo: + session_info = MemorySessionInfo( + session_id=session_id, + session_name=f"session_{session_id}", + ) + await self.kvstore.set( + key=f"memory::session:{session_id}", + value=session_info.model_dump_json(), + ) + return session_info + + async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]: + value = await self.kvstore.get( + key=f"memory::session:{session_id}", + ) + if not value: + session_info = await self.create_session(session_id) + return session_info + + return MemorySessionInfo(**json.loads(value)) + + async def add_memory_bank_to_session(self, session_id: str, bank_id: str): + session_info = await self.get_session_info(session_id) + + session_info.memory_bank_id = bank_id + await self.kvstore.set( + key=f"memory::session:{session_id}", + value=session_info.model_dump_json(), + ) + + async def _ensure_memory_bank(self, session_id: str) -> str: + session_info = await self.get_session_info(session_id) + + if session_info.memory_bank_id is None: + bank_id = f"memory_bank_{session_id}" + await self.memory_banks_api.register_memory_bank( + memory_bank_id=bank_id, + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + ), + ) + await self.add_memory_bank_to_session(session_id, bank_id) + else: + bank_id = session_info.memory_bank_id + + return bank_id + + async def attachment_message( + self, tempdir: str, urls: List[URL] + ) -> List[TextContentItem]: + content = [] + + for url in urls: + uri = url.uri + if uri.startswith("file://"): + filepath = uri[len("file://") :] + elif uri.startswith("http"): + path = urlparse(uri).path + basename = os.path.basename(path) + filepath = f"{tempdir}/{make_random_string() + basename}" + log.info(f"Downloading {url} -> {filepath}") + + async with httpx.AsyncClient() as client: + r = await client.get(uri) + resp = r.text + with open(filepath, "w") as fp: + fp.write(resp) + else: + raise ValueError(f"Unsupported URL {url}") + + content.append( + TextContentItem( + text=f'# There is a file accessible to you at "{filepath}"\n' + ) + ) + + return content + + async def _retrieve_context( + self, session_id: str, messages: List[Message] + ) -> Optional[List[InterleavedContent]]: + bank_ids = [] + + bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs) + + session_info = await self.get_session_info(session_id) + if session_info.memory_bank_id: + bank_ids.append(session_info.memory_bank_id) + + if not bank_ids: + # this can happen if the per-session memory bank is not yet populated + # (i.e., no prior turns uploaded an Attachment) + return None + + query = await generate_rag_query( + self.config.query_generator_config, + messages, + inference_api=self.inference_api, + ) + tasks = [ + self.memory_api.query_documents( + bank_id=bank_id, + query=query, + params={ + "max_chunks": 5, + }, + ) + for bank_id in bank_ids + ] + results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) + chunks = [c for r in results for c in r.chunks] + scores = [s for r in results for s in r.scores] + + if not chunks: + return None + + # sort by score + chunks, scores = zip( + *sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True) + ) + + tokens = 0 + picked = [] + for c in chunks[: self.config.max_chunks]: + tokens += c.token_count + if tokens > self.config.max_tokens_in_context: + log.error( + f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", + ) + break + picked.append(f"id:{c.document_id}; content:{c.content}") + + return [ + "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", + *picked, + "\n=== END-RETRIEVED-CONTEXT ===\n", + ] + + async def _process_attachments( + self, session_id: str, attachments: List[Attachment] + ): + bank_id = await self._ensure_memory_bank(session_id) + + documents = [ + MemoryBankDocument( + document_id=str(uuid.uuid4()), + content=a.content, + mime_type=a.mime_type, + metadata={}, + ) + for a in attachments + if isinstance(a.content, str) + ] + await self.memory_api.insert_documents(bank_id, documents) + + urls = [a.content for a in attachments if isinstance(a.content, URL)] + # TODO: we need to migrate URL away from str type + pattern = re.compile("^(https?://|file://|data:)") + urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)] + return await self.attachment_message(self.tempdir, urls) + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + if args["session_id"] is None: + raise ValueError("session_id is required") + + context = await self._retrieve_context( + args["session_id"], args["input_messages"] + ) + if context is None: + context = [] + attachments = args["attachments"] + if attachments and len(attachments) > 0: + context += await self._process_attachments(args["session_id"], attachments) + return ToolInvocationResult( + content=concat_interleaved_content(context), error_code=0 + ) diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 6595b1955..3e38b1adc 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -35,6 +35,8 @@ def available_providers() -> List[ProviderSpec]: Api.safety, Api.memory, Api.memory_banks, + Api.tool_runtime, + Api.tool_groups, ], ), remote_provider_spec( diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 042aef9d9..d0493810c 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -25,6 +25,14 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", ), + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::memory-runtime", + pip_packages=[], + module="llama_stack.providers.inline.tool_runtime.memory", + config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig", + api_dependencies=[Api.memory, Api.memory_banks, Api.inference], + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index dbf79e713..d80013fae 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -7,12 +7,10 @@ import pytest from ..conftest import get_provider_fixture_overrides - from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield -from .fixtures import AGENTS_FIXTURES - +from .fixtures import AGENTS_FIXTURES, TOOL_RUNTIME_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( @@ -21,6 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", + "tool_runtime": "memory", }, id="meta_reference", marks=pytest.mark.meta_reference, @@ -31,6 +30,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", + "tool_runtime": "memory", }, id="ollama", marks=pytest.mark.ollama, @@ -42,6 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ # make this work with Weaviate which is what the together distro supports "memory": "faiss", "agents": "meta_reference", + "tool_runtime": "memory", }, id="together", marks=pytest.mark.together, @@ -52,6 +53,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", + "tool_runtime": "memory", }, id="fireworks", marks=pytest.mark.fireworks, @@ -62,6 +64,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "safety": "remote", "memory": "remote", "agents": "remote", + "tool_runtime": "memory", }, id="remote", marks=pytest.mark.remote, @@ -117,6 +120,7 @@ def pytest_generate_tests(metafunc): "safety": SAFETY_FIXTURES, "memory": MEMORY_FIXTURES, "agents": AGENTS_FIXTURES, + "tool_runtime": TOOL_RUNTIME_FIXTURES, } combinations = ( get_provider_fixture_overrides(metafunc.config, available_fixtures) diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 9f8e7a12b..dd9882aa6 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -10,14 +10,19 @@ import pytest import pytest_asyncio from llama_stack.apis.models import ModelInput, ModelType +from llama_stack.apis.tools import ( + ToolDef, + ToolGroupInput, + ToolParameter, + UserDefinedToolGroupDef, +) from llama_stack.distribution.datatypes import Api, Provider - from llama_stack.providers.inline.agents.meta_reference import ( MetaReferenceAgentsImplConfig, ) - from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + from ..conftest import ProviderFixture, remote_stack_fixture @@ -55,7 +60,21 @@ def agents_meta_reference() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def tool_runtime_memory() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="memory-runtime", + provider_type="inline::memory-runtime", + config={}, + ) + ], + ) + + AGENTS_FIXTURES = ["meta_reference", "remote"] +TOOL_RUNTIME_FIXTURES = ["memory"] @pytest_asyncio.fixture(scope="session") @@ -64,7 +83,7 @@ async def agents_stack(request, inference_model, safety_shield): providers = {} provider_data = {} - for key in ["inference", "safety", "memory", "agents"]: + for key in ["inference", "safety", "memory", "agents", "tool_runtime"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers if key == "inference": @@ -111,12 +130,48 @@ async def agents_stack(request, inference_model, safety_shield): metadata={"embedding_dimension": 384}, ) ) + tool_groups = [ + ToolGroupInput( + tool_group_id="memory_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + ToolDef( + name="memory", + description="memory", + parameters=[ + ToolParameter( + name="session_id", + description="session id", + parameter_type="string", + required=True, + ), + ToolParameter( + name="input_messages", + description="messages", + parameter_type="list", + required=True, + ), + ToolParameter( + name="attachments", + description="attachments", + parameter_type="list", + required=False, + ), + ], + metadata={}, + ) + ], + ), + provider_id="memory-runtime", + ) + ] test_stack = await construct_stack_for_test( - [Api.agents, Api.inference, Api.safety, Api.memory], + [Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime], providers, provider_data, models=models, shields=[safety_shield] if safety_shield else [], + tool_groups=tool_groups, ) return test_stack diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index dc95fa6a6..4ff94e4fe 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -35,7 +35,6 @@ from llama_stack.providers.datatypes import Api # # pytest -v -s llama_stack/providers/tests/agents/test_agents.py # -m "meta_reference" - from .fixtures import pick_inference_model from .utils import create_agent_session @@ -255,17 +254,8 @@ class TestAgents: agent_config = AgentConfig( **{ **common_params, - "tools": [ - MemoryToolDefinition( - memory_bank_configs=[], - query_generator_config={ - "type": "default", - "sep": " ", - }, - max_tokens_in_context=4096, - max_chunks=10, - ), - ], + "tools": [], + "preprocessing_tools": ["memory"], "tool_choice": ToolChoice.auto, } ) diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 5a38aaecc..6f3733408 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import MemoryBankInput from llama_stack.apis.models import ModelInput from llama_stack.apis.scoring_functions import ScoringFnInput from llama_stack.apis.shields import ShieldInput - +from llama_stack.apis.tools import ToolGroupInput from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Provider, StackRunConfig @@ -43,6 +43,7 @@ async def construct_stack_for_test( datasets: Optional[List[DatasetInput]] = None, scoring_fns: Optional[List[ScoringFnInput]] = None, eval_tasks: Optional[List[EvalTaskInput]] = None, + tool_groups: Optional[List[ToolGroupInput]] = None, ) -> TestStack: sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") run_config = dict( @@ -56,6 +57,7 @@ async def construct_stack_for_test( datasets=datasets or [], scoring_fns=scoring_fns or [], eval_tasks=eval_tasks or [], + tool_groups=tool_groups or [], ) run_config = parse_and_maybe_upgrade_config(run_config) try: