fix the rag query generator types

This commit is contained in:
Dinesh Yeduguru 2025-01-07 21:33:25 -08:00
parent efe3189728
commit 82395ba654
4 changed files with 19 additions and 10 deletions

View file

@ -8,8 +8,10 @@
from typing import List
from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
@ -24,7 +26,7 @@ from .config import (
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
"""
@ -42,21 +44,26 @@ async def generate_rag_query(
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
return config.sep.join(interleaved_content_as_str(m) for m in messages)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [message.model_dump() for message in messages]}
m_dict = {
"messages": [
message.model_dump() if isinstance(message, BaseModel) else message
for message in messages
]
}
template = Template(config.template)
content = template.render(m_dict)

View file

@ -69,7 +69,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
]
async def _retrieve_context(
self, input_messages: List[str], bank_ids: List[str]
self, input_messages: List[InterleavedContent], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None