mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
remove attachements, move memory bank to tool metadata
This commit is contained in:
parent
97798c8442
commit
f408fd3aca
9 changed files with 45 additions and 180 deletions
|
@ -39,7 +39,6 @@ from llama_stack.apis.safety import SafetyViolation
|
|||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Attachment(BaseModel):
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
@ -258,7 +257,6 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
attachments: Optional[List[Attachment]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
@ -295,7 +293,6 @@ class Agents(Protocol):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
|
|
|
@ -188,7 +188,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
|
@ -238,7 +237,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
|
@ -257,7 +255,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
session_id, turn_id, input_messages, attachments, sampling_params, stream
|
||||
session_id, turn_id, input_messages, sampling_params, stream
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -350,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
|
@ -370,7 +367,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=input_messages,
|
||||
attachments=attachments,
|
||||
)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -423,7 +419,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("output", result.content)
|
||||
span.set_attribute("error_code", result.error_code)
|
||||
span.set_attribute("error_message", result.error_message)
|
||||
span.set_attribute("tool_name", tool_name)
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
span.set_attribute("tool_name", tool_name.value)
|
||||
else:
|
||||
span.set_attribute("tool_name", tool_name)
|
||||
if result.error_code == 0:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = result.content
|
||||
|
@ -553,9 +552,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += attachments
|
||||
message.content += output_attachments
|
||||
else:
|
||||
message.content = [message.content] + attachments
|
||||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
log.info(f"Partial message: {str(message)}")
|
||||
|
@ -586,10 +585,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"tool_name": tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
|
@ -608,6 +610,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
|
|
|
@ -146,14 +146,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
)
|
||||
if stream:
|
||||
|
|
|
@ -8,11 +8,11 @@ from typing import Any, Dict
|
|||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MemoryToolConfig
|
||||
from .config import MemoryToolRuntimeConfig
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
|
||||
async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
impl = MemoryToolRuntimeImpl(
|
||||
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||
)
|
||||
|
|
|
@ -7,9 +7,6 @@
|
|||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Union
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
@ -81,13 +78,13 @@ MemoryQueryGeneratorConfig = Annotated[
|
|||
|
||||
class MemoryToolConfig(BaseModel):
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MemoryToolRuntimeConfig(BaseModel):
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 10
|
||||
kvstore_config: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix()
|
||||
)
|
||||
max_chunks: int = 5
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
|
@ -23,7 +22,7 @@ from .config import (
|
|||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
message: Message,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -31,9 +30,9 @@ async def generate_rag_query(
|
|||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
query = await default_rag_query_generator(config, message, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
query = await llm_rag_query_generator(config, message, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
@ -41,21 +40,21 @@ async def generate_rag_query(
|
|||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
message: Message,
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||
return interleaved_content_as_str(message.content)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
message: Message,
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {"messages": [m.model_dump() for m in messages]}
|
||||
m_dict = {"messages": [message.model_dump()]}
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
|
|
|
@ -5,24 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.agents import Attachment
|
||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent, Message
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.tools import (
|
||||
ToolDef,
|
||||
ToolGroupDef,
|
||||
|
@ -30,22 +20,14 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import MemoryToolConfig
|
||||
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemorySessionInfo(BaseModel):
|
||||
session_id: str
|
||||
session_name: str
|
||||
memory_bank_id: Optional[str] = None
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||
|
@ -55,7 +37,7 @@ def make_random_string(length: int = 8):
|
|||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: MemoryToolConfig,
|
||||
config: MemoryToolRuntimeConfig,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
inference_api: Inference,
|
||||
|
@ -63,113 +45,26 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
self.config = config
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self):
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore_config)
|
||||
pass
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
||||
return []
|
||||
|
||||
async def create_session(self, session_id: str) -> MemorySessionInfo:
|
||||
session_info = MemorySessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=f"session_{session_id}",
|
||||
)
|
||||
await self.kvstore.set(
|
||||
key=f"memory::session:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
return session_info
|
||||
|
||||
async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"memory::session:{session_id}",
|
||||
)
|
||||
if not value:
|
||||
session_info = await self.create_session(session_id)
|
||||
return session_info
|
||||
|
||||
return MemorySessionInfo(**json.loads(value))
|
||||
|
||||
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
||||
session_info = await self.get_session_info(session_id)
|
||||
|
||||
session_info.memory_bank_id = bank_id
|
||||
await self.kvstore.set(
|
||||
key=f"memory::session:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||
session_info = await self.get_session_info(session_id)
|
||||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
await self.memory_banks_api.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
)
|
||||
await self.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
||||
return bank_id
|
||||
|
||||
async def attachment_message(
|
||||
self, tempdir: str, urls: List[URL]
|
||||
) -> List[TextContentItem]:
|
||||
content = []
|
||||
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
if uri.startswith("file://"):
|
||||
filepath = uri[len("file://") :]
|
||||
elif uri.startswith("http"):
|
||||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
log.info(f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
resp = r.text
|
||||
with open(filepath, "w") as fp:
|
||||
fp.write(resp)
|
||||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
content.append(
|
||||
TextContentItem(
|
||||
text=f'# There is a file accessible to you at "{filepath}"\n'
|
||||
)
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def _retrieve_context(
|
||||
self, session_id: str, messages: List[Message]
|
||||
self, messages: List[Message], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
bank_ids = []
|
||||
|
||||
bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs)
|
||||
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if session_info.memory_bank_id:
|
||||
bank_ids.append(session_info.memory_bank_id)
|
||||
|
||||
if not bank_ids:
|
||||
# this can happen if the per-session memory bank is not yet populated
|
||||
# (i.e., no prior turns uploaded an Attachment)
|
||||
return None
|
||||
if len(messages) == 0:
|
||||
return None
|
||||
|
||||
message = messages[-1] # only use the last message as input to the query
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
messages,
|
||||
message,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
|
@ -177,7 +72,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
bank_id=bank_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": 5,
|
||||
"max_chunks": self.config.max_chunks,
|
||||
},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
|
@ -211,43 +106,20 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
]
|
||||
|
||||
async def _process_attachments(
|
||||
self, session_id: str, attachments: List[Attachment]
|
||||
):
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for a in attachments
|
||||
if isinstance(a.content, str)
|
||||
]
|
||||
await self.memory_api.insert_documents(bank_id, documents)
|
||||
|
||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||
# TODO: we need to migrate URL away from str type
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)]
|
||||
return await self.attachment_message(self.tempdir, urls)
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
if args["session_id"] is None:
|
||||
raise ValueError("session_id is required")
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
config = MemoryToolConfig()
|
||||
if tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
|
||||
context = await self._retrieve_context(
|
||||
args["session_id"], args["input_messages"]
|
||||
args["input_messages"],
|
||||
[bank_config.bank_id for bank_config in config.memory_bank_configs],
|
||||
)
|
||||
if context is None:
|
||||
context = []
|
||||
attachments = args["attachments"]
|
||||
if attachments and len(attachments) > 0:
|
||||
context += await self._process_attachments(args["session_id"], attachments)
|
||||
return ToolInvocationResult(
|
||||
content=concat_interleaved_content(context), error_code=0
|
||||
)
|
||||
|
|
|
@ -22,7 +22,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_type="inline::memory-runtime",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
|
|||
|
||||
class TestCustomTool(CustomTool):
|
||||
"""Tool to give boiling point of a liquid
|
||||
Returns the correct value for water in Celcius and Fahrenheit
|
||||
Returns the correct value for polyjuice in Celcius and Fahrenheit
|
||||
and returns -1 for other liquids
|
||||
"""
|
||||
|
||||
|
@ -50,7 +50,7 @@ class TestCustomTool(CustomTool):
|
|||
return "get_boiling_point"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
|
||||
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
|
||||
|
||||
def get_params_definition(self) -> Dict[str, Parameter]:
|
||||
return {
|
||||
|
@ -279,7 +279,6 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
|
||||
"Was anything related to 'Llama3' discussed, if so what?",
|
||||
"Tell me how to use LoRA",
|
||||
"What about Quantization?",
|
||||
]
|
||||
|
||||
for prompt in user_prompts:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue