mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
# What does this PR do? Rather than have a single `LLAMA_STACK_VERSION`, we need to have a `_V1`, `_V1ALPHA`, and `_V1BETA` constant. This also necessitated addition of `level` to the `WebMethod` so that routing can be handeled properly. For backwards compat, the `v1` routes are being kept around and marked as `deprecated`. When used, the server will log a deprecation warning. Deprecation log: <img width="1224" height="134" alt="Screenshot 2025-09-25 at 2 43 36 PM" src="https://github.com/user-attachments/assets/0cc7c245-dafc-48f0-be99-269fb9a686f9" /> move: 1. post_training to `v1alpha` as it is under heavy development and not near its final state 2. eval: job scheduling is not implemented. Relies heavily on the datasetio API which is under development missing implementations of specific routes indicating the structure of those routes might change. Additionally eval depends on the `inference` API which is going to be deprecated, eval will likely need a major API surface change to conform to using completions properly implements leveling in #3317 note: integration tests will fail until the SDK is regenerated with v1alpha/inference as opposed to v1/inference ## Test Plan existing tests should pass with newly generated schema. Conformance will also pass as these routes are not the ones we currently test for stability Signed-off-by: Charlie Doern <cdoern@redhat.com>
218 lines
7.4 KiB
Python
218 lines
7.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from enum import Enum, StrEnum
|
|
from typing import Annotated, Any, Literal, Protocol
|
|
|
|
from pydantic import BaseModel, Field, field_validator
|
|
from typing_extensions import runtime_checkable
|
|
|
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
|
|
|
|
|
@json_schema_type
|
|
class RRFRanker(BaseModel):
|
|
"""
|
|
Reciprocal Rank Fusion (RRF) ranker configuration.
|
|
|
|
:param type: The type of ranker, always "rrf"
|
|
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
|
|
Must be greater than 0
|
|
"""
|
|
|
|
type: Literal["rrf"] = "rrf"
|
|
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
|
|
|
|
|
@json_schema_type
|
|
class WeightedRanker(BaseModel):
|
|
"""
|
|
Weighted ranker configuration that combines vector and keyword scores.
|
|
|
|
:param type: The type of ranker, always "weighted"
|
|
:param alpha: Weight factor between 0 and 1.
|
|
0 means only use keyword scores,
|
|
1 means only use vector scores,
|
|
values in between blend both scores.
|
|
"""
|
|
|
|
type: Literal["weighted"] = "weighted"
|
|
alpha: float = Field(
|
|
default=0.5,
|
|
ge=0.0,
|
|
le=1.0,
|
|
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
|
|
)
|
|
|
|
|
|
Ranker = Annotated[
|
|
RRFRanker | WeightedRanker,
|
|
Field(discriminator="type"),
|
|
]
|
|
register_schema(Ranker, name="Ranker")
|
|
|
|
|
|
@json_schema_type
|
|
class RAGDocument(BaseModel):
|
|
"""
|
|
A document to be used for document ingestion in the RAG Tool.
|
|
|
|
:param document_id: The unique identifier for the document.
|
|
:param content: The content of the document.
|
|
:param mime_type: The MIME type of the document.
|
|
:param metadata: Additional metadata for the document.
|
|
"""
|
|
|
|
document_id: str
|
|
content: InterleavedContent | URL
|
|
mime_type: str | None = None
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
@json_schema_type
|
|
class RAGQueryResult(BaseModel):
|
|
"""Result of a RAG query containing retrieved content and metadata.
|
|
|
|
:param content: (Optional) The retrieved content from the query
|
|
:param metadata: Additional metadata about the query result
|
|
"""
|
|
|
|
content: InterleavedContent | None = None
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
@json_schema_type
|
|
class RAGQueryGenerator(Enum):
|
|
"""Types of query generators for RAG systems.
|
|
|
|
:cvar default: Default query generator using simple text processing
|
|
:cvar llm: LLM-based query generator for enhanced query understanding
|
|
:cvar custom: Custom query generator implementation
|
|
"""
|
|
|
|
default = "default"
|
|
llm = "llm"
|
|
custom = "custom"
|
|
|
|
|
|
@json_schema_type
|
|
class RAGSearchMode(StrEnum):
|
|
"""
|
|
Search modes for RAG query retrieval:
|
|
- VECTOR: Uses vector similarity search for semantic matching
|
|
- KEYWORD: Uses keyword-based search for exact matching
|
|
- HYBRID: Combines both vector and keyword search for better results
|
|
"""
|
|
|
|
VECTOR = "vector"
|
|
KEYWORD = "keyword"
|
|
HYBRID = "hybrid"
|
|
|
|
|
|
@json_schema_type
|
|
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
|
"""Configuration for the default RAG query generator.
|
|
|
|
:param type: Type of query generator, always 'default'
|
|
:param separator: String separator used to join query terms
|
|
"""
|
|
|
|
type: Literal["default"] = "default"
|
|
separator: str = " "
|
|
|
|
|
|
@json_schema_type
|
|
class LLMRAGQueryGeneratorConfig(BaseModel):
|
|
"""Configuration for the LLM-based RAG query generator.
|
|
|
|
:param type: Type of query generator, always 'llm'
|
|
:param model: Name of the language model to use for query generation
|
|
:param template: Template string for formatting the query generation prompt
|
|
"""
|
|
|
|
type: Literal["llm"] = "llm"
|
|
model: str
|
|
template: str
|
|
|
|
|
|
RAGQueryGeneratorConfig = Annotated[
|
|
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
|
Field(discriminator="type"),
|
|
]
|
|
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
|
|
|
|
|
@json_schema_type
|
|
class RAGQueryConfig(BaseModel):
|
|
"""
|
|
Configuration for the RAG query generation.
|
|
|
|
:param query_generator_config: Configuration for the query generator.
|
|
:param max_tokens_in_context: Maximum number of tokens in the context.
|
|
:param max_chunks: Maximum number of chunks to retrieve.
|
|
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
|
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
|
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
|
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
|
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
|
|
"""
|
|
|
|
# This config defines how a query is generated using the messages
|
|
# for memory bank retrieval.
|
|
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
|
max_tokens_in_context: int = 4096
|
|
max_chunks: int = 5
|
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
|
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
|
|
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
|
|
|
@field_validator("chunk_template")
|
|
def validate_chunk_template(cls, v: str) -> str:
|
|
if "{chunk.content}" not in v:
|
|
raise ValueError("chunk_template must contain {chunk.content}")
|
|
if "{index}" not in v:
|
|
raise ValueError("chunk_template must contain {index}")
|
|
if len(v) == 0:
|
|
raise ValueError("chunk_template must not be empty")
|
|
return v
|
|
|
|
|
|
@runtime_checkable
|
|
@trace_protocol
|
|
class RAGToolRuntime(Protocol):
|
|
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
|
|
async def insert(
|
|
self,
|
|
documents: list[RAGDocument],
|
|
vector_db_id: str,
|
|
chunk_size_in_tokens: int = 512,
|
|
) -> None:
|
|
"""Index documents so they can be used by the RAG system.
|
|
|
|
:param documents: List of documents to index in the RAG system
|
|
:param vector_db_id: ID of the vector database to store the document embeddings
|
|
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
|
|
"""
|
|
...
|
|
|
|
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
|
|
async def query(
|
|
self,
|
|
content: InterleavedContent,
|
|
vector_db_ids: list[str],
|
|
query_config: RAGQueryConfig | None = None,
|
|
) -> RAGQueryResult:
|
|
"""Query the RAG system for context; typically invoked by the agent.
|
|
|
|
:param content: The query content to search for in the indexed documents
|
|
:param vector_db_ids: List of vector database IDs to search within
|
|
:param query_config: (Optional) Configuration parameters for the query operation
|
|
:returns: RAGQueryResult containing the retrieved content and metadata
|
|
"""
|
|
...
|