mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
fix the rag query generator types
This commit is contained in:
parent
efe3189728
commit
82395ba654
4 changed files with 19 additions and 10 deletions
|
@ -401,7 +401,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
if session_info.memory_bank_id:
|
if session_info.memory_bank_id:
|
||||||
query_args["memory_bank_id"] = session_info.memory_bank_id
|
if "memory_bank_ids" not in query_args:
|
||||||
|
query_args["memory_bank_ids"] = []
|
||||||
|
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
|
|
@ -8,8 +8,10 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message, UserMessage
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
from llama_stack.apis.inference import UserMessage
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
@ -24,7 +26,7 @@ from .config import (
|
||||||
|
|
||||||
async def generate_rag_query(
|
async def generate_rag_query(
|
||||||
config: MemoryQueryGeneratorConfig,
|
config: MemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -42,21 +44,26 @@ async def generate_rag_query(
|
||||||
|
|
||||||
async def default_rag_query_generator(
|
async def default_rag_query_generator(
|
||||||
config: DefaultMemoryQueryGeneratorConfig,
|
config: DefaultMemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
async def llm_rag_query_generator(
|
async def llm_rag_query_generator(
|
||||||
config: LLMMemoryQueryGeneratorConfig,
|
config: LLMMemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||||
inference_api = kwargs["inference_api"]
|
inference_api = kwargs["inference_api"]
|
||||||
|
|
||||||
m_dict = {"messages": [message.model_dump() for message in messages]}
|
m_dict = {
|
||||||
|
"messages": [
|
||||||
|
message.model_dump() if isinstance(message, BaseModel) else message
|
||||||
|
for message in messages
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
template = Template(config.template)
|
template = Template(config.template)
|
||||||
content = template.render(m_dict)
|
content = template.render(m_dict)
|
||||||
|
|
|
@ -69,7 +69,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
]
|
]
|
||||||
|
|
||||||
async def _retrieve_context(
|
async def _retrieve_context(
|
||||||
self, input_messages: List[str], bank_ids: List[str]
|
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
||||||
) -> Optional[List[InterleavedContent]]:
|
) -> Optional[List[InterleavedContent]]:
|
||||||
if not bank_ids:
|
if not bank_ids:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -206,7 +206,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||||
|
|
||||||
def test_code_execution(llama_stack_client):
|
def test_code_execution(llama_stack_client):
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="meta-llama/Llama-3.1-70B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
|
@ -297,7 +297,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
agent_config["toolgroups"].append(
|
agent_config["toolgroups"].append(
|
||||||
dict(
|
dict(
|
||||||
name="builtin::memory",
|
name="builtin::memory",
|
||||||
args={"memory_bank_id": memory_bank_id},
|
args={"memory_bank_ids": [memory_bank_id]},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue