mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
remove attachements, move memory bank to tool metadata
This commit is contained in:
parent
97798c8442
commit
f408fd3aca
9 changed files with 45 additions and 180 deletions
|
@ -39,7 +39,6 @@ from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
content: InterleavedContent | URL
|
content: InterleavedContent | URL
|
||||||
mime_type: str
|
mime_type: str
|
||||||
|
@ -258,7 +257,6 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
attachments: Optional[List[Attachment]] = None
|
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
@ -295,7 +293,6 @@ class Agents(Protocol):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
attachments: Optional[List[Attachment]] = None,
|
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
|
|
|
@ -188,7 +188,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
attachments=request.attachments or [],
|
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=self.agent_config.sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
):
|
):
|
||||||
|
@ -238,7 +237,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
attachments: List[Attachment],
|
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -257,7 +255,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
async for res in self._run(
|
async for res in self._run(
|
||||||
session_id, turn_id, input_messages, attachments, sampling_params, stream
|
session_id, turn_id, input_messages, sampling_params, stream
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
@ -350,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
attachments: List[Attachment],
|
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -370,7 +367,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
input_messages=input_messages,
|
input_messages=input_messages,
|
||||||
attachments=attachments,
|
|
||||||
)
|
)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -423,7 +419,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute("output", result.content)
|
span.set_attribute("output", result.content)
|
||||||
span.set_attribute("error_code", result.error_code)
|
span.set_attribute("error_code", result.error_code)
|
||||||
span.set_attribute("error_message", result.error_message)
|
span.set_attribute("error_message", result.error_message)
|
||||||
span.set_attribute("tool_name", tool_name)
|
if isinstance(tool_name, BuiltinTool):
|
||||||
|
span.set_attribute("tool_name", tool_name.value)
|
||||||
|
else:
|
||||||
|
span.set_attribute("tool_name", tool_name)
|
||||||
if result.error_code == 0:
|
if result.error_code == 0:
|
||||||
last_message = input_messages[-1]
|
last_message = input_messages[-1]
|
||||||
last_message.context = result.content
|
last_message.context = result.content
|
||||||
|
@ -553,9 +552,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||||
if len(output_attachments) > 0:
|
if len(output_attachments) > 0:
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
message.content += attachments
|
message.content += output_attachments
|
||||||
else:
|
else:
|
||||||
message.content = [message.content] + attachments
|
message.content = [message.content] + output_attachments
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
log.info(f"Partial message: {str(message)}")
|
log.info(f"Partial message: {str(message)}")
|
||||||
|
@ -586,10 +585,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_name = tool_call.tool_name
|
||||||
|
if isinstance(tool_name, BuiltinTool):
|
||||||
|
tool_name = tool_name.value
|
||||||
with tracing.span(
|
with tracing.span(
|
||||||
"tool_execution",
|
"tool_execution",
|
||||||
{
|
{
|
||||||
"tool_name": tool_call.tool_name,
|
"tool_name": tool_name,
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
|
@ -608,6 +610,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
step_details=ToolExecutionStep(
|
step_details=ToolExecutionStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
|
|
@ -146,14 +146,12 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
attachments: Optional[List[Attachment]] = None,
|
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
attachments=attachments,
|
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -8,11 +8,11 @@ from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import MemoryToolConfig
|
from .config import MemoryToolRuntimeConfig
|
||||||
from .memory import MemoryToolRuntimeImpl
|
from .memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]):
|
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||||
impl = MemoryToolRuntimeImpl(
|
impl = MemoryToolRuntimeImpl(
|
||||||
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,9 +7,6 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Union
|
from typing import Annotated, List, Literal, Union
|
||||||
|
|
||||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,13 +78,13 @@ MemoryQueryGeneratorConfig = Annotated[
|
||||||
|
|
||||||
class MemoryToolConfig(BaseModel):
|
class MemoryToolConfig(BaseModel):
|
||||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolRuntimeConfig(BaseModel):
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
# for memory bank retrieval.
|
# for memory bank retrieval.
|
||||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||||
default=DefaultMemoryQueryGeneratorConfig()
|
default=DefaultMemoryQueryGeneratorConfig()
|
||||||
)
|
)
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 10
|
max_chunks: int = 5
|
||||||
kvstore_config: KVStoreConfig = SqliteKVStoreConfig(
|
|
||||||
db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix()
|
|
||||||
)
|
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
|
@ -23,7 +22,7 @@ from .config import (
|
||||||
|
|
||||||
async def generate_rag_query(
|
async def generate_rag_query(
|
||||||
config: MemoryQueryGeneratorConfig,
|
config: MemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
message: Message,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -31,9 +30,9 @@ async def generate_rag_query(
|
||||||
retrieving relevant information from the memory bank.
|
retrieving relevant information from the memory bank.
|
||||||
"""
|
"""
|
||||||
if config.type == MemoryQueryGenerator.default.value:
|
if config.type == MemoryQueryGenerator.default.value:
|
||||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
query = await default_rag_query_generator(config, message, **kwargs)
|
||||||
elif config.type == MemoryQueryGenerator.llm.value:
|
elif config.type == MemoryQueryGenerator.llm.value:
|
||||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
query = await llm_rag_query_generator(config, message, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||||
return query
|
return query
|
||||||
|
@ -41,21 +40,21 @@ async def generate_rag_query(
|
||||||
|
|
||||||
async def default_rag_query_generator(
|
async def default_rag_query_generator(
|
||||||
config: DefaultMemoryQueryGeneratorConfig,
|
config: DefaultMemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
message: Message,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
return interleaved_content_as_str(message.content)
|
||||||
|
|
||||||
|
|
||||||
async def llm_rag_query_generator(
|
async def llm_rag_query_generator(
|
||||||
config: LLMMemoryQueryGeneratorConfig,
|
config: LLMMemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
message: Message,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||||
inference_api = kwargs["inference_api"]
|
inference_api = kwargs["inference_api"]
|
||||||
|
|
||||||
m_dict = {"messages": [m.model_dump() for m in messages]}
|
m_dict = {"messages": [message.model_dump()]}
|
||||||
|
|
||||||
template = Template(config.template)
|
template = Template(config.template)
|
||||||
content = template.render(m_dict)
|
content = template.render(m_dict)
|
||||||
|
|
|
@ -5,24 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
import tempfile
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import Attachment
|
|
||||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent, Message
|
from llama_stack.apis.inference import Inference, InterleavedContent, Message
|
||||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolGroupDef,
|
ToolGroupDef,
|
||||||
|
@ -30,22 +20,14 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .config import MemoryToolConfig
|
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
||||||
from .context_retriever import generate_rag_query
|
from .context_retriever import generate_rag_query
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MemorySessionInfo(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
session_name: str
|
|
||||||
memory_bank_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
return "".join(
|
return "".join(
|
||||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||||
|
@ -55,7 +37,7 @@ def make_random_string(length: int = 8):
|
||||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MemoryToolConfig,
|
config: MemoryToolRuntimeConfig,
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
memory_banks_api: MemoryBanks,
|
memory_banks_api: MemoryBanks,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
|
@ -63,113 +45,26 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
self.memory_banks_api = memory_banks_api
|
self.memory_banks_api = memory_banks_api
|
||||||
self.tempdir = tempfile.mkdtemp()
|
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
self.kvstore = await kvstore_impl(self.config.kvstore_config)
|
pass
|
||||||
|
|
||||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def create_session(self, session_id: str) -> MemorySessionInfo:
|
|
||||||
session_info = MemorySessionInfo(
|
|
||||||
session_id=session_id,
|
|
||||||
session_name=f"session_{session_id}",
|
|
||||||
)
|
|
||||||
await self.kvstore.set(
|
|
||||||
key=f"memory::session:{session_id}",
|
|
||||||
value=session_info.model_dump_json(),
|
|
||||||
)
|
|
||||||
return session_info
|
|
||||||
|
|
||||||
async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]:
|
|
||||||
value = await self.kvstore.get(
|
|
||||||
key=f"memory::session:{session_id}",
|
|
||||||
)
|
|
||||||
if not value:
|
|
||||||
session_info = await self.create_session(session_id)
|
|
||||||
return session_info
|
|
||||||
|
|
||||||
return MemorySessionInfo(**json.loads(value))
|
|
||||||
|
|
||||||
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
|
||||||
session_info = await self.get_session_info(session_id)
|
|
||||||
|
|
||||||
session_info.memory_bank_id = bank_id
|
|
||||||
await self.kvstore.set(
|
|
||||||
key=f"memory::session:{session_id}",
|
|
||||||
value=session_info.model_dump_json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
|
||||||
session_info = await self.get_session_info(session_id)
|
|
||||||
|
|
||||||
if session_info.memory_bank_id is None:
|
|
||||||
bank_id = f"memory_bank_{session_id}"
|
|
||||||
await self.memory_banks_api.register_memory_bank(
|
|
||||||
memory_bank_id=bank_id,
|
|
||||||
params=VectorMemoryBankParams(
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await self.add_memory_bank_to_session(session_id, bank_id)
|
|
||||||
else:
|
|
||||||
bank_id = session_info.memory_bank_id
|
|
||||||
|
|
||||||
return bank_id
|
|
||||||
|
|
||||||
async def attachment_message(
|
|
||||||
self, tempdir: str, urls: List[URL]
|
|
||||||
) -> List[TextContentItem]:
|
|
||||||
content = []
|
|
||||||
|
|
||||||
for url in urls:
|
|
||||||
uri = url.uri
|
|
||||||
if uri.startswith("file://"):
|
|
||||||
filepath = uri[len("file://") :]
|
|
||||||
elif uri.startswith("http"):
|
|
||||||
path = urlparse(uri).path
|
|
||||||
basename = os.path.basename(path)
|
|
||||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
|
||||||
log.info(f"Downloading {url} -> {filepath}")
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
r = await client.get(uri)
|
|
||||||
resp = r.text
|
|
||||||
with open(filepath, "w") as fp:
|
|
||||||
fp.write(resp)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported URL {url}")
|
|
||||||
|
|
||||||
content.append(
|
|
||||||
TextContentItem(
|
|
||||||
text=f'# There is a file accessible to you at "{filepath}"\n'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def _retrieve_context(
|
async def _retrieve_context(
|
||||||
self, session_id: str, messages: List[Message]
|
self, messages: List[Message], bank_ids: List[str]
|
||||||
) -> Optional[List[InterleavedContent]]:
|
) -> Optional[List[InterleavedContent]]:
|
||||||
bank_ids = []
|
|
||||||
|
|
||||||
bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs)
|
|
||||||
|
|
||||||
session_info = await self.get_session_info(session_id)
|
|
||||||
if session_info.memory_bank_id:
|
|
||||||
bank_ids.append(session_info.memory_bank_id)
|
|
||||||
|
|
||||||
if not bank_ids:
|
if not bank_ids:
|
||||||
# this can happen if the per-session memory bank is not yet populated
|
return None
|
||||||
# (i.e., no prior turns uploaded an Attachment)
|
if len(messages) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
message = messages[-1] # only use the last message as input to the query
|
||||||
query = await generate_rag_query(
|
query = await generate_rag_query(
|
||||||
self.config.query_generator_config,
|
self.config.query_generator_config,
|
||||||
messages,
|
message,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
tasks = [
|
tasks = [
|
||||||
|
@ -177,7 +72,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
bank_id=bank_id,
|
bank_id=bank_id,
|
||||||
query=query,
|
query=query,
|
||||||
params={
|
params={
|
||||||
"max_chunks": 5,
|
"max_chunks": self.config.max_chunks,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for bank_id in bank_ids
|
for bank_id in bank_ids
|
||||||
|
@ -211,43 +106,20 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
]
|
]
|
||||||
|
|
||||||
async def _process_attachments(
|
|
||||||
self, session_id: str, attachments: List[Attachment]
|
|
||||||
):
|
|
||||||
bank_id = await self._ensure_memory_bank(session_id)
|
|
||||||
|
|
||||||
documents = [
|
|
||||||
MemoryBankDocument(
|
|
||||||
document_id=str(uuid.uuid4()),
|
|
||||||
content=a.content,
|
|
||||||
mime_type=a.mime_type,
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
for a in attachments
|
|
||||||
if isinstance(a.content, str)
|
|
||||||
]
|
|
||||||
await self.memory_api.insert_documents(bank_id, documents)
|
|
||||||
|
|
||||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
|
||||||
# TODO: we need to migrate URL away from str type
|
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
|
||||||
urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)]
|
|
||||||
return await self.attachment_message(self.tempdir, urls)
|
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
if args["session_id"] is None:
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
raise ValueError("session_id is required")
|
config = MemoryToolConfig()
|
||||||
|
if tool.metadata.get("config") is not None:
|
||||||
|
config = MemoryToolConfig(**tool.metadata["config"])
|
||||||
|
|
||||||
context = await self._retrieve_context(
|
context = await self._retrieve_context(
|
||||||
args["session_id"], args["input_messages"]
|
args["input_messages"],
|
||||||
|
[bank_config.bank_id for bank_config in config.memory_bank_configs],
|
||||||
)
|
)
|
||||||
if context is None:
|
if context is None:
|
||||||
context = []
|
context = []
|
||||||
attachments = args["attachments"]
|
|
||||||
if attachments and len(attachments) > 0:
|
|
||||||
context += await self._process_attachments(args["session_id"], attachments)
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=concat_interleaved_content(context), error_code=0
|
content=concat_interleaved_content(context), error_code=0
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,7 +22,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_type="inline::memory-runtime",
|
provider_type="inline::memory-runtime",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.inline.tool_runtime.memory",
|
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||||
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
|
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
|
|
||||||
class TestCustomTool(CustomTool):
|
class TestCustomTool(CustomTool):
|
||||||
"""Tool to give boiling point of a liquid
|
"""Tool to give boiling point of a liquid
|
||||||
Returns the correct value for water in Celcius and Fahrenheit
|
Returns the correct value for polyjuice in Celcius and Fahrenheit
|
||||||
and returns -1 for other liquids
|
and returns -1 for other liquids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class TestCustomTool(CustomTool):
|
||||||
return "get_boiling_point"
|
return "get_boiling_point"
|
||||||
|
|
||||||
def get_description(self) -> str:
|
def get_description(self) -> str:
|
||||||
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
|
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
|
||||||
|
|
||||||
def get_params_definition(self) -> Dict[str, Parameter]:
|
def get_params_definition(self) -> Dict[str, Parameter]:
|
||||||
return {
|
return {
|
||||||
|
@ -279,7 +279,6 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
|
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
|
||||||
"Was anything related to 'Llama3' discussed, if so what?",
|
"Was anything related to 'Llama3' discussed, if so what?",
|
||||||
"Tell me how to use LoRA",
|
"Tell me how to use LoRA",
|
||||||
"What about Quantization?",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for prompt in user_prompts:
|
for prompt in user_prompts:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue