fix the rag query generator types

This commit is contained in:
Dinesh Yeduguru 2025-01-07 21:33:25 -08:00
parent efe3189728
commit 82395ba654
4 changed files with 19 additions and 10 deletions

View file

@ -401,7 +401,9 @@ class ChatAgent(ShieldRunnerMixin):
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id: if session_info.memory_bank_id:
query_args["memory_bank_id"] = session_info.memory_bank_id if "memory_bank_ids" not in query_args:
query_args["memory_bank_ids"] = []
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(

View file

@ -8,8 +8,10 @@
from typing import List from typing import List
from jinja2 import Template from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.inference import Message, UserMessage from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
@ -24,7 +26,7 @@ from .config import (
async def generate_rag_query( async def generate_rag_query(
config: MemoryQueryGeneratorConfig, config: MemoryQueryGeneratorConfig,
messages: List[Message], messages: List[InterleavedContent],
**kwargs, **kwargs,
): ):
""" """
@ -42,21 +44,26 @@ async def generate_rag_query(
async def default_rag_query_generator( async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig, config: DefaultMemoryQueryGeneratorConfig,
messages: List[Message], messages: List[InterleavedContent],
**kwargs, **kwargs,
): ):
return config.sep.join(interleaved_content_as_str(m.content) for m in messages) return config.sep.join(interleaved_content_as_str(m) for m in messages)
async def llm_rag_query_generator( async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig, config: LLMMemoryQueryGeneratorConfig,
messages: List[Message], messages: List[InterleavedContent],
**kwargs, **kwargs,
): ):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"] inference_api = kwargs["inference_api"]
m_dict = {"messages": [message.model_dump() for message in messages]} m_dict = {
"messages": [
message.model_dump() if isinstance(message, BaseModel) else message
for message in messages
]
}
template = Template(config.template) template = Template(config.template)
content = template.render(m_dict) content = template.render(m_dict)

View file

@ -69,7 +69,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
] ]
async def _retrieve_context( async def _retrieve_context(
self, input_messages: List[str], bank_ids: List[str] self, input_messages: List[InterleavedContent], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]: ) -> Optional[List[InterleavedContent]]:
if not bank_ids: if not bank_ids:
return None return None

View file

@ -206,7 +206,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
def test_code_execution(llama_stack_client): def test_code_execution(llama_stack_client):
agent_config = AgentConfig( agent_config = AgentConfig(
model="meta-llama/Llama-3.1-70B-Instruct", model="meta-llama/Llama-3.1-8B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
toolgroups=[ toolgroups=[
"builtin::code_interpreter", "builtin::code_interpreter",
@ -297,7 +297,7 @@ def test_rag_agent(llama_stack_client, agent_config):
agent_config["toolgroups"].append( agent_config["toolgroups"].append(
dict( dict(
name="builtin::memory", name="builtin::memory",
args={"memory_bank_id": memory_bank_id}, args={"memory_bank_ids": [memory_bank_id]},
) )
) )
agent = Agent(llama_stack_client, agent_config) agent = Agent(llama_stack_client, agent_config)