mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
fix the rag query generator types
This commit is contained in:
parent
efe3189728
commit
82395ba654
4 changed files with 19 additions and 10 deletions
|
@ -401,7 +401,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
query_args["memory_bank_id"] = session_info.memory_bank_id
|
||||
if "memory_bank_ids" not in query_args:
|
||||
query_args["memory_bank_ids"] = []
|
||||
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
|
|
@ -8,8 +8,10 @@
|
|||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Message, UserMessage
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
@ -24,7 +26,7 @@ from .config import (
|
|||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -42,21 +44,26 @@ async def generate_rag_query(
|
|||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {"messages": [message.model_dump() for message in messages]}
|
||||
m_dict = {
|
||||
"messages": [
|
||||
message.model_dump() if isinstance(message, BaseModel) else message
|
||||
for message in messages
|
||||
]
|
||||
}
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
|
|
|
@ -69,7 +69,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
]
|
||||
|
||||
async def _retrieve_context(
|
||||
self, input_messages: List[str], bank_ids: List[str]
|
||||
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
|
|
|
@ -206,7 +206,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
|||
|
||||
def test_code_execution(llama_stack_client):
|
||||
agent_config = AgentConfig(
|
||||
model="meta-llama/Llama-3.1-70B-Instruct",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
toolgroups=[
|
||||
"builtin::code_interpreter",
|
||||
|
@ -297,7 +297,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
agent_config["toolgroups"].append(
|
||||
dict(
|
||||
name="builtin::memory",
|
||||
args={"memory_bank_id": memory_bank_id},
|
||||
args={"memory_bank_ids": [memory_bank_id]},
|
||||
)
|
||||
)
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue