diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1b375fba7..a57b989a0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -18,6 +18,7 @@ from urllib.parse import urlparse import httpx from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition +from pydantic import TypeAdapter from llama_stack.apis.agents import ( AgentConfig, @@ -60,13 +61,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import ( - DefaultRAGQueryGeneratorConfig, - RAGDocument, - RAGQueryConfig, - ToolGroups, - ToolRuntime, -) +from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content @@ -410,6 +405,15 @@ class ChatAgent(ShieldRunnerMixin): args = toolgroup_args.get(MEMORY_GROUP, {}) vector_db_ids = args.get("vector_db_ids", []) + query_config = args.get("query_config") + if query_config: + query_config = TypeAdapter(RAGQueryConfig).validate_python( + query_config + ) + else: + # handle someone passing an empty dict + query_config = RAGQueryConfig() + session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it @@ -437,11 +441,7 @@ class ChatAgent(ShieldRunnerMixin): [msg.content for msg in input_messages] ), vector_db_ids=vector_db_ids, - query_config=RAGQueryConfig( - query_generator_config=DefaultRAGQueryGeneratorConfig(), - max_tokens_in_context=4096, - max_chunks=5, - ), + query_config=query_config, ) retrieved_context = result.content