mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 20:53:57 +00:00
simplify toolgroups registration
This commit is contained in:
parent
ba242c04cc
commit
f9a98c278a
15 changed files with 350 additions and 256 deletions
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.inference import Message, UserMessage
|
||||
|
|
@ -22,7 +24,7 @@ from .config import (
|
|||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
message: Message,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
|
@ -30,9 +32,9 @@ async def generate_rag_query(
|
|||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, message, **kwargs)
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, message, **kwargs)
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
|
@ -40,21 +42,21 @@ async def generate_rag_query(
|
|||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
message: Message,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
return interleaved_content_as_str(message.content)
|
||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
message: Message,
|
||||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {"messages": [message.model_dump()]}
|
||||
m_dict = {"messages": [message.model_dump() for message in messages]}
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,14 @@ import secrets
|
|||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent, Message
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.tools import (
|
||||
ToolDef,
|
||||
ToolGroupDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
|
@ -50,17 +51,31 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
||||
return []
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="memory",
|
||||
description="Retrieve context from memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="input_messages",
|
||||
description="The input messages to search for",
|
||||
parameter_type="array",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def _retrieve_context(
|
||||
self, message: Message, bank_ids: List[str]
|
||||
self, input_messages: List[str], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
message,
|
||||
input_messages,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
|
|
@ -106,17 +121,22 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||
final_args = tool_group.args or {}
|
||||
final_args.update(args)
|
||||
config = MemoryToolConfig()
|
||||
if tool.metadata.get("config") is not None:
|
||||
if tool.metadata and tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
if "memory_bank_id" in args:
|
||||
bank_ids = [args["memory_bank_id"]]
|
||||
if "memory_bank_ids" in final_args:
|
||||
bank_ids = final_args["memory_bank_ids"]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
if "messages" not in final_args:
|
||||
raise ValueError("messages are required")
|
||||
context = await self._retrieve_context(
|
||||
args["query"],
|
||||
final_args["messages"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue