feat: Making static prompt values in Rag/File Search configurable in Vector Store Config (#4368)

# What does this PR do?

- Enables users to configure prompts used throughout the File Search /
Vector Retrieval
- Configuration is defined in the Vector Stores Config so they can be
modified at runtime
- Backwards compatible, which means the fields are optional and default
to the previously used values

This is the summary of the new options in the `run.yaml`
```yaml
vector_stores:
  file_search_params:
    header_template: 'knowledge_search tool found {num_chunks} chunks:\nBEGIN of knowledge_search tool results.\n'
    footer_template: 'END of knowledge_search tool results.\n'
  context_prompt_params:
    chunk_annotation_template: 'Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n'
    context_template: 'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{annotation_instruction}\n'
  annotation_prompt_params:
    enable_annotations: true
    annotation_instruction_template: 'Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format like \'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.\'. Do not add
extra punctuation. Use only the file IDs provided, do not invent new ones.'
    chunk_annotation_template: '[{index}] {metadata_text} cite as <|{file_id}|>\n{chunk_text}\n'
```

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
Added tests.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-12-15 11:39:01 -05:00 committed by GitHub
parent 4043dedeea
commit 62005dc1a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 42640 additions and 40 deletions

View file

@ -18,7 +18,15 @@ from llama_stack.core.storage.datatypes import (
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
from llama_stack.providers.utils.memory.constants import (
DEFAULT_ANNOTATION_INSTRUCTION_TEMPLATE,
DEFAULT_CHUNK_ANNOTATION_TEMPLATE,
DEFAULT_CHUNK_WITH_SOURCES_TEMPLATE,
DEFAULT_CONTEXT_TEMPLATE,
DEFAULT_FILE_SEARCH_FOOTER_TEMPLATE,
DEFAULT_FILE_SEARCH_HEADER_TEMPLATE,
DEFAULT_QUERY_REWRITE_PROMPT,
)
from llama_stack_api import (
Api,
Benchmark,
@ -371,6 +379,125 @@ class RewriteQueryParams(BaseModel):
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
)
@field_validator("prompt")
@classmethod
def validate_prompt(cls, v: str) -> str:
if "{query}" not in v:
raise ValueError("prompt must contain {query} placeholder")
return v
@field_validator("max_tokens")
@classmethod
def validate_max_tokens(cls, v: int) -> int:
if v <= 0:
raise ValueError("max_tokens must be positive")
if v > 4096:
raise ValueError("max_tokens should not exceed 4096")
return v
@field_validator("temperature")
@classmethod
def validate_temperature(cls, v: float) -> float:
if v < 0.0 or v > 2.0:
raise ValueError("temperature must be between 0.0 and 2.0")
return v
class FileSearchParams(BaseModel):
"""Configuration for file search tool output formatting."""
header_template: str = Field(
default=DEFAULT_FILE_SEARCH_HEADER_TEMPLATE,
description="Template for the header text shown before search results. Available placeholders: {num_chunks} number of chunks found.",
)
footer_template: str = Field(
default=DEFAULT_FILE_SEARCH_FOOTER_TEMPLATE,
description="Template for the footer text shown after search results.",
)
@field_validator("header_template")
@classmethod
def validate_header_template(cls, v: str) -> str:
if len(v) == 0:
raise ValueError("header_template must not be empty")
if "{num_chunks}" not in v:
raise ValueError("header_template must contain {num_chunks} placeholder")
if "knowledge_search" not in v.lower():
raise ValueError(
"header_template must contain 'knowledge_search' keyword to ensure proper tool identification"
)
return v
class ContextPromptParams(BaseModel):
"""Configuration for LLM prompt content and chunk formatting."""
chunk_annotation_template: str = Field(
default=DEFAULT_CHUNK_ANNOTATION_TEMPLATE,
description="Template for formatting individual chunks in search results. Available placeholders: {index} 1-based chunk index, {chunk.content} chunk content, {metadata} chunk metadata dict.",
)
context_template: str = Field(
default=DEFAULT_CONTEXT_TEMPLATE,
description="Template for explaining the search results to the model. Available placeholders: {query} user's query, {num_chunks} number of chunks.",
)
@field_validator("chunk_annotation_template")
@classmethod
def validate_chunk_annotation_template(cls, v: str) -> str:
if len(v) == 0:
raise ValueError("chunk_annotation_template must not be empty")
if "{chunk.content}" not in v:
raise ValueError("chunk_annotation_template must contain {chunk.content} placeholder")
if "{index}" not in v:
raise ValueError("chunk_annotation_template must contain {index} placeholder")
return v
@field_validator("context_template")
@classmethod
def validate_context_template(cls, v: str) -> str:
if len(v) == 0:
raise ValueError("context_template must not be empty")
if "{query}" not in v:
raise ValueError("context_template must contain {query} placeholder")
return v
class AnnotationPromptParams(BaseModel):
"""Configuration for source annotation and attribution features."""
enable_annotations: bool = Field(
default=True,
description="Whether to include annotation information in results.",
)
annotation_instruction_template: str = Field(
default=DEFAULT_ANNOTATION_INSTRUCTION_TEMPLATE,
description="Instructions for how the model should cite sources. Used when enable_annotations is True.",
)
chunk_annotation_template: str = Field(
default=DEFAULT_CHUNK_WITH_SOURCES_TEMPLATE,
description="Template for chunks with annotation information. Available placeholders: {index} 1-based chunk index, {metadata_text} formatted metadata, {file_id} document identifier, {chunk_text} chunk content.",
)
@field_validator("chunk_annotation_template")
@classmethod
def validate_chunk_annotation_template(cls, v: str) -> str:
if len(v) == 0:
raise ValueError("chunk_annotation_template must not be empty")
if "{index}" not in v:
raise ValueError("chunk_annotation_template must contain {index} placeholder")
if "{chunk_text}" not in v:
raise ValueError("chunk_annotation_template must contain {chunk_text} placeholder")
if "{file_id}" not in v:
raise ValueError("chunk_annotation_template must contain {file_id} placeholder")
return v
@field_validator("annotation_instruction_template")
@classmethod
def validate_annotation_instruction_template(cls, v: str) -> str:
if len(v) == 0:
raise ValueError("annotation_instruction_template must not be empty")
return v
class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""
@ -387,6 +514,18 @@ class VectorStoresConfig(BaseModel):
default=None,
description="Parameters for query rewriting/expansion. None disables query rewriting.",
)
file_search_params: FileSearchParams = Field(
default_factory=FileSearchParams,
description="Configuration for file search tool output formatting.",
)
context_prompt_params: ContextPromptParams = Field(
default_factory=ContextPromptParams,
description="Configuration for LLM prompt content and chunk formatting.",
)
annotation_prompt_params: AnnotationPromptParams = Field(
default_factory=AnnotationPromptParams,
description="Configuration for source annotation and attribution features.",
)
class SafetyConfig(BaseModel):

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import importlib.metadata
import inspect
@ -406,13 +407,17 @@ async def instantiate_provider(
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
else:
method = "get_provider_impl"
provider_config = provider.config.copy()
# Inject vector_stores_config for providers that need it (introspection-based)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
if hasattr(config_type, "__fields__") and "vector_stores_config" in config_type.__fields__:
provider_config["vector_stores_config"] = run_config.vector_stores
config = config_type(**provider_config)
args = [config, deps]
if "policy" in inspect.signature(getattr(module, method)).parameters:
args.append(policy)
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_id__ = provider.provider_id

View file

@ -229,8 +229,6 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
if vector_stores_config.rewrite_query_params:
if vector_stores_config.rewrite_query_params.model:
await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls)
if "{query}" not in vector_stores_config.rewrite_query_params.prompt:
raise ValueError("'{query}' placeholder is required in the prompt template")
async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None: