agents to use tools api

This commit is contained in:
Dinesh Yeduguru 2024-12-20 14:46:32 -08:00
parent 596afc6497
commit f90e9c2003
21 changed files with 538 additions and 329 deletions

View file

@ -14,18 +14,16 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.inference import (
CompletionMessage,
@ -40,7 +38,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -110,85 +107,6 @@ class FunctionCallToolDefinition(ToolDefinitionCommon):
remote_execution: Optional[RestAPIExecutionConfig] = None
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgentVectorMemoryBankConfig,
AgentKeyValueMemoryBankConfig,
AgentKeywordMemoryBankConfig,
AgentGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 10
AgentToolDefinition = Annotated[
Union[
SearchToolDefinition,
@ -196,7 +114,6 @@ AgentToolDefinition = Annotated[
PhotogenToolDefinition,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
MemoryToolDefinition,
],
Field(discriminator="type"),
]
@ -295,7 +212,11 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
tools: Optional[List[AgentToolDefinition]] = Field(
default_factory=list, deprecated=True
)
available_tools: Optional[List[str]] = Field(default_factory=list)
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json

View file

@ -68,10 +68,16 @@ ToolGroupDef = register_schema(
Annotated[
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
],
name="ToolGroup",
name="ToolGroupDef",
)
class ToolGroupInput(BaseModel):
tool_group_id: str
tool_group: ToolGroupDef
provider_id: Optional[str] = None
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value

View file

@ -161,6 +161,7 @@ a default SQLite store will be used.""",
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
class BuildConfig(BaseModel):

View file

@ -5,9 +5,7 @@
# the root directory of this source tree.
import importlib
import inspect
import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents
@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AutoRoutedProviderSpec,
Provider,
@ -38,7 +35,7 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import (
Api,
DatasetsProtocolPrivate,

View file

@ -523,6 +523,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
)
provider_id = list(self.impls_by_provider_id.keys())[0]
# parse tool group to the type if dict
tool_group = parse_obj_as(ToolGroupDef, tool_group)
if isinstance(tool_group, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group

View file

@ -12,7 +12,7 @@ from typing import Any, Dict, Optional
import pkg_resources
import yaml
from llama_models.llama3.api.datatypes import * # noqa: F403
from termcolor import colored
from llama_stack.apis.agents import Agents
@ -33,14 +33,12 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
LLAMA_STACK_API_VERSION = "alpha"
@ -81,6 +79,7 @@ RESOURCES = [
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -22,6 +22,8 @@ async def get_provider_impl(
deps[Api.memory],
deps[Api.safety],
deps[Api.memory_banks],
deps[Api.tool_runtime],
deps[Api.tool_groups],
)
await impl.initialize()
return impl

View file

@ -4,25 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import copy
import logging
import os
import re
import secrets
import string
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional, Tuple
from typing import AsyncGenerator, Dict, List
from urllib.parse import urlparse
import httpx
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@ -36,8 +32,6 @@ from llama_stack.apis.agents import (
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
InferenceStep,
MemoryRetrievalStep,
MemoryToolDefinition,
PhotogenToolDefinition,
SearchToolDefinition,
ShieldCallStep,
@ -46,11 +40,9 @@ from llama_stack.apis.agents import (
Turn,
WolframAlphaToolDefinition,
)
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
@ -62,30 +54,26 @@ from llama_stack.apis.inference import (
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool
from .tools.builtin import (
CodeInterpreterTool,
interpret_content_as_attachment,
PhotogenTool,
SearchTool,
WolframAlphaTool,
interpret_content_as_attachment,
)
from .tools.safety import SafeTool
@ -108,6 +96,8 @@ class ChatAgent(ShieldRunnerMixin):
memory_api: Memory,
memory_banks_api: MemoryBanks,
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
persistence_store: KVStore,
):
self.agent_id = agent_id
@ -118,6 +108,8 @@ class ChatAgent(ShieldRunnerMixin):
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api
self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
builtin_tools = []
for tool_defn in agent_config.tools:
@ -392,62 +384,50 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
enabled_tools = set(t.type for t in self.agent_config.tools)
need_rag_context = await self._should_retrieve_context(
input_messages, attachments
)
if need_rag_context:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
if self.agent_config.preprocessing_tools:
with tracing.span("preprocessing_tools") as span:
for tool_name in self.agent_config.preprocessing_tools:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=str(uuid.uuid4()),
)
)
)
)
)
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", rag_context)
span.set_attribute("bank_ids", bank_ids)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
step_details=MemoryRetrievalStep(
turn_id=turn_id,
step_id=step_id,
memory_bank_ids=bank_ids,
inserted_context=rag_context or "",
),
args = dict(
session_id=session_id,
input_messages=input_messages,
attachments=attachments,
)
)
)
if rag_context:
last_message = input_messages[-1]
last_message.context = rag_context
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
# TODO: we need to migrate URL away from str type
pattern = re.compile("^(https?://|file://|data:)")
urls += [
URL(uri=a.content) for a in attachments if pattern.match(a.content)
]
msg = await attachment_message(self.tempdir, urls)
input_messages.append(msg)
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name,
args=args,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=str(uuid.uuid4()),
tool_call_delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
content=ToolCall(
call_id="", tool_name=tool_name, arguments={}
),
),
)
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", tool_name)
if result.error_code != 0 and result.content:
last_message = input_messages[-1]
last_message.context = result.content
output_attachments = []
@ -659,129 +639,6 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
return bank_id
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments:
if (
AgentTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required
):
return False
else:
return True
return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools:
if t.type == AgentTool.memory.value:
return t
return None
async def _retrieve_context(
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
assert memory is not None, "Memory tool not configured"
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank_id = await self._ensure_memory_bank(session_id)
bank_ids.append(bank_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
]
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self.storage.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated
# (i.e., no prior turns uploaded an Attachment)
return None, []
query = await generate_rag_query(
memory.query_generator_config, messages, inference_api=self.inference_api
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
},
)
for bank_id in bank_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None, bank_ids
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
tokens = 0
picked = []
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return (
concat_interleaved_content(
[
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
),
bank_ids,
)
def _get_tools(self) -> List[ToolDefinition]:
ret = []
for t in self.agent_config.tools:

View file

@ -24,12 +24,11 @@ from llama_stack.apis.agents import (
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
@ -47,12 +46,16 @@ class MetaReferenceAgentsImpl(Agents):
memory_api: Memory,
safety_api: Safety,
memory_banks_api: MemoryBanks,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
):
self.config = config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.memory_banks_api = memory_banks_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
@ -112,6 +115,8 @@ class MetaReferenceAgentsImpl(Agents):
safety_api=self.safety_api,
memory_api=self.memory_api,
memory_banks_api=self.memory_banks_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence

View file

@ -8,13 +8,11 @@ import json
import logging
import uuid
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -23,7 +21,6 @@ log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
started_at: datetime
@ -54,17 +51,6 @@ class AgentPersistence:
return AgentSessionInfo(**json.loads(value))
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.providers.datatypes import Api
from .config import MemoryToolConfig
from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl(
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
)
await impl.initialize()
return impl

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, List, Literal, Union
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig
from pydantic import BaseModel, Field
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
class MemoryToolConfig(BaseModel):
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 10
kvstore_config: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix()
)

View file

@ -8,16 +8,17 @@ from typing import List
from jinja2 import Template
from llama_stack.apis.agents import (
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(

View file

@ -0,0 +1,253 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
import os
import re
import secrets
import string
import tempfile
import uuid
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import httpx
from llama_stack.apis.agents import Attachment
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.inference import Inference, InterleavedContent, Message
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.tools import (
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from pydantic import BaseModel
from .config import MemoryToolConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
class MemorySessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(
self,
config: MemoryToolConfig,
memory_api: Memory,
memory_banks_api: MemoryBanks,
inference_api: Inference,
):
self.config = config
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.tempdir = tempfile.mkdtemp()
self.inference_api = inference_api
async def initialize(self):
self.kvstore = await kvstore_impl(self.config.kvstore_config)
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
return []
async def create_session(self, session_id: str) -> MemorySessionInfo:
session_info = MemorySessionInfo(
session_id=session_id,
session_name=f"session_{session_id}",
)
await self.kvstore.set(
key=f"memory::session:{session_id}",
value=session_info.model_dump_json(),
)
return session_info
async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]:
value = await self.kvstore.get(
key=f"memory::session:{session_id}",
)
if not value:
session_info = await self.create_session(session_id)
return session_info
return MemorySessionInfo(**json.loads(value))
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self.get_session_info(session_id)
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"memory::session:{session_id}",
value=session_info.model_dump_json(),
)
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.get_session_info(session_id)
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
await self.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
return bank_id
async def attachment_message(
self, tempdir: str, urls: List[URL]
) -> List[TextContentItem]:
content = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
elif uri.startswith("http"):
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
else:
raise ValueError(f"Unsupported URL {url}")
content.append(
TextContentItem(
text=f'# There is a file accessible to you at "{filepath}"\n'
)
)
return content
async def _retrieve_context(
self, session_id: str, messages: List[Message]
) -> Optional[List[InterleavedContent]]:
bank_ids = []
bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs)
session_info = await self.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated
# (i.e., no prior turns uploaded an Attachment)
return None
query = await generate_rag_query(
self.config.query_generator_config,
messages,
inference_api=self.inference_api,
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
},
)
for bank_id in bank_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
tokens = 0
picked = []
for c in chunks[: self.config.max_chunks]:
tokens += c.token_count
if tokens > self.config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
async def _process_attachments(
self, session_id: str, attachments: List[Attachment]
):
bank_id = await self._ensure_memory_bank(session_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
if isinstance(a.content, str)
]
await self.memory_api.insert_documents(bank_id, documents)
urls = [a.content for a in attachments if isinstance(a.content, URL)]
# TODO: we need to migrate URL away from str type
pattern = re.compile("^(https?://|file://|data:)")
urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)]
return await self.attachment_message(self.tempdir, urls)
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
if args["session_id"] is None:
raise ValueError("session_id is required")
context = await self._retrieve_context(
args["session_id"], args["input_messages"]
)
if context is None:
context = []
attachments = args["attachments"]
if attachments and len(attachments) > 0:
context += await self._process_attachments(args["session_id"], attachments)
return ToolInvocationResult(
content=concat_interleaved_content(context), error_code=0
)

View file

@ -35,6 +35,8 @@ def available_providers() -> List[ProviderSpec]:
Api.safety,
Api.memory,
Api.memory_banks,
Api.tool_runtime,
Api.tool_groups,
],
),
remote_provider_spec(

View file

@ -25,6 +25,14 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
),
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::memory-runtime",
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.memory",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig",
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(

View file

@ -7,12 +7,10 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from .fixtures import AGENTS_FIXTURES
from .fixtures import AGENTS_FIXTURES, TOOL_RUNTIME_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
@ -21,6 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
@ -31,6 +30,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory",
},
id="ollama",
marks=pytest.mark.ollama,
@ -42,6 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
# make this work with Weaviate which is what the together distro supports
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory",
},
id="together",
marks=pytest.mark.together,
@ -52,6 +53,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory",
},
id="fireworks",
marks=pytest.mark.fireworks,
@ -62,6 +64,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "remote",
"memory": "remote",
"agents": "remote",
"tool_runtime": "memory",
},
id="remote",
marks=pytest.mark.remote,
@ -117,6 +120,7 @@ def pytest_generate_tests(metafunc):
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
"agents": AGENTS_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)

View file

@ -10,14 +10,19 @@ import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import (
ToolDef,
ToolGroupInput,
ToolParameter,
UserDefinedToolGroupDef,
)
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
@ -55,7 +60,21 @@ def agents_meta_reference() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def tool_runtime_memory() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="memory-runtime",
provider_type="inline::memory-runtime",
config={},
)
],
)
AGENTS_FIXTURES = ["meta_reference", "remote"]
TOOL_RUNTIME_FIXTURES = ["memory"]
@pytest_asyncio.fixture(scope="session")
@ -64,7 +83,7 @@ async def agents_stack(request, inference_model, safety_shield):
providers = {}
provider_data = {}
for key in ["inference", "safety", "memory", "agents"]:
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
@ -111,12 +130,48 @@ async def agents_stack(request, inference_model, safety_shield):
metadata={"embedding_dimension": 384},
)
)
tool_groups = [
ToolGroupInput(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
tools=[
ToolDef(
name="memory",
description="memory",
parameters=[
ToolParameter(
name="session_id",
description="session id",
parameter_type="string",
required=True,
),
ToolParameter(
name="input_messages",
description="messages",
parameter_type="list",
required=True,
),
ToolParameter(
name="attachments",
description="attachments",
parameter_type="list",
required=False,
),
],
metadata={},
)
],
),
provider_id="memory-runtime",
)
]
test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory],
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
providers,
provider_data,
models=models,
shields=[safety_shield] if safety_shield else [],
tool_groups=tool_groups,
)
return test_stack

View file

@ -35,7 +35,6 @@ from llama_stack.providers.datatypes import Api
#
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
# -m "meta_reference"
from .fixtures import pick_inference_model
from .utils import create_agent_session
@ -255,17 +254,8 @@ class TestAgents:
agent_config = AgentConfig(
**{
**common_params,
"tools": [
MemoryToolDefinition(
memory_bank_configs=[],
query_generator_config={
"type": "default",
"sep": " ",
},
max_tokens_in_context=4096,
max_chunks=10,
),
],
"tools": [],
"preprocessing_tools": ["memory"],
"tool_choice": ToolChoice.auto,
}
)

View file

@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import MemoryBankInput
from llama_stack.apis.models import ModelInput
from llama_stack.apis.scoring_functions import ScoringFnInput
from llama_stack.apis.shields import ShieldInput
from llama_stack.apis.tools import ToolGroupInput
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Provider, StackRunConfig
@ -43,6 +43,7 @@ async def construct_stack_for_test(
datasets: Optional[List[DatasetInput]] = None,
scoring_fns: Optional[List[ScoringFnInput]] = None,
eval_tasks: Optional[List[EvalTaskInput]] = None,
tool_groups: Optional[List[ToolGroupInput]] = None,
) -> TestStack:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(
@ -56,6 +57,7 @@ async def construct_stack_for_test(
datasets=datasets or [],
scoring_fns=scoring_fns or [],
eval_tasks=eval_tasks or [],
tool_groups=tool_groups or [],
)
run_config = parse_and_maybe_upgrade_config(run_config)
try: