remove attachements, move memory bank to tool metadata

This commit is contained in:
Dinesh Yeduguru 2024-12-26 15:48:52 -08:00
parent 97798c8442
commit f408fd3aca
9 changed files with 45 additions and 180 deletions

View file

@ -39,7 +39,6 @@ from llama_stack.apis.safety import SafetyViolation
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class Attachment(BaseModel): class Attachment(BaseModel):
content: InterleavedContent | URL content: InterleavedContent | URL
mime_type: str mime_type: str
@ -258,7 +257,6 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
ToolResponseMessage, ToolResponseMessage,
] ]
] ]
attachments: Optional[List[Attachment]] = None
stream: Optional[bool] = False stream: Optional[bool] = False
@ -295,7 +293,6 @@ class Agents(Protocol):
ToolResponseMessage, ToolResponseMessage,
] ]
], ],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...

View file

@ -188,7 +188,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id=request.session_id, session_id=request.session_id,
turn_id=turn_id, turn_id=turn_id,
input_messages=messages, input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params, sampling_params=self.agent_config.sampling_params,
stream=request.stream, stream=request.stream,
): ):
@ -238,7 +237,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str, session_id: str,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -257,7 +255,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res yield res
async for res in self._run( async for res in self._run(
session_id, turn_id, input_messages, attachments, sampling_params, stream session_id, turn_id, input_messages, sampling_params, stream
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -350,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str, session_id: str,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -370,7 +367,6 @@ class ChatAgent(ShieldRunnerMixin):
session_id=session_id, session_id=session_id,
turn_id=turn_id, turn_id=turn_id,
input_messages=input_messages, input_messages=input_messages,
attachments=attachments,
) )
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -423,7 +419,10 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("output", result.content) span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code) span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message) span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", tool_name) if isinstance(tool_name, BuiltinTool):
span.set_attribute("tool_name", tool_name.value)
else:
span.set_attribute("tool_name", tool_name)
if result.error_code == 0: if result.error_code == 0:
last_message = input_messages[-1] last_message = input_messages[-1]
last_message.context = result.content last_message.context = result.content
@ -553,9 +552,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0: if len(output_attachments) > 0:
if isinstance(message.content, list): if isinstance(message.content, list):
message.content += attachments message.content += output_attachments
else: else:
message.content = [message.content] + attachments message.content = [message.content] + output_attachments
yield message yield message
else: else:
log.info(f"Partial message: {str(message)}") log.info(f"Partial message: {str(message)}")
@ -586,10 +585,13 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span( with tracing.span(
"tool_execution", "tool_execution",
{ {
"tool_name": tool_call.tool_name, "tool_name": tool_name,
"input": message.model_dump_json(), "input": message.model_dump_json(),
}, },
) as span: ) as span:
@ -608,6 +610,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep( step_details=ToolExecutionStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,

View file

@ -146,14 +146,12 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage, ToolResponseMessage,
] ]
], ],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = AgentTurnCreateRequest( request = AgentTurnCreateRequest(
agent_id=agent_id, agent_id=agent_id,
session_id=session_id, session_id=session_id,
messages=messages, messages=messages,
attachments=attachments,
stream=True, stream=True,
) )
if stream: if stream:

View file

@ -8,11 +8,11 @@ from typing import Any, Dict
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from .config import MemoryToolConfig from .config import MemoryToolRuntimeConfig
from .memory import MemoryToolRuntimeImpl from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]): async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl( impl = MemoryToolRuntimeImpl(
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference] config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
) )

View file

@ -7,9 +7,6 @@
from enum import Enum from enum import Enum
from typing import Annotated, List, Literal, Union from typing import Annotated, List, Literal, Union
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -81,13 +78,13 @@ MemoryQueryGeneratorConfig = Annotated[
class MemoryToolConfig(BaseModel): class MemoryToolConfig(BaseModel):
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
class MemoryToolRuntimeConfig(BaseModel):
# This config defines how a query is generated using the messages # This config defines how a query is generated using the messages
# for memory bank retrieval. # for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field( query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig() default=DefaultMemoryQueryGeneratorConfig()
) )
max_tokens_in_context: int = 4096 max_tokens_in_context: int = 4096
max_chunks: int = 10 max_chunks: int = 5
kvstore_config: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix()
)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List
from jinja2 import Template from jinja2 import Template
@ -23,7 +22,7 @@ from .config import (
async def generate_rag_query( async def generate_rag_query(
config: MemoryQueryGeneratorConfig, config: MemoryQueryGeneratorConfig,
messages: List[Message], message: Message,
**kwargs, **kwargs,
): ):
""" """
@ -31,9 +30,9 @@ async def generate_rag_query(
retrieving relevant information from the memory bank. retrieving relevant information from the memory bank.
""" """
if config.type == MemoryQueryGenerator.default.value: if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, messages, **kwargs) query = await default_rag_query_generator(config, message, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value: elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, messages, **kwargs) query = await llm_rag_query_generator(config, message, **kwargs)
else: else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}") raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query return query
@ -41,21 +40,21 @@ async def generate_rag_query(
async def default_rag_query_generator( async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig, config: DefaultMemoryQueryGeneratorConfig,
messages: List[Message], message: Message,
**kwargs, **kwargs,
): ):
return config.sep.join(interleaved_content_as_str(m.content) for m in messages) return interleaved_content_as_str(message.content)
async def llm_rag_query_generator( async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig, config: LLMMemoryQueryGeneratorConfig,
messages: List[Message], message: Message,
**kwargs, **kwargs,
): ):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"] inference_api = kwargs["inference_api"]
m_dict = {"messages": [m.model_dump() for m in messages]} m_dict = {"messages": [message.model_dump()]}
template = Template(config.template) template = Template(config.template)
content = template.render(m_dict) content = template.render(m_dict)

View file

