remove attachements, move memory bank to tool metadata

This commit is contained in:
Dinesh Yeduguru 2024-12-26 15:48:52 -08:00
parent 97798c8442
commit f408fd3aca
9 changed files with 45 additions and 180 deletions

View file

@ -39,7 +39,6 @@ from llama_stack.apis.safety import SafetyViolation
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class Attachment(BaseModel):
content: InterleavedContent | URL
mime_type: str
@ -258,7 +257,6 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
ToolResponseMessage,
]
]
attachments: Optional[List[Attachment]] = None
stream: Optional[bool] = False
@ -295,7 +293,6 @@ class Agents(Protocol):
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...

View file

@ -188,7 +188,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
@ -238,7 +237,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
@ -257,7 +255,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
session_id, turn_id, input_messages, attachments, sampling_params, stream
session_id, turn_id, input_messages, sampling_params, stream
):
if isinstance(res, bool):
return
@ -350,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
@ -370,7 +367,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id=session_id,
turn_id=turn_id,
input_messages=input_messages,
attachments=attachments,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -423,6 +419,9 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
if isinstance(tool_name, BuiltinTool):
span.set_attribute("tool_name", tool_name.value)
else:
span.set_attribute("tool_name", tool_name)
if result.error_code == 0:
last_message = input_messages[-1]
@ -553,9 +552,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += attachments
message.content += output_attachments
else:
message.content = [message.content] + attachments
message.content = [message.content] + output_attachments
yield message
else:
log.info(f"Partial message: {str(message)}")
@ -586,10 +585,13 @@ class ChatAgent(ShieldRunnerMixin):
)
)
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"tool_name": tool_name,
"input": message.model_dump_json(),
},
) as span:
@ -608,6 +610,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,

View file

@ -146,14 +146,12 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=True,
)
if stream:

View file

@ -8,11 +8,11 @@ from typing import Any, Dict
from llama_stack.providers.datatypes import Api
from .config import MemoryToolConfig
from .config import MemoryToolRuntimeConfig
from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]):
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl(
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
)

View file

@ -7,9 +7,6 @@
from enum import Enum
from typing import Annotated, List, Literal, Union
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig
from pydantic import BaseModel, Field
@ -81,13 +78,13 @@ MemoryQueryGeneratorConfig = Annotated[
class MemoryToolConfig(BaseModel):
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
class MemoryToolRuntimeConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 10
kvstore_config: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix()
)
max_chunks: int = 5

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from jinja2 import Template
@ -23,7 +22,7 @@ from .config import (
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[Message],
message: Message,
**kwargs,
):
"""
@ -31,9 +30,9 @@ async def generate_rag_query(
retrieving relevant information from the memory bank.
"""
if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, messages, **kwargs)
query = await default_rag_query_generator(config, message, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, messages, **kwargs)
query = await llm_rag_query_generator(config, message, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
@ -41,21 +40,21 @@ async def generate_rag_query(
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[Message],
message: Message,
**kwargs,
):
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
return interleaved_content_as_str(message.content)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[Message],
message: Message,
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [m.model_dump() for m in messages]}
m_dict = {"messages": [message.model_dump()]}
template = Template(config.template)
content = template.render(m_dict)

View file

@ -5,24 +5,14 @@
# the root directory of this source tree.
import asyncio
import json
import logging
import os
import re
import secrets
import string
import tempfile
import uuid
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import httpx
from llama_stack.apis.agents import Attachment
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.inference import Inference, InterleavedContent, Message
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import (
ToolDef,
ToolGroupDef,
@ -30,22 +20,14 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from pydantic import BaseModel
from .config import MemoryToolConfig
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
class MemorySessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
@ -55,7 +37,7 @@ def make_random_string(length: int = 8):
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(
self,
config: MemoryToolConfig,
config: MemoryToolRuntimeConfig,
memory_api: Memory,
memory_banks_api: MemoryBanks,
inference_api: Inference,
@ -63,113 +45,26 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
self.config = config
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.tempdir = tempfile.mkdtemp()
self.inference_api = inference_api
async def initialize(self):
self.kvstore = await kvstore_impl(self.config.kvstore_config)
pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
return []
async def create_session(self, session_id: str) -> MemorySessionInfo:
session_info = MemorySessionInfo(
session_id=session_id,
session_name=f"session_{session_id}",
)
await self.kvstore.set(
key=f"memory::session:{session_id}",
value=session_info.model_dump_json(),
)
return session_info
async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]:
value = await self.kvstore.get(
key=f"memory::session:{session_id}",
)
if not value:
session_info = await self.create_session(session_id)
return session_info
return MemorySessionInfo(**json.loads(value))
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self.get_session_info(session_id)
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"memory::session:{session_id}",
value=session_info.model_dump_json(),
)
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.get_session_info(session_id)
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
await self.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
return bank_id
async def attachment_message(
self, tempdir: str, urls: List[URL]
) -> List[TextContentItem]:
content = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
elif uri.startswith("http"):
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
else:
raise ValueError(f"Unsupported URL {url}")
content.append(
TextContentItem(
text=f'# There is a file accessible to you at "{filepath}"\n'
)
)
return content
async def _retrieve_context(
self, session_id: str, messages: List[Message]
self, messages: List[Message], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
bank_ids = []
bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs)
session_info = await self.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated
# (i.e., no prior turns uploaded an Attachment)
return None
if len(messages) == 0:
return None
message = messages[-1] # only use the last message as input to the query
query = await generate_rag_query(
self.config.query_generator_config,
messages,
message,
inference_api=self.inference_api,
)
tasks = [
@ -177,7 +72,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
"max_chunks": self.config.max_chunks,
},
)
for bank_id in bank_ids
@ -211,43 +106,20 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
async def _process_attachments(
self, session_id: str, attachments: List[Attachment]
):
bank_id = await self._ensure_memory_bank(session_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
if isinstance(a.content, str)
]
await self.memory_api.insert_documents(bank_id, documents)
urls = [a.content for a in attachments if isinstance(a.content, URL)]
# TODO: we need to migrate URL away from str type
pattern = re.compile("^(https?://|file://|data:)")
urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)]
return await self.attachment_message(self.tempdir, urls)
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
if args["session_id"] is None:
raise ValueError("session_id is required")
tool = await self.tool_store.get_tool(tool_name)
config = MemoryToolConfig()
if tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
context = await self._retrieve_context(
args["session_id"], args["input_messages"]
args["input_messages"],
[bank_config.bank_id for bank_config in config.memory_bank_configs],
)
if context is None:
context = []
attachments = args["attachments"]
if attachments and len(attachments) > 0:
context += await self._process_attachments(args["session_id"], attachments)
return ToolInvocationResult(
content=concat_interleaved_content(context), error_code=0
)

View file

@ -22,7 +22,7 @@ def available_providers() -> List[ProviderSpec]:
provider_type="inline::memory-runtime",
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.memory",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
),
InlineProviderSpec(

View file

@ -21,7 +21,7 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
class TestCustomTool(CustomTool):
"""Tool to give boiling point of a liquid
Returns the correct value for water in Celcius and Fahrenheit
Returns the correct value for polyjuice in Celcius and Fahrenheit
and returns -1 for other liquids
"""
@ -50,7 +50,7 @@ class TestCustomTool(CustomTool):
return "get_boiling_point"
def get_description(self) -> str:
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, Parameter]:
return {
@ -279,7 +279,6 @@ def test_rag_agent(llama_stack_client, agent_config):
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
"Was anything related to 'Llama3' discussed, if so what?",
"Tell me how to use LoRA",
"What about Quantization?",
]
for prompt in user_prompts: