fix the rag query generator types

This commit is contained in:
Dinesh Yeduguru 2025-01-07 21:33:25 -08:00
parent efe3189728
commit 82395ba654
4 changed files with 19 additions and 10 deletions

View file

@ -401,7 +401,9 @@ class ChatAgent(ShieldRunnerMixin):
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
query_args["memory_bank_id"] = session_info.memory_bank_id
if "memory_bank_ids" not in query_args:
query_args["memory_bank_ids"] = []
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(

View file

@ -8,8 +8,10 @@
from typing import List
from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
@ -24,7 +26,7 @@ from .config import (
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
"""
@ -42,21 +44,26 @@ async def generate_rag_query(
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
return config.sep.join(interleaved_content_as_str(m) for m in messages)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [message.model_dump() for message in messages]}
m_dict = {
"messages": [
message.model_dump() if isinstance(message, BaseModel) else message
for message in messages
]
}
template = Template(config.template)
content = template.render(m_dict)

View file

@ -69,7 +69,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
]
async def _retrieve_context(
self, input_messages: List[str], bank_ids: List[str]
self, input_messages: List[InterleavedContent], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None

View file

@ -206,7 +206,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
def test_code_execution(llama_stack_client):
agent_config = AgentConfig(
model="meta-llama/Llama-3.1-70B-Instruct",
model="meta-llama/Llama-3.1-8B-Instruct",
instructions="You are a helpful assistant",
toolgroups=[
"builtin::code_interpreter",
@ -297,7 +297,7 @@ def test_rag_agent(llama_stack_client, agent_config):
agent_config["toolgroups"].append(
dict(
name="builtin::memory",
args={"memory_bank_id": memory_bank_id},
args={"memory_bank_ids": [memory_bank_id]},
)
)
agent = Agent(llama_stack_client, agent_config)