@ -5,24 +5,14 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json
import logging import logging
import os
import re
import secrets import secrets
import string import string
import tempfile
import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import httpx
from llama_stack.apis.agents import Attachment
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.inference import Inference, InterleavedContent, Message from llama_stack.apis.inference import Inference, InterleavedContent, Message
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ToolDef, ToolDef,
ToolGroupDef, ToolGroupDef,
@ -30,22 +20,14 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from pydantic import BaseModel
from .config import MemoryToolConfig from .config import MemoryToolConfig, MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query from .context_retriever import generate_rag_query
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class MemorySessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
def make_random_string(length: int = 8): def make_random_string(length: int = 8):
return "".join( return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length) secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
@ -55,7 +37,7 @@ def make_random_string(length: int = 8):
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__( def __init__(
self, self,
config: MemoryToolConfig, config: MemoryToolRuntimeConfig,
memory_api: Memory, memory_api: Memory,
memory_banks_api: MemoryBanks, memory_banks_api: MemoryBanks,
inference_api: Inference, inference_api: Inference,
@ -63,113 +45,26 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
self.config = config self.config = config
self.memory_api = memory_api self.memory_api = memory_api
self.memory_banks_api = memory_banks_api self.memory_banks_api = memory_banks_api
self.tempdir = tempfile.mkdtemp()
self.inference_api = inference_api self.inference_api = inference_api
async def initialize(self): async def initialize(self):
self.kvstore = await kvstore_impl(self.config.kvstore_config) pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
return [] return []
async def create_session(self, session_id: str) -> MemorySessionInfo:
session_info = MemorySessionInfo(
session_id=session_id,
session_name=f"session_{session_id}",
)
await self.kvstore.set(
key=f"memory::session:{session_id}",
value=session_info.model_dump_json(),
)
return session_info
async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]:
value = await self.kvstore.get(
key=f"memory::session:{session_id}",
)
if not value:
session_info = await self.create_session(session_id)
return session_info
return MemorySessionInfo(**json.loads(value))
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self.get_session_info(session_id)
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"memory::session:{session_id}",
value=session_info.model_dump_json(),
)
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.get_session_info(session_id)
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
await self.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
return bank_id
async def attachment_message(
self, tempdir: str, urls: List[URL]
) -> List[TextContentItem]:
content = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
elif uri.startswith("http"):
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
else:
raise ValueError(f"Unsupported URL {url}")
content.append(
TextContentItem(
text=f'# There is a file accessible to you at "{filepath}"\n'
)
)
return content
async def _retrieve_context( async def _retrieve_context(
self, session_id: str, messages: List[Message] self, messages: List[Message], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]: ) -> Optional[List[InterleavedContent]]:
bank_ids = []
bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs)
session_info = await self.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids: if not bank_ids:
# this can happen if the per-session memory bank is not yet populated return None
# (i.e., no prior turns uploaded an Attachment) if len(messages) == 0:
return None return None
message = messages[-1] # only use the last message as input to the query
query = await generate_rag_query( query = await generate_rag_query(
self.config.query_generator_config, self.config.query_generator_config,
messages, message,
inference_api=self.inference_api, inference_api=self.inference_api,
) )
tasks = [ tasks = [
@ -177,7 +72,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
bank_id=bank_id, bank_id=bank_id,
query=query, query=query,
params={ params={
"max_chunks": 5, "max_chunks": self.config.max_chunks,
}, },
) )
for bank_id in bank_ids for bank_id in bank_ids
@ -211,43 +106,20 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
"\n=== END-RETRIEVED-CONTEXT ===\n", "\n=== END-RETRIEVED-CONTEXT ===\n",
] ]
async def _process_attachments(
self, session_id: str, attachments: List[Attachment]
):
bank_id = await self._ensure_memory_bank(session_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
if isinstance(a.content, str)
]
await self.memory_api.insert_documents(bank_id, documents)
urls = [a.content for a in attachments if isinstance(a.content, URL)]
# TODO: we need to migrate URL away from str type
pattern = re.compile("^(https?://|file://|data:)")
urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)]
return await self.attachment_message(self.tempdir, urls)
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
if args["session_id"] is None: tool = await self.tool_store.get_tool(tool_name)
raise ValueError("session_id is required") config = MemoryToolConfig()
if tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
context = await self._retrieve_context( context = await self._retrieve_context(
args["session_id"], args["input_messages"] args["input_messages"],
[bank_config.bank_id for bank_config in config.memory_bank_configs],
) )
if context is None: if context is None:
context = [] context = []
attachments = args["attachments"]
if attachments and len(attachments) > 0:
context += await self._process_attachments(args["session_id"], attachments)
return ToolInvocationResult( return ToolInvocationResult(
content=concat_interleaved_content(context), error_code=0 content=concat_interleaved_content(context), error_code=0
) )

View file

@ -22,7 +22,7 @@ def available_providers() -> List[ProviderSpec]:
provider_type="inline::memory-runtime", provider_type="inline::memory-runtime",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.memory", module="llama_stack.providers.inline.tool_runtime.memory",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig", config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
api_dependencies=[Api.memory, Api.memory_banks, Api.inference], api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(

View file

@ -21,7 +21,7 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
class TestCustomTool(CustomTool): class TestCustomTool(CustomTool):
"""Tool to give boiling point of a liquid """Tool to give boiling point of a liquid
Returns the correct value for water in Celcius and Fahrenheit Returns the correct value for polyjuice in Celcius and Fahrenheit
and returns -1 for other liquids and returns -1 for other liquids
""" """
@ -50,7 +50,7 @@ class TestCustomTool(CustomTool):
return "get_boiling_point" return "get_boiling_point"
def get_description(self) -> str: def get_description(self) -> str:
return "Get the boiling point of a imaginary liquids (eg. polyjuice)" return "Get the boiling point of imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, Parameter]: def get_params_definition(self) -> Dict[str, Parameter]:
return { return {
@ -279,7 +279,6 @@ def test_rag_agent(llama_stack_client, agent_config):
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.", "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
"Was anything related to 'Llama3' discussed, if so what?", "Was anything related to 'Llama3' discussed, if so what?",
"Tell me how to use LoRA", "Tell me how to use LoRA",
"What about Quantization?",
] ]
for prompt in user_prompts: for prompt in user_prompts: