forked from phoenix-oss/llama-stack-mirror
Accept "query_config" params for the RAG tool
This commit is contained in:
parent
f4f47970e5
commit
08dcb9e31e
1 changed files with 12 additions and 12 deletions
|
@ -18,6 +18,7 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
@ -60,13 +61,7 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||||
DefaultRAGQueryGeneratorConfig,
|
|
||||||
RAGDocument,
|
|
||||||
RAGQueryConfig,
|
|
||||||
ToolGroups,
|
|
||||||
ToolRuntime,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
|
@ -410,6 +405,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
args = toolgroup_args.get(MEMORY_GROUP, {})
|
args = toolgroup_args.get(MEMORY_GROUP, {})
|
||||||
vector_db_ids = args.get("vector_db_ids", [])
|
vector_db_ids = args.get("vector_db_ids", [])
|
||||||
|
query_config = args.get("query_config")
|
||||||
|
if query_config:
|
||||||
|
query_config = TypeAdapter(RAGQueryConfig).validate_python(
|
||||||
|
query_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# handle someone passing an empty dict
|
||||||
|
query_config = RAGQueryConfig()
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
|
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
|
@ -437,11 +441,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
[msg.content for msg in input_messages]
|
[msg.content for msg in input_messages]
|
||||||
),
|
),
|
||||||
vector_db_ids=vector_db_ids,
|
vector_db_ids=vector_db_ids,
|
||||||
query_config=RAGQueryConfig(
|
query_config=query_config,
|
||||||
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
|
||||||
max_tokens_in_context=4096,
|
|
||||||
max_chunks=5,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
retrieved_context = result.content
|
retrieved_context = result.content
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue