forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR addresses the content dominance problem that frequently arises with multiple models when executing queries with the RAG tool. When the retrieved content is too large, it disproportionately influences the generation process, causing the model to ignore the original question and to provide meaningless comments on the retrieved information instead. This situation is especially common with agentic RAG, which is the standard way of doing RAG in Llama Stack, since directly manipulating the prompt combining the query with the retrieved content is not possible. This PR appends a grounding message to the results returned by the knowledge search tool, reminding the model about the original query and the purpose of the inference call. This makes the problem significantly less likely to occur. ## Test Plan Running the following script before the fix demonstrates the content dominance problem where the model insists to comment on the retrieved content and refuses to address the question. Running the script after the fix results in getting the correct answer. ``` import os import uuid from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient # the server endpoint LLAMA_STACK_SERVER_URL = "http://localhost:8321" # inference settings MODEL_ID = ""meta-llama/Llama-3.1-8B-Instruct" SYSTEM_PROMPT = "You are a helpful assistant. " # RAG settings VECTOR_DB_EMBEDDING_MODEL = "all-MiniLM-L6-v2" VECTOR_DB_EMBEDDING_DIMENSION = 384 VECTOR_DB_CHUNK_SIZE = 512 # initialize the server connection client = LlamaStackClient(base_url=os.environ.get("LLAMA_STACK_ENDPOINT", LLAMA_STACK_SERVER_URL)) # init the RAG retrieval parameters vector_db_id = f"test_vector_db_{uuid.uuid4()}" vector_providers = [ provider for provider in client.providers.list() if provider.api == "vector_io" ] vector_provider_to_use = vector_providers[0] # define and register the document collection to be used client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=VECTOR_DB_EMBEDDING_MODEL, embedding_dimension=VECTOR_DB_EMBEDDING_DIMENSION, provider_id=vector_provider_to_use.provider_id, ) # ingest the documents into the newly created document collection urls = [ ("https://www.openshift.guide/openshift-guide-screen.pdf", "application/pdf"), ] documents = [ RAGDocument( document_id=f"num-{i}", content=url, mime_type=url_type, metadata={}, ) for i, (url, url_type) in enumerate(urls) ] client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=vector_db_id, chunk_size_in_tokens=VECTOR_DB_CHUNK_SIZE, ) queries = [ "How to install OpenShift?", ] # initializing the agent agent = Agent( client, model=MODEL_ID, instructions=SYSTEM_PROMPT, # we make our agent aware of the RAG tool by including builtin::rag/knowledge_search in the list of tools tools=[ dict( name="builtin::rag/knowledge_search", args={ "vector_db_ids": [vector_db_id], # list of IDs of document collections to consider during retrieval }, ) ], ) for prompt in queries: print(f"User> {prompt}") # create a new turn with a new session ID for each prompt response = agent.create_turn( messages=[ { "role": "user", "content": prompt, } ], session_id=agent.create_session(f"rag-session_{uuid.uuid4()}") ) # print the response, including tool calls output for log in AgentEventLogger().log(response): print(log.content, end='') ```
215 lines
6.9 KiB
Python
215 lines
6.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import logging
|
|
import secrets
|
|
import string
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import TypeAdapter
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
URL,
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import Inference
|
|
from llama_stack.apis.tools import (
|
|
ListToolDefsResponse,
|
|
RAGDocument,
|
|
RAGQueryConfig,
|
|
RAGQueryResult,
|
|
RAGToolRuntime,
|
|
Tool,
|
|
ToolDef,
|
|
ToolInvocationResult,
|
|
ToolParameter,
|
|
ToolRuntime,
|
|
)
|
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
|
from llama_stack.providers.utils.memory.vector_store import (
|
|
content_from_doc,
|
|
make_overlapped_chunks,
|
|
)
|
|
|
|
from .config import RagToolRuntimeConfig
|
|
from .context_retriever import generate_rag_query
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def make_random_string(length: int = 8):
|
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
|
|
|
|
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|
def __init__(
|
|
self,
|
|
config: RagToolRuntimeConfig,
|
|
vector_io_api: VectorIO,
|
|
inference_api: Inference,
|
|
):
|
|
self.config = config
|
|
self.vector_io_api = vector_io_api
|
|
self.inference_api = inference_api
|
|
|
|
async def initialize(self):
|
|
pass
|
|
|
|
async def shutdown(self):
|
|
pass
|
|
|
|
async def register_tool(self, tool: Tool) -> None:
|
|
pass
|
|
|
|
async def unregister_tool(self, tool_id: str) -> None:
|
|
return
|
|
|
|
async def insert(
|
|
self,
|
|
documents: List[RAGDocument],
|
|
vector_db_id: str,
|
|
chunk_size_in_tokens: int = 512,
|
|
) -> None:
|
|
chunks = []
|
|
for doc in documents:
|
|
content = await content_from_doc(doc)
|
|
chunks.extend(
|
|
make_overlapped_chunks(
|
|
doc.document_id,
|
|
content,
|
|
chunk_size_in_tokens,
|
|
chunk_size_in_tokens // 4,
|
|
)
|
|
)
|
|
|
|
if not chunks:
|
|
return
|
|
|
|
await self.vector_io_api.insert_chunks(
|
|
chunks=chunks,
|
|
vector_db_id=vector_db_id,
|
|
)
|
|
|
|
async def query(
|
|
self,
|
|
content: InterleavedContent,
|
|
vector_db_ids: List[str],
|
|
query_config: Optional[RAGQueryConfig] = None,
|
|
) -> RAGQueryResult:
|
|
if not vector_db_ids:
|
|
return RAGQueryResult(content=None)
|
|
|
|
query_config = query_config or RAGQueryConfig()
|
|
query = await generate_rag_query(
|
|
query_config.query_generator_config,
|
|
content,
|
|
inference_api=self.inference_api,
|
|
)
|
|
tasks = [
|
|
self.vector_io_api.query_chunks(
|
|
vector_db_id=vector_db_id,
|
|
query=query,
|
|
params={
|
|
"max_chunks": query_config.max_chunks,
|
|
},
|
|
)
|
|
for vector_db_id in vector_db_ids
|
|
]
|
|
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
|
|
chunks = [c for r in results for c in r.chunks]
|
|
scores = [s for r in results for s in r.scores]
|
|
|
|
if not chunks:
|
|
return RAGQueryResult(content=None)
|
|
|
|
# sort by score
|
|
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
|
|
chunks = chunks[: query_config.max_chunks]
|
|
|
|
tokens = 0
|
|
picked: list[InterleavedContentItem] = [
|
|
TextContentItem(
|
|
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
|
)
|
|
]
|
|
for i, c in enumerate(chunks):
|
|
metadata = c.metadata
|
|
tokens += metadata["token_count"]
|
|
if tokens > query_config.max_tokens_in_context:
|
|
log.error(
|
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
|
)
|
|
break
|
|
picked.append(
|
|
TextContentItem(
|
|
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
|
|
)
|
|
)
|
|
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
|
picked.append(
|
|
TextContentItem(
|
|
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
|
)
|
|
)
|
|
|
|
return RAGQueryResult(
|
|
content=picked,
|
|
metadata={
|
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
|
},
|
|
)
|
|
|
|
async def list_runtime_tools(
|
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
|
) -> ListToolDefsResponse:
|
|
# Parameters are not listed since these methods are not yet invoked automatically
|
|
# by the LLM. The method is only implemented so things like /tools can list without
|
|
# encountering fatals.
|
|
return ListToolDefsResponse(
|
|
data=[
|
|
ToolDef(
|
|
name="insert_into_memory",
|
|
description="Insert documents into memory",
|
|
),
|
|
ToolDef(
|
|
name="knowledge_search",
|
|
description="Search for information in a database.",
|
|
parameters=[
|
|
ToolParameter(
|
|
name="query",
|
|
description="The query to search for. Can be a natural language sentence or keywords.",
|
|
parameter_type="string",
|
|
),
|
|
],
|
|
),
|
|
]
|
|
)
|
|
|
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
|
vector_db_ids = kwargs.get("vector_db_ids", [])
|
|
query_config = kwargs.get("query_config")
|
|
if query_config:
|
|
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
|
else:
|
|
# handle someone passing an empty dict
|
|
query_config = RAGQueryConfig()
|
|
|
|
query = kwargs["query"]
|
|
result = await self.query(
|
|
content=query,
|
|
vector_db_ids=vector_db_ids,
|
|
query_config=query_config,
|
|
)
|
|
|
|
return ToolInvocationResult(
|
|
content=result.content,
|
|
metadata=result.metadata,
|
|
)
|