Accept "query_config" params for the RAG tool

This commit is contained in:
Ashwin Bharambe 2025-01-22 16:42:36 -08:00
parent f4f47970e5
commit 08dcb9e31e

View file

@ -18,6 +18,7 @@ from urllib.parse import urlparse
import httpx import httpx
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
from pydantic import TypeAdapter
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
@ -60,13 +61,7 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ( from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
DefaultRAGQueryGeneratorConfig,
RAGDocument,
RAGQueryConfig,
ToolGroups,
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
@ -410,6 +405,15 @@ class ChatAgent(ShieldRunnerMixin):
args = toolgroup_args.get(MEMORY_GROUP, {}) args = toolgroup_args.get(MEMORY_GROUP, {})
vector_db_ids = args.get("vector_db_ids", []) vector_db_ids = args.get("vector_db_ids", [])
query_config = args.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(
query_config
)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
@ -437,11 +441,7 @@ class ChatAgent(ShieldRunnerMixin):
[msg.content for msg in input_messages] [msg.content for msg in input_messages]
), ),
vector_db_ids=vector_db_ids, vector_db_ids=vector_db_ids,
query_config=RAGQueryConfig( query_config=query_config,
query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096,
max_chunks=5,
),
) )
retrieved_context = result.content retrieved_context = result.content