simplify toolgroups registration

This commit is contained in:
Dinesh Yeduguru 2025-01-07 15:37:52 -08:00
parent ba242c04cc
commit f9a98c278a
15 changed files with 350 additions and 256 deletions

View file

@ -7,9 +7,16 @@
import logging
import tempfile
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
@ -35,8 +42,22 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def unregister_tool(self, tool_id: str) -> None:
return
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Code interpreter tool group not supported")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="code_interpreter",
description="Execute code",
parameters=[
ToolParameter(
name="code",
description="The code to execute",
parameter_type="string",
),
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]

View file

@ -5,6 +5,8 @@
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from llama_stack.apis.inference import Message, UserMessage
@ -22,7 +24,7 @@ from .config import (
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
message: Message,
messages: List[Message],
**kwargs,
):
"""
@ -30,9 +32,9 @@ async def generate_rag_query(
retrieving relevant information from the memory bank.
"""
if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, message, **kwargs)
query = await default_rag_query_generator(config, messages, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, message, **kwargs)
query = await llm_rag_query_generator(config, messages, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
@ -40,21 +42,21 @@ async def generate_rag_query(
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
message: Message,
messages: List[Message],
**kwargs,
):
return interleaved_content_as_str(message.content)
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
message: Message,
messages: List[Message],
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [message.model_dump()]}
m_dict = {"messages": [message.model_dump() for message in messages]}
template = Template(config.template)
content = template.render(m_dict)

View file

@ -10,13 +10,14 @@ import secrets
import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.inference import Inference, InterleavedContent, Message
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import (
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -50,17 +51,31 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
return []
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="memory",
description="Retrieve context from memory",
parameters=[
ToolParameter(
name="input_messages",
description="The input messages to search for",
parameter_type="array",
),
],
)
]
async def _retrieve_context(
self, message: Message, bank_ids: List[str]
self, input_messages: List[str], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
query = await generate_rag_query(
self.config.query_generator_config,
message,
input_messages,
inference_api=self.inference_api,
)
tasks = [
@ -106,17 +121,22 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig()
if tool.metadata.get("config") is not None:
if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_id" in args:
bank_ids = [args["memory_bank_id"]]
if "memory_bank_ids" in final_args:
bank_ids = final_args["memory_bank_ids"]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context(
args["query"],
final_args["messages"],
bank_ids,
)
if context is None: