mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 05:32:26 +00:00
feat: Making static prompt values in Rag/File Search configurable in Vector Store Config (#4368)
# What does this PR do?
- Enables users to configure prompts used throughout the File Search /
Vector Retrieval
- Configuration is defined in the Vector Stores Config so they can be
modified at runtime
- Backwards compatible, which means the fields are optional and default
to the previously used values
This is the summary of the new options in the `run.yaml`
```yaml
vector_stores:
file_search_params:
header_template: 'knowledge_search tool found {num_chunks} chunks:\nBEGIN of knowledge_search tool results.\n'
footer_template: 'END of knowledge_search tool results.\n'
context_prompt_params:
chunk_annotation_template: 'Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n'
context_template: 'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{annotation_instruction}\n'
annotation_prompt_params:
enable_annotations: true
annotation_instruction_template: 'Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format like \'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.\'. Do not add
extra punctuation. Use only the file IDs provided, do not invent new ones.'
chunk_annotation_template: '[{index}] {metadata_text} cite as <|{file_id}|>\n{chunk_text}\n'
```
<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->
## Test Plan
Added tests.
---------
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
4043dedeea
commit
62005dc1a9
47 changed files with 42640 additions and 40 deletions
|
|
@ -18,7 +18,15 @@ from llama_stack.core.storage.datatypes import (
|
|||
StorageConfig,
|
||||
)
|
||||
from llama_stack.log import LoggingConfig
|
||||
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
|
||||
from llama_stack.providers.utils.memory.constants import (
|
||||
DEFAULT_ANNOTATION_INSTRUCTION_TEMPLATE,
|
||||
DEFAULT_CHUNK_ANNOTATION_TEMPLATE,
|
||||
DEFAULT_CHUNK_WITH_SOURCES_TEMPLATE,
|
||||
DEFAULT_CONTEXT_TEMPLATE,
|
||||
DEFAULT_FILE_SEARCH_FOOTER_TEMPLATE,
|
||||
DEFAULT_FILE_SEARCH_HEADER_TEMPLATE,
|
||||
DEFAULT_QUERY_REWRITE_PROMPT,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
Benchmark,
|
||||
|
|
@ -371,6 +379,125 @@ class RewriteQueryParams(BaseModel):
|
|||
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
|
||||
)
|
||||
|
||||
@field_validator("prompt")
|
||||
@classmethod
|
||||
def validate_prompt(cls, v: str) -> str:
|
||||
if "{query}" not in v:
|
||||
raise ValueError("prompt must contain {query} placeholder")
|
||||
return v
|
||||
|
||||
@field_validator("max_tokens")
|
||||
@classmethod
|
||||
def validate_max_tokens(cls, v: int) -> int:
|
||||
if v <= 0:
|
||||
raise ValueError("max_tokens must be positive")
|
||||
if v > 4096:
|
||||
raise ValueError("max_tokens should not exceed 4096")
|
||||
return v
|
||||
|
||||
@field_validator("temperature")
|
||||
@classmethod
|
||||
def validate_temperature(cls, v: float) -> float:
|
||||
if v < 0.0 or v > 2.0:
|
||||
raise ValueError("temperature must be between 0.0 and 2.0")
|
||||
return v
|
||||
|
||||
|
||||
class FileSearchParams(BaseModel):
|
||||
"""Configuration for file search tool output formatting."""
|
||||
|
||||
header_template: str = Field(
|
||||
default=DEFAULT_FILE_SEARCH_HEADER_TEMPLATE,
|
||||
description="Template for the header text shown before search results. Available placeholders: {num_chunks} number of chunks found.",
|
||||
)
|
||||
footer_template: str = Field(
|
||||
default=DEFAULT_FILE_SEARCH_FOOTER_TEMPLATE,
|
||||
description="Template for the footer text shown after search results.",
|
||||
)
|
||||
|
||||
@field_validator("header_template")
|
||||
@classmethod
|
||||
def validate_header_template(cls, v: str) -> str:
|
||||
if len(v) == 0:
|
||||
raise ValueError("header_template must not be empty")
|
||||
if "{num_chunks}" not in v:
|
||||
raise ValueError("header_template must contain {num_chunks} placeholder")
|
||||
if "knowledge_search" not in v.lower():
|
||||
raise ValueError(
|
||||
"header_template must contain 'knowledge_search' keyword to ensure proper tool identification"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class ContextPromptParams(BaseModel):
|
||||
"""Configuration for LLM prompt content and chunk formatting."""
|
||||
|
||||
chunk_annotation_template: str = Field(
|
||||
default=DEFAULT_CHUNK_ANNOTATION_TEMPLATE,
|
||||
description="Template for formatting individual chunks in search results. Available placeholders: {index} 1-based chunk index, {chunk.content} chunk content, {metadata} chunk metadata dict.",
|
||||
)
|
||||
context_template: str = Field(
|
||||
default=DEFAULT_CONTEXT_TEMPLATE,
|
||||
description="Template for explaining the search results to the model. Available placeholders: {query} user's query, {num_chunks} number of chunks.",
|
||||
)
|
||||
|
||||
@field_validator("chunk_annotation_template")
|
||||
@classmethod
|
||||
def validate_chunk_annotation_template(cls, v: str) -> str:
|
||||
if len(v) == 0:
|
||||
raise ValueError("chunk_annotation_template must not be empty")
|
||||
if "{chunk.content}" not in v:
|
||||
raise ValueError("chunk_annotation_template must contain {chunk.content} placeholder")
|
||||
if "{index}" not in v:
|
||||
raise ValueError("chunk_annotation_template must contain {index} placeholder")
|
||||
return v
|
||||
|
||||
@field_validator("context_template")
|
||||
@classmethod
|
||||
def validate_context_template(cls, v: str) -> str:
|
||||
if len(v) == 0:
|
||||
raise ValueError("context_template must not be empty")
|
||||
if "{query}" not in v:
|
||||
raise ValueError("context_template must contain {query} placeholder")
|
||||
return v
|
||||
|
||||
|
||||
class AnnotationPromptParams(BaseModel):
|
||||
"""Configuration for source annotation and attribution features."""
|
||||
|
||||
enable_annotations: bool = Field(
|
||||
default=True,
|
||||
description="Whether to include annotation information in results.",
|
||||
)
|
||||
annotation_instruction_template: str = Field(
|
||||
default=DEFAULT_ANNOTATION_INSTRUCTION_TEMPLATE,
|
||||
description="Instructions for how the model should cite sources. Used when enable_annotations is True.",
|
||||
)
|
||||
chunk_annotation_template: str = Field(
|
||||
default=DEFAULT_CHUNK_WITH_SOURCES_TEMPLATE,
|
||||
description="Template for chunks with annotation information. Available placeholders: {index} 1-based chunk index, {metadata_text} formatted metadata, {file_id} document identifier, {chunk_text} chunk content.",
|
||||
)
|
||||
|
||||
@field_validator("chunk_annotation_template")
|
||||
@classmethod
|
||||
def validate_chunk_annotation_template(cls, v: str) -> str:
|
||||
if len(v) == 0:
|
||||
raise ValueError("chunk_annotation_template must not be empty")
|
||||
if "{index}" not in v:
|
||||
raise ValueError("chunk_annotation_template must contain {index} placeholder")
|
||||
if "{chunk_text}" not in v:
|
||||
raise ValueError("chunk_annotation_template must contain {chunk_text} placeholder")
|
||||
if "{file_id}" not in v:
|
||||
raise ValueError("chunk_annotation_template must contain {file_id} placeholder")
|
||||
return v
|
||||
|
||||
@field_validator("annotation_instruction_template")
|
||||
@classmethod
|
||||
def validate_annotation_instruction_template(cls, v: str) -> str:
|
||||
if len(v) == 0:
|
||||
raise ValueError("annotation_instruction_template must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
class VectorStoresConfig(BaseModel):
|
||||
"""Configuration for vector stores in the stack."""
|
||||
|
|
@ -387,6 +514,18 @@ class VectorStoresConfig(BaseModel):
|
|||
default=None,
|
||||
description="Parameters for query rewriting/expansion. None disables query rewriting.",
|
||||
)
|
||||
file_search_params: FileSearchParams = Field(
|
||||
default_factory=FileSearchParams,
|
||||
description="Configuration for file search tool output formatting.",
|
||||
)
|
||||
context_prompt_params: ContextPromptParams = Field(
|
||||
default_factory=ContextPromptParams,
|
||||
description="Configuration for LLM prompt content and chunk formatting.",
|
||||
)
|
||||
annotation_prompt_params: AnnotationPromptParams = Field(
|
||||
default_factory=AnnotationPromptParams,
|
||||
description="Configuration for source annotation and attribution features.",
|
||||
)
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
|
|
@ -406,13 +407,17 @@ async def instantiate_provider(
|
|||
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
provider_config = provider.config.copy()
|
||||
|
||||
# Inject vector_stores_config for providers that need it (introspection-based)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
if hasattr(config_type, "__fields__") and "vector_stores_config" in config_type.__fields__:
|
||||
provider_config["vector_stores_config"] = run_config.vector_stores
|
||||
|
||||
config = config_type(**provider_config)
|
||||
args = [config, deps]
|
||||
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
||||
args.append(policy)
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_id__ = provider.provider_id
|
||||
|
|
|
|||
|
|
@ -229,8 +229,6 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
|
|||
if vector_stores_config.rewrite_query_params:
|
||||
if vector_stores_config.rewrite_query_params.model:
|
||||
await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls)
|
||||
if "{query}" not in vector_stores_config.rewrite_query_params.prompt:
|
||||
raise ValueError("'{query}' placeholder is required in the prompt template")
|
||||
|
||||
|
||||
async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue