Merge branch 'main' into add-expiration-files-remote-s3

This commit is contained in:
Matthew Farrellee 2025-08-29 15:32:31 -04:00
commit 42d15a4516
27 changed files with 2165 additions and 2402 deletions

View file

@ -12,6 +12,60 @@ That means you'll get fast and efficient vector retrieval.
- Easy to use
- Fully integrated with Llama Stack
There are three implementations of search for PGVectoIndex available:
1. Vector Search:
- How it works:
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
-Characteristics:
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
- Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions)
- Best for: Finding conceptually related content, handling synonyms, cross-language search
2. Keyword Search
- How it works:
- Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank
- Converts text to searchable tokens using to_tsvector('english', text). Default language is English.
- Eg. SQL query: SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
- Characteristics:
- Lexical matching - finds exact keyword matches and variations
- Uses GIN (Generalized Inverted Index) for fast text search performance
- Scoring: Uses PostgreSQL's ts_rank function for relevance scoring
- Best for: Exact term matching, proper names, technical terms, Boolean-style queries
3. Hybrid Search
- How it works:
- Combines both vector and keyword search results
- Runs both searches independently, then merges results using configurable reranking
- Two reranking strategies available:
- Reciprocal Rank Fusion (RRF) - (default: 60.0)
- Weighted Average - (default: 0.5)
- Characteristics:
- Best of both worlds: semantic understanding + exact matching
- Documents appearing in both searches get boosted scores
- Configurable balance between semantic and lexical matching
- Best for: General-purpose search where you want both precision and recall
4. Database Schema
The PGVector implementation stores data optimized for all three search types:
CREATE TABLE vector_store_xxx (
id TEXT PRIMARY KEY,
document JSONB, -- Original document
embedding vector(dimension), -- For vector search
content_text TEXT, -- Raw text content
tokenized_content TSVECTOR -- For keyword search
);
-- Indexes for performance
CREATE INDEX content_gin_idx ON table USING GIN(tokenized_content); -- Keyword search
-- Vector index created automatically by pgvector
## Usage
To use PGVector in your Llama Stack project, follow these steps:
@ -20,6 +74,25 @@ To use PGVector in your Llama Stack project, follow these steps:
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
3. Start storing and querying vectors.
## This is an example how you can set up your environment for using PGVector
1. Export env vars:
```bash
export ENABLE_PGVECTOR=true
export PGVECTOR_HOST=localhost
export PGVECTOR_PORT=5432
export PGVECTOR_DB=llamastack
export PGVECTOR_USER=llamastack
export PGVECTOR_PASSWORD=llamastack
```
2. Create DB:
```bash
psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';"
psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;"
psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;"
```
## Installation
You can install PGVector using docker:

View file

@ -17,6 +17,7 @@ Weaviate supports:
- Metadata filtering
- Multi-modal retrieval
## Usage
To use Weaviate in your Llama Stack project, follow these steps:

View file

@ -478,7 +478,6 @@ llama-stack-client scoring_functions list
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ description ┃ type ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ basic::bfcl │ basic │ BFCL complex scoring │ scoring_function │
│ basic::docvqa │ basic │ DocVQA Visual Question & Answer scoring function │ scoring_function │
│ basic::equality │ basic │ Returns 1.0 if the input is equal to the target, 0.0 │ scoring_function │
│ │ │ otherwise. │ │

View file

@ -105,12 +105,12 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
if hasattr(obj, "provider_id"):
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
if not obj.provider_id or obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,

View file

@ -43,7 +43,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
"openai",
[
ProviderModelEntry(
provider_model_id="openai/gpt-4o",
provider_model_id="gpt-4o",
model_type=ModelType.llm,
)
],
@ -53,7 +53,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
"anthropic",
[
ProviderModelEntry(
provider_model_id="anthropic/claude-3-5-sonnet-latest",
provider_model_id="claude-3-5-sonnet-latest",
model_type=ModelType.llm,
)
],
@ -206,13 +206,6 @@ def get_distribution_template() -> DistributionTemplate:
uri="huggingface://datasets/llamastack/math_500?split=test",
),
),
DatasetInput(
dataset_id="bfcl",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/bfcl_v3?split=train",
),
),
DatasetInput(
dataset_id="ifeval",
purpose=DatasetPurpose.eval_messages_answer,
@ -250,11 +243,6 @@ def get_distribution_template() -> DistributionTemplate:
dataset_id="math_500",
scoring_functions=["basic::regex_parser_math_response"],
),
BenchmarkInput(
benchmark_id="meta-reference-bfcl",
dataset_id="bfcl",
scoring_functions=["basic::bfcl"],
),
BenchmarkInput(
benchmark_id="meta-reference-ifeval",
dataset_id="ifeval",

View file

@ -136,14 +136,14 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db
models:
- metadata: {}
model_id: openai/gpt-4o
model_id: gpt-4o
provider_id: openai
provider_model_id: openai/gpt-4o
provider_model_id: gpt-4o
model_type: llm
- metadata: {}
model_id: anthropic/claude-3-5-sonnet-latest
model_id: claude-3-5-sonnet-latest
provider_id: anthropic
provider_model_id: anthropic/claude-3-5-sonnet-latest
provider_model_id: claude-3-5-sonnet-latest
model_type: llm
- metadata: {}
model_id: gemini/gemini-1.5-flash
@ -188,12 +188,6 @@ datasets:
uri: huggingface://datasets/llamastack/math_500?split=test
metadata: {}
dataset_id: math_500
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
metadata: {}
dataset_id: bfcl
- purpose: eval/messages-answer
source:
type: uri
@ -228,11 +222,6 @@ benchmarks:
- basic::regex_parser_math_response
metadata: {}
benchmark_id: meta-reference-math-500
- dataset_id: bfcl
scoring_functions:
- basic::bfcl
metadata: {}
benchmark_id: meta-reference-bfcl
- dataset_id: ifeval
scoring_functions:
- basic::ifeval

View file

@ -22,7 +22,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
)
from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
@ -37,7 +36,6 @@ FIXED_FNS = [
SubsetOfScoringFn,
RegexParserScoringFn,
RegexParserMathResponseScoringFn,
BFCLScoringFn,
IfEvalScoringFn,
DocVQAScoringFn,
]

View file

@ -1,93 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from ..utils.bfcl.ast_parser import decode_ast
from ..utils.bfcl.checker import ast_checker, is_empty_output
from .fn_defs.bfcl import bfcl
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
contain_func_call = False
error = None
error_type = None
checker_result = {}
try:
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
contain_func_call = True
# if not is_function_calling_format_output(prediction):
if is_empty_output(prediction):
contain_func_call = False
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
error_type = "ast_decoder:decoder_wrong_output_format"
else:
checker_result = ast_checker(
json.loads(x["function"]),
prediction,
json.loads(x["ground_truth"]),
x["language"],
test_category=test_category,
model_name="",
)
except Exception as e:
prediction = ""
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
error_type = "ast_decoder:decoder_failed"
return {
"prediction": prediction,
"contain_func_call": contain_func_call,
"valid": checker_result.get("valid", False),
"error": error or checker_result.get("error", ""),
"error_type": error_type or checker_result.get("error_type", ""),
}
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
return {"valid": x["valid"]}
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
# If `test_category` is "irrelevance", the model is expected to output no function call.
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
return {"valid": float(acc)}
class BFCLScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn for BFCL
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
bfcl.identifier: bfcl,
}
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "bfcl",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
score_result = postprocess(input_row, test_category)
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
score = gen_relevance_acc(score_result)["valid"]
else:
score = gen_valid(score_result)["valid"]
return {
"score": float(score),
}

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
bfcl = ScoringFn(
identifier="basic::bfcl",
description="BFCL complex scoring",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="bfcl",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,296 +0,0 @@
# ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import ast
from .tree_sitter import get_parser
def parse_java_function_call(source_code):
if not source_code.endswith(";"):
source_code += ";" # Necessary for the parser not to register an error
parser = get_parser("java")
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node
if root_node.has_error:
raise Exception("Error parsing java the source code.")
def get_text(node):
"""Returns the text represented by the node."""
return source_code[node.start_byte : node.end_byte]
def traverse_node(node, nested=False):
if node.type == "string_literal":
if nested:
return get_text(node)
# Strip surrounding quotes from string literals
return get_text(node)[1:-1]
elif node.type == "character_literal":
if nested:
return get_text(node)
# Strip surrounding single quotes from character literals
return get_text(node)[1:-1]
"""Traverse the node to collect texts for complex structures."""
if node.type in [
"identifier",
"class_literal",
"type_identifier",
"method_invocation",
]:
return get_text(node)
elif node.type == "array_creation_expression":
# Handle array creation expression specifically
type_node = node.child_by_field_name("type")
value_node = node.child_by_field_name("value")
type_text = traverse_node(type_node, True)
value_text = traverse_node(value_node, True)
return f"new {type_text}[]{value_text}"
elif node.type == "object_creation_expression":
# Handle object creation expression specifically
type_node = node.child_by_field_name("type")
arguments_node = node.child_by_field_name("arguments")
type_text = traverse_node(type_node, True)
if arguments_node:
# Process each argument carefully, avoiding unnecessary punctuation
argument_texts = []
for child in arguments_node.children:
if child.type not in [
",",
"(",
")",
]: # Exclude commas and parentheses
argument_text = traverse_node(child, True)
argument_texts.append(argument_text)
arguments_text = ", ".join(argument_texts)
return f"new {type_text}({arguments_text})"
else:
return f"new {type_text}()"
elif node.type == "set":
# Handling sets specifically
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
return "{" + ", ".join(items) + "}"
elif node.child_count > 0:
return "".join(traverse_node(child, True) for child in node.children)
else:
return get_text(node)
def extract_arguments(args_node):
arguments = {}
for child in args_node.children:
if child.type == "assignment_expression":
# For named parameters
name_node, value_node = child.children[0], child.children[2]
name = get_text(name_node)
value = traverse_node(value_node)
if name in arguments:
if not isinstance(arguments[name], list):
arguments[name] = [arguments[name]]
arguments[name].append(value)
else:
arguments[name] = value
# arguments.append({'name': name, 'value': value})
elif child.type in ["identifier", "class_literal", "set"]:
# For unnamed parameters and handling sets
value = traverse_node(child)
if None in arguments:
if not isinstance(arguments[None], list):
arguments[None] = [arguments[None]]
arguments[None].append(value)
else:
arguments[None] = value
return arguments
def traverse(node):
if node.type == "method_invocation":
# Extract the function name and its arguments
method_name = get_text(node.child_by_field_name("name"))
class_name_node = node.child_by_field_name("object")
if class_name_node:
class_name = get_text(class_name_node)
function_name = f"{class_name}.{method_name}"
else:
function_name = method_name
arguments_node = node.child_by_field_name("arguments")
if arguments_node:
arguments = extract_arguments(arguments_node)
for key, value in arguments.items():
if isinstance(value, list):
raise Exception("Error: Multiple arguments with the same name are not supported.")
return [{function_name: arguments}]
else:
for child in node.children:
result = traverse(child)
if result:
return result
result = traverse(root_node)
return result if result else {}
def parse_javascript_function_call(source_code):
if not source_code.endswith(";"):
source_code += ";" # Necessary for the parser not to register an error
parser = get_parser("javascript")
# Parse the source code
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node
if root_node.has_error:
raise Exception("Error js parsing the source code.")
# Function to recursively extract argument details
def extract_arguments(node):
args = {}
for child in node.children:
if child.type == "assignment_expression":
# Extract left (name) and right (value) parts of the assignment
name = child.children[0].text.decode("utf-8")
value = child.children[2].text.decode("utf-8")
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
value = value[1:-1] # Trim the quotation marks
if name in args:
if not isinstance(args[name], list):
args[name] = [args[name]]
args[name].append(value)
else:
args[name] = value
elif child.type == "identifier" or child.type == "true":
# Handle non-named arguments and boolean values
value = child.text.decode("utf-8")
if None in args:
if not isinstance(args[None], list):
args[None] = [args[None]]
args[None].append(value)
else:
args[None] = value
return args
# Find the function call and extract its name and arguments
if root_node.type == "program":
for child in root_node.children:
if child.type == "expression_statement":
for sub_child in child.children:
if sub_child.type == "call_expression":
function_name = sub_child.children[0].text.decode("utf8")
arguments_node = sub_child.children[1]
parameters = extract_arguments(arguments_node)
for key, value in parameters.items():
if isinstance(value, list):
raise Exception("Error: Multiple arguments with the same name are not supported.")
result = [{function_name: parameters}]
return result
def ast_parse(input_str, language="Python"):
if language == "Python":
cleaned_input = input_str.strip("[]'")
parsed = ast.parse(cleaned_input, mode="eval")
extracted = []
if isinstance(parsed.body, ast.Call):
extracted.append(resolve_ast_call(parsed.body))
else:
for elem in parsed.body.elts:
extracted.append(resolve_ast_call(elem))
return extracted
elif language == "Java":
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
elif language == "JavaScript":
return parse_javascript_function_call(input_str[1:-1])
else:
raise NotImplementedError(f"Unsupported language: {language}")
def resolve_ast_call(elem):
# Handle nested attributes for deeply nested module paths
func_parts = []
func_part = elem.func
while isinstance(func_part, ast.Attribute):
func_parts.append(func_part.attr)
func_part = func_part.value
if isinstance(func_part, ast.Name):
func_parts.append(func_part.id)
func_name = ".".join(reversed(func_parts))
args_dict = {}
# Parse when args are simply passed as an unnamed dictionary arg
for arg in elem.args:
if isinstance(arg, ast.Dict):
for key, value in zip(arg.keys, arg.values):
if isinstance(key, ast.Constant):
arg_name = key.value
output = resolve_ast_by_type(value)
args_dict[arg_name] = output
for arg in elem.keywords:
output = resolve_ast_by_type(arg.value)
args_dict[arg.arg] = output
return {func_name: args_dict}
def resolve_ast_by_type(value):
if isinstance(value, ast.Constant):
if value.value is Ellipsis:
output = "..."
else:
output = value.value
elif isinstance(value, ast.UnaryOp):
output = -value.operand.value
elif isinstance(value, ast.List):
output = [resolve_ast_by_type(v) for v in value.elts]
elif isinstance(value, ast.Dict):
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
output = value.value
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
output = eval(ast.unparse(value))
elif isinstance(value, ast.Name):
output = value.id
elif isinstance(value, ast.Call):
if len(value.keywords) == 0:
output = ast.unparse(value)
else:
output = resolve_ast_call(value)
elif isinstance(value, ast.Tuple):
output = tuple(resolve_ast_by_type(v) for v in value.elts)
elif isinstance(value, ast.Lambda):
output = eval(ast.unparse(value.body[0].value))
elif isinstance(value, ast.Ellipsis):
output = "..."
elif isinstance(value, ast.Subscript):
try:
output = ast.unparse(value.body[0].value)
except:
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
else:
raise Exception(f"Unsupported AST type: {type(value)}")
return output
def decode_ast(result, language="Python"):
func = result
func = func.replace("\n", "") # remove new line characters
if not func.startswith("["):
func = "[" + func
if not func.endswith("]"):
func = func + "]"
decoded_output = ast_parse(func, language)
return decoded_output
def decode_execute(result):
func = result
func = func.replace("\n", "") # remove new line characters
if not func.startswith("["):
func = "[" + func
if not func.endswith("]"):
func = func + "]"
decode_output = ast_parse(func)
execution_list = []
for function_call in decode_output:
for key, value in function_call.items():
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
return execution_list

View file

@ -1,989 +0,0 @@
# ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
import time
from typing import Any
# Comment out for now until we actually use the rest checker in evals
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
class NoAPIKeyError(Exception):
def __init__(self):
self.message = "Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
super().__init__(self.message)
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
JAVA_TYPE_CONVERSION = {
"byte": int,
"short": int,
"integer": int,
"float": float,
"double": float,
"long": int,
"boolean": bool,
"char": str,
"Array": list,
"ArrayList": list,
"Set": set,
"HashMap": dict,
"Hashtable": dict,
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
"Stack": list,
"String": str,
"any": str,
}
JS_TYPE_CONVERSION = {
"String": str,
"integer": int,
"float": float,
"Bigint": int,
"Boolean": bool,
"dict": dict,
"array": list,
"any": str,
}
# We switch to conditional import for the following two imports to avoid unnecessary installations.
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
# from js_type_converter import js_type_converter
# from java_type_converter import java_type_converter
PYTHON_TYPE_MAPPING = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"array": list,
"tuple": list,
"dict": dict,
"any": str,
}
# This is the list of types that we need to recursively check its values
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
#### Helper functions for AST ####
def find_description(func_descriptions, name):
if type(func_descriptions) == list:
for func_description in func_descriptions:
if func_description["name"] == name:
return func_description
return None
else:
# it is a dict, there is only one function
return func_descriptions
def get_possible_answer_type(possible_answer: list):
for answer in possible_answer:
if answer != "": # Optional parameter
return type(answer)
return None
def type_checker(
param: str,
value,
possible_answer: list,
expected_type_description: str,
expected_type_converted,
nested_type_converted,
):
# NOTE: This type checker only supports nested type checking for one level deep.
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
result: Any = {
"valid": True,
"error": [],
"is_variable": False,
"error_type": "type_error:simple",
}
is_variable = False
# check for the case where a variable is used instead of a actual value.
# use the type in possible_answer as the expected type
possible_answer_type = get_possible_answer_type(possible_answer)
# if possible_answer only contains optional parameters, we can't determine the type
if possible_answer_type != None:
# we are being precise here.
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
if possible_answer_type != expected_type_converted:
is_variable = True
# value is the same type as in function description
if type(value) == expected_type_converted:
# We don't need to do recursive check for simple types
if nested_type_converted == None:
result["is_variable"] = is_variable
return result
else:
for possible_answer_item in possible_answer:
flag = True # Each parameter should match to at least one possible answer type.
# Here, we assume that each item should be the same type. We could also relax it.
if type(possible_answer_item) == list:
for value_item in value:
checker_result = type_checker(
param,
value_item,
possible_answer_item,
str(nested_type_converted),
nested_type_converted,
None,
)
if not checker_result["valid"]:
flag = False
break
if flag:
return {"valid": True, "error": [], "is_variable": is_variable}
result["valid"] = False
result["error"] = [
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
]
result["error_type"] = "type_error:nested"
# value is not as expected, check for the case where a variable is used instead of a actual value
# use the type in possible_answer as the expected type
possible_answer_type = get_possible_answer_type(possible_answer)
# if possible_answer only contains optional parameters, we can't determine the type
if possible_answer_type != None:
# we are being precise here.
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
if type(value) == possible_answer_type:
result["is_variable"] = True
return result
result["valid"] = False
result["error"].append(
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
)
result["error_type"] = "type_error:simple"
return result
def standardize_string(input_string: str):
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
# It will also convert all the single quotes to double quotes
# This is used to compare the model output with the possible answers
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
regex_string = r"[ \,\.\/\-\_\*\^]"
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
def string_checker(param: str, model_output: str, possible_answer: list):
standardize_possible_answer = []
standardize_model_output = standardize_string(model_output)
for i in range(len(possible_answer)):
if type(possible_answer[i]) == str:
standardize_possible_answer.append(standardize_string(possible_answer[i]))
if standardize_model_output not in standardize_possible_answer:
return {
"valid": False,
"error": [
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
],
"error_type": "value_error:string",
}
return {"valid": True, "error": []}
def list_checker(param: str, model_output: list, possible_answer: list):
# Convert the tuple to a list
standardize_model_output = list(model_output)
# If the element in the list is a string, we need to standardize it
for i in range(len(standardize_model_output)):
if type(standardize_model_output[i]) == str:
standardize_model_output[i] = standardize_string(model_output[i])
standardize_possible_answer: Any = []
# We also need to standardize the possible answers
for i in range(len(possible_answer)):
standardize_possible_answer.append([])
for j in range(len(possible_answer[i])):
if type(possible_answer[i][j]) == str:
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
else:
standardize_possible_answer[i].append(possible_answer[i][j])
if standardize_model_output not in standardize_possible_answer:
return {
"valid": False,
"error": [
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
],
"error_type": "value_error:list/tuple",
}
return {"valid": True, "error": []}
def dict_checker(param: str, model_output: dict, possible_answers: list):
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
# The current dataset only contains simple dictionaries, so this is sufficient.
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
for i in range(len(possible_answers)):
if possible_answers[i] == "":
continue
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
flag = True
possible_answer = possible_answers[i]
# possible_anwer is a single dictionary
for key, value in model_output.items():
if key not in possible_answer:
result["valid"] = False
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
result["error_type"] = "value_error:dict_key"
flag = False
break
standardize_value = value
# If the value is a string, we need to standardize it
if type(value) == str:
standardize_value = standardize_string(value)
# We also need to standardize the possible answers if they are string
standardize_possible_answer = []
for i in range(len(possible_answer[key])):
if type(possible_answer[key][i]) == str:
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
else:
standardize_possible_answer.append(possible_answer[key][i])
if standardize_value not in standardize_possible_answer:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
)
result["error_type"] = "value_error:dict_value"
flag = False
break
for key, value in possible_answer.items():
if key not in model_output and "" not in value:
result["valid"] = False
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
result["error_type"] = "value_error:dict_key"
flag = False
break
if flag:
return {"valid": True, "error": []}
return result
def list_dict_checker(param: str, model_output: list, possible_answers: list):
# This function takes in a list of dictionaries and checks if each dictionary is valid
# The order of the dictionaries in the list must match the order of the possible answers
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
for answer_index in range(len(possible_answers)):
flag = True # True means so far, all dictionaries are valid
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
if len(model_output) != len(possible_answers[answer_index]):
result["valid"] = False
result["error"] = ["Wrong number of dictionaries in the list."]
result["error_type"] = "value_error:list_dict_count"
flag = False
continue
for dict_index in range(len(model_output)):
result = dict_checker(
param,
model_output[dict_index],
[possible_answers[answer_index][dict_index]],
)
if not result["valid"]:
flag = False
break
if flag:
return {"valid": True, "error": []}
return result
def simple_function_checker(
func_description: dict,
model_output: dict,
possible_answer: dict,
language: str,
model_name: str,
):
possible_answer = list(possible_answer.values())[0]
# Extract function name and parameters details
func_name = func_description["name"]
param_details = func_description["parameters"]["properties"]
required_params = func_description["parameters"]["required"]
# Initialize a result dictionary
result = {
"valid": True,
"error": [],
"error_type": "simple_function_checker:unclear",
}
# Check if function name matches
if func_name not in model_output:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Function name {repr(func_name)} not found in model output."
)
result["error_type"] = "simple_function_checker:wrong_func_name"
return result
model_params = model_output[func_name]
# Check for required parameters in model output
for param in required_params:
if param not in model_params:
result["valid"] = False
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
result["error_type"] = "simple_function_checker:missing_required"
return result
# Validate types and values for each parameter in model output
for param, value in model_params.items():
if param not in param_details or param not in possible_answer:
result["valid"] = False
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
result["error_type"] = "simple_function_checker:unexpected_param"
return result
full_param_details = param_details[param]
expected_type_description = full_param_details["type"] # This is a string
is_variable = False
nested_type_converted = None
if language == "Java":
from evals.utils.bfcl.java_type_converter import java_type_converter
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
if expected_type_description in JAVA_TYPE_CONVERSION:
if type(value) != str:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
)
result["error_type"] = "type_error:java"
return result
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
value = java_type_converter(value, expected_type_description, nested_type)
else:
value = java_type_converter(value, expected_type_description)
elif language == "JavaScript":
from evals.utils.bfcl.js_type_converter import js_type_converter
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
if expected_type_description in JS_TYPE_CONVERSION:
if type(value) != str:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
)
result["error_type"] = "type_error:js"
return result
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
value = js_type_converter(value, expected_type_description, nested_type)
else:
value = js_type_converter(value, expected_type_description)
elif language == "Python":
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
# We convert all tuple value to list when the expected type is tuple.
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
if expected_type_description == "tuple" and type(value) == tuple:
value = list(value)
# Allow python auto conversion from int to float
if language == "Python" and expected_type_description == "float" and type(value) == int:
value = float(value)
# Type checking
# In fact, we only check for Python here.
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
type_check_result = type_checker(
param,
value,
possible_answer[param],
expected_type_description,
expected_type_converted,
nested_type_converted,
)
is_variable = type_check_result["is_variable"]
if not type_check_result["valid"]:
return type_check_result
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
# We can just treat the variable as a string and use the normal flow.
if not is_variable:
# Special handle for dictionaries
if expected_type_converted == dict:
result = dict_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
# Special handle for list of dictionaries
elif expected_type_converted == list and nested_type_converted == dict:
result = list_dict_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
# Special handle for strings
elif expected_type_converted == str:
# We don't check for case sensitivity for string, as long as it's not a variable
result = string_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
elif expected_type_converted == list:
result = list_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
# Check if the value is within the possible answers
if value not in possible_answer[param]:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
)
result["error_type"] = "value_error:others"
return result
# Check for optional parameters not provided but allowed
for param in possible_answer:
if param not in model_params and "" not in possible_answer[param]:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Optional parameter {repr(param)} not provided and not marked as optional."
)
result["error_type"] = "simple_function_checker:missing_optional"
return result
return result
def parallel_function_checker_enforce_order(
func_descriptions: list,
model_output: list,
possible_answers: dict,
language: str,
model_name: str,
):
if len(model_output) != len(possible_answers):
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "parallel_function_checker_enforce_order:wrong_count",
}
func_name_list = list(possible_answers.keys())
possible_answers_list = []
for key, value in possible_answers.items():
possible_answers_list.append({key: value})
for i in range(len(possible_answers_list)):
func_description = find_description(func_descriptions, func_name_list[i])
result = simple_function_checker(
func_description,
model_output[i],
possible_answers_list[i],
language,
model_name,
)
if not result["valid"]:
return result
return {"valid": True, "error": []}
def parallel_function_checker_no_order(
func_descriptions: list,
model_output: list,
possible_answers: list,
language: str,
model_name: str,
):
if len(model_output) != len(possible_answers):
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "parallel_function_checker_no_order:wrong_count",
}
matched_indices = []
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
# It must be this way because we need ground truth to fetch the correct function description
for i in range(len(possible_answers)):
# possible_answers[i] is a dictionary with only one key
func_name_expected = list(possible_answers[i].keys())[0]
func_description = find_description(func_descriptions, func_name_expected)
all_errors = []
for index in range(len(model_output)):
if index in matched_indices:
continue
result = simple_function_checker(
func_description,
model_output[index],
possible_answers[i],
language,
model_name,
)
if result["valid"]:
matched_indices.append(index)
break
else:
all_errors.append(
{
f"Model Result Index {index}": {
"sub_error": result["error"],
"sub_error_type": result["error_type"],
"model_output_item": model_output[index],
"possible_answer_item": possible_answers[i],
}
}
)
if not result["valid"]:
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
all_errors.insert(
0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
)
return {
"valid": False,
"error": all_errors,
"error_type": "parallel_function_checker_no_order:cannot_find_match",
}
return {"valid": True, "error": []}
def multiple_function_checker(
func_descriptions: list,
model_output: list,
possible_answers: list,
language: str,
model_name: str,
):
if len(model_output) != len(possible_answers):
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "multiple_function_checker:wrong_count",
}
# possible_answers is a list of only one dictionary with only one key
func_name_expected = list(possible_answers[0].keys())[0]
func_description = find_description(func_descriptions, func_name_expected)
return simple_function_checker(
func_description,
model_output[0],
possible_answers[0],
language,
model_name,
)
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
if type(exec_output) != type(expected_result):
return {
"valid": False,
"error": [
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
],
"error_type": "executable_checker:wrong_result_type",
"model_executed_output": exec_output,
}
if type(exec_output) == dict:
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
# This happens when the key is a timestamp or a random number.
if is_sanity_check:
if len(exec_output) != len(expected_result):
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
],
"error_type": "executable_checker:wrong_result_type:dict_length",
"model_executed_output": exec_output,
}
else:
return result
for key, value in expected_result.items():
if key not in exec_output:
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
],
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
"model_executed_output": exec_output,
}
for key, value in exec_output.items():
if key not in expected_result:
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
],
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
"model_executed_output": exec_output,
}
if type(exec_output) == list:
if len(exec_output) != len(expected_result):
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
],
"error_type": "executable_checker:wrong_result_type:list_length",
"model_executed_output": exec_output,
}
return result
#### Helper functions for Exec ####
def executable_checker_simple(
function_call: str,
expected_result,
expected_result_type: str,
is_sanity_check=False,
):
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
exec_dict: Any = {}
try:
exec(
"from executable_python_function import *" + "\nresult=" + function_call,
exec_dict,
)
exec_output = exec_dict["result"]
except NoAPIKeyError as e:
raise e
except Exception as e:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
)
result["error_type"] = "executable_checker:execution_error"
return result
# We need to special handle the case where the execution result is a tuple and convert it to a list
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
if isinstance(exec_output, tuple):
exec_output = list(exec_output)
if expected_result_type == "exact_match":
if exec_output != expected_result:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
)
result["error_type"] = "executable_checker:wrong_result"
result["model_executed_output"] = exec_output
return result
elif expected_result_type == "real_time_match":
# Allow for 5% difference
if (type(expected_result) == float or type(expected_result) == int) and (
type(exec_output) == float or type(exec_output) == int
):
if not (
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
<= exec_output
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
):
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
)
result["error_type"] = "executable_checker:wrong_result_real_time"
result["model_executed_output"] = exec_output
return result
else:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
)
result["error_type"] = "executable_checker:wrong_result_real_time"
result["model_executed_output"] = exec_output
return result
else:
# structural match
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
if not pattern_match_result["valid"]:
return pattern_match_result
return result
def executable_checker_parallel_no_order(
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
):
if len(decoded_result) != len(expected_exec_result):
return {
"valid": False,
"error": [
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
],
"error_type": "value_error:exec_result_count",
}
matched_indices = []
for i in range(len(expected_exec_result)):
all_errors = []
for index in range(len(decoded_result)):
if index in matched_indices:
continue
result = executable_checker_simple(
decoded_result[index],
expected_exec_result[i],
expected_exec_result_type[i],
False,
)
if result["valid"]:
matched_indices.append(index)
break
else:
all_errors.append(
{
f"Model Result Index {index}": {
"sub_error": result["error"],
"sub_error_type": result["error_type"],
"model_executed_output": (
result["model_executed_output"] if "model_executed_output" in result else None
),
}
}
)
if not result["valid"]:
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
all_errors.insert(
0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
)
return {
"valid": False,
"error": all_errors,
"error_type": "executable_checker:cannot_find_match",
}
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
#### Main function ####
def executable_checker_rest(func_call, idx):
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
EVAL_GROUND_TRUTH = f.readlines()
if "https://geocode.maps.co" in func_call:
time.sleep(2)
if "requests_get" in func_call:
func_call = func_call.replace("requests_get", "requests.get")
try:
response = eval(func_call)
except Exception as e:
return {
"valid": False,
"error": [f"Execution failed. {str(e)}"],
"error_type": "executable_checker_rest:execution_error",
}
try:
if response.status_code == 200:
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
try:
if isinstance(eval_GT_json, dict):
if isinstance(response.json(), dict):
if set(eval_GT_json.keys()) == set(response.json().keys()):
return {"valid": True, "error": [], "error_type": ""}
return {
"valid": False,
"error": ["Key inconsistency"],
"error_type": "executable_checker_rest:wrong_key",
}
return {
"valid": False,
"error": [f"Expected dictionary, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
elif isinstance(eval_GT_json, list):
if isinstance(response.json(), list):
if len(eval_GT_json) != len(response.json()):
return {
"valid": False,
"error": [f"Response list length inconsistency."],
"error_type": "value_error:exec_result_rest_count",
}
else:
for i in range(len(eval_GT_json)):
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
return {
"valid": False,
"error": [f"Key inconsistency"],
"error_type": "executable_checker_rest:wrong_key",
}
return {"valid": True, "error": []}
else:
return {
"valid": False,
"error": [f"Expected list, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
return {
"valid": False,
"error": [f"Expected dict or list, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
except Exception as e:
return {
"valid": False,
"error": [
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
],
"error_type": "executable_checker_rest:response_format_error",
}
else:
return {
"valid": False,
"error": [f"Execution result status code is not 200, got {response.status_code}"],
"error_type": "executable_checker_rest:wrong_status_code",
}
except Exception as e:
return {
"valid": False,
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
"error_type": "executable_checker_rest:cannot_get_status_code",
}
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
if "parallel" in test_category:
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
elif "multiple" in test_category:
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
else:
if len(model_output) != 1:
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "simple_function_checker:wrong_count",
}
return simple_function_checker(
func_description[0],
model_output[0],
possible_answer[0],
language,
model_name,
)
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
if "multiple" in test_category or "parallel" in test_category:
return executable_checker_parallel_no_order(
decoded_result,
func_description["execution_result"],
func_description["execution_result_type"],
)
else:
if len(decoded_result) != 1:
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "simple_exec_checker:wrong_count",
}
return executable_checker_simple(
decoded_result[0],
func_description["execution_result"][0],
func_description["execution_result_type"][0],
False,
)
def is_empty_output(decoded_output):
# This function is a patch to the ast decoder for relevance detection
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
if not is_function_calling_format_output(decoded_output):
return True
if len(decoded_output) == 0:
return True
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
return True
def is_function_calling_format_output(decoded_output):
# Ensure the output is a list of dictionaries
if type(decoded_output) == list:
for item in decoded_output:
if type(item) != dict:
return False
return True
return False

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
import it from here so that we can centrally manage things as necessary.
"""
# These currently work with tree-sitter 0.23.0
# NOTE: Don't import tree-sitter or any of the language modules in the main module
# because not all environments have them. Import lazily inside functions where needed.
import importlib
import typing
if typing.TYPE_CHECKING:
import tree_sitter
def get_language(language: str) -> "tree_sitter.Language":
import tree_sitter
language_module_name = f"tree_sitter_{language}"
try:
language_module = importlib.import_module(language_module_name)
except ModuleNotFoundError as exc:
raise ValueError(
f"Language {language} is not found. Please install the tree-sitter-{language} package."
) from exc
return tree_sitter.Language(language_module.language())
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
import tree_sitter
lang = get_language(language)
return tree_sitter.Parser(lang, **kwargs)

View file

@ -404,6 +404,60 @@ That means you'll get fast and efficient vector retrieval.
- Easy to use
- Fully integrated with Llama Stack
There are three implementations of search for PGVectoIndex available:
1. Vector Search:
- How it works:
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
-Characteristics:
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
- Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions)
- Best for: Finding conceptually related content, handling synonyms, cross-language search
2. Keyword Search
- How it works:
- Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank
- Converts text to searchable tokens using to_tsvector('english', text). Default language is English.
- Eg. SQL query: SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
- Characteristics:
- Lexical matching - finds exact keyword matches and variations
- Uses GIN (Generalized Inverted Index) for fast text search performance
- Scoring: Uses PostgreSQL's ts_rank function for relevance scoring
- Best for: Exact term matching, proper names, technical terms, Boolean-style queries
3. Hybrid Search
- How it works:
- Combines both vector and keyword search results
- Runs both searches independently, then merges results using configurable reranking
- Two reranking strategies available:
- Reciprocal Rank Fusion (RRF) - (default: 60.0)
- Weighted Average - (default: 0.5)
- Characteristics:
- Best of both worlds: semantic understanding + exact matching
- Documents appearing in both searches get boosted scores
- Configurable balance between semantic and lexical matching
- Best for: General-purpose search where you want both precision and recall
4. Database Schema
The PGVector implementation stores data optimized for all three search types:
CREATE TABLE vector_store_xxx (
id TEXT PRIMARY KEY,
document JSONB, -- Original document
embedding vector(dimension), -- For vector search
content_text TEXT, -- Raw text content
tokenized_content TSVECTOR -- For keyword search
);
-- Indexes for performance
CREATE INDEX content_gin_idx ON table USING GIN(tokenized_content); -- Keyword search
-- Vector index created automatically by pgvector
## Usage
To use PGVector in your Llama Stack project, follow these steps:
@ -412,6 +466,25 @@ To use PGVector in your Llama Stack project, follow these steps:
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
3. Start storing and querying vectors.
## This is an example how you can set up your environment for using PGVector
1. Export env vars:
```bash
export ENABLE_PGVECTOR=true
export PGVECTOR_HOST=localhost
export PGVECTOR_PORT=5432
export PGVECTOR_DB=llamastack
export PGVECTOR_USER=llamastack
export PGVECTOR_PASSWORD=llamastack
```
2. Create DB:
```bash
psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';"
psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;"
psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;"
```
## Installation
You can install PGVector using docker:
@ -449,6 +522,7 @@ Weaviate supports:
- Metadata filtering
- Multi-modal retrieval
## Usage
To use Weaviate in your Llama Stack project, follow these steps:

View file

@ -6,15 +6,14 @@
from typing import Any
from llama_stack.core.datatypes import Api
from llama_stack.core.datatypes import AccessRule, Api
from .config import S3FilesImplConfig
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
from .files import S3FilesImpl
# TODO: authorization policies and user separation
impl = S3FilesImpl(config)
impl = S3FilesImpl(config, policy or [])
await impl.initialize()
return impl

View file

@ -22,8 +22,10 @@ from llama_stack.apis.files import (
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import S3FilesImplConfig
@ -120,10 +122,11 @@ def _make_file_object(
class S3FilesImpl(Files):
"""S3-based implementation of the Files API."""
def __init__(self, config: S3FilesImplConfig) -> None:
def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None:
self._config = config
self.policy = policy
self._client: boto3.client | None = None
self._sql_store: SqlStore | None = None
self._sql_store: AuthorizedSqlStore | None = None
def _now(self) -> int:
"""Return current UTC timestamp as int seconds."""
@ -133,7 +136,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id}
if not return_expired:
where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()")
return row
@ -160,7 +163,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = sqlstore_impl(self._config.metadata_store)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
await self._sql_store.create_table(
"openai_files",
{
@ -183,7 +186,7 @@ class S3FilesImpl(Files):
return self._client
@property
def sql_store(self) -> SqlStore:
def sql_store(self) -> AuthorizedSqlStore:
assert self._sql_store is not None, "Provider not initialized"
return self._sql_store
@ -264,6 +267,7 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import heapq
from typing import Any
import psycopg2
@ -23,6 +24,9 @@ from llama_stack.apis.vector_io import (
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -31,6 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
from .config import PGVectorVectorIOConfig
@ -72,26 +77,64 @@ def load_models(cur, cls):
class PGVectorIndex(EmbeddingIndex):
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
# reference: https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION: dict[str, str] = {
"L2": "<->",
"L1": "<+>",
"COSINE": "<=>",
"INNER_PRODUCT": "<#>",
"HAMMING": "<~>",
"JACCARD": "<%>",
}
def __init__(
self,
vector_db: VectorDB,
dimension: int,
conn: psycopg2.extensions.connection,
kvstore: KVStore | None = None,
distance_metric: str = "COSINE",
):
self.vector_db = vector_db
self.dimension = dimension
self.conn = conn
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
self.kvstore = kvstore
self.check_distance_metric_availability(distance_metric)
self.distance_metric = distance_metric
self.table_name = None
async def initialize(self) -> None:
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Sanitize the table name by replacing hyphens with underscores
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
# when created with patterns like "test-vector-db-{uuid4()}"
sanitized_identifier = vector_db.identifier.replace("-", "_")
self.table_name = f"vector_store_{sanitized_identifier}"
self.kvstore = kvstore
sanitized_identifier = sanitize_collection_name(self.vector_db.identifier)
self.table_name = f"vs_{sanitized_identifier}"
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
document JSONB,
embedding vector({dimension})
embedding vector({self.dimension}),
content_text TEXT,
tokenized_content TSVECTOR
)
"""
)
# Create GIN index for full-text search performance
cur.execute(
f"""
CREATE INDEX IF NOT EXISTS {self.table_name}_content_gin_idx
ON {self.table_name} USING GIN(tokenized_content)
"""
)
except Exception as e:
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
@ -99,29 +142,49 @@ class PGVectorIndex(EmbeddingIndex):
values = []
for i, chunk in enumerate(chunks):
content_text = interleaved_content_as_str(chunk.content)
values.append(
(
f"{chunk.chunk_id}",
Json(chunk.model_dump()),
embeddings[i].tolist(),
content_text,
content_text, # Pass content_text twice - once for content_text column, once for to_tsvector function. Eg. to_tsvector(content_text) = tokenized_content
)
)
query = sql.SQL(
f"""
INSERT INTO {self.table_name} (id, document, embedding)
INSERT INTO {self.table_name} (id, document, embedding, content_text, tokenized_content)
VALUES %s
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
ON CONFLICT (id) DO UPDATE SET
embedding = EXCLUDED.embedding,
document = EXCLUDED.document,
content_text = EXCLUDED.content_text,
tokenized_content = EXCLUDED.tokenized_content
"""
)
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
execute_values(cur, query, values, template="(%s, %s, %s::vector, %s, to_tsvector('english', %s))")
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs vector similarity search using PostgreSQL's search function. Default distance metric is COSINE.
Args:
embedding: The query embedding vector
k: Number of results to return
score_threshold: Minimum similarity score threshold
Returns:
QueryChunksResponse with combined results
"""
pgvector_search_function = self.get_pgvector_search_function()
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance
SELECT document, embedding {pgvector_search_function} %s::vector AS distance
FROM {self.table_name}
ORDER BY distance
LIMIT %s
@ -147,7 +210,40 @@ class PGVectorIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in PGVector")
"""
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
Args:
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
Returns:
QueryChunksResponse with combined results
"""
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Use plainto_tsquery to handle user input safely and ts_rank for relevance scoring
cur.execute(
f"""
SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
FROM {self.table_name}
WHERE tokenized_content @@ plainto_tsquery('english', %s)
ORDER BY score DESC
LIMIT %s
""",
(query_string, query_string, k),
)
results = cur.fetchall()
chunks = []
scores = []
for doc, score in results:
if score < score_threshold:
continue
chunks.append(Chunk(**doc))
scores.append(float(score))
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
@ -158,7 +254,59 @@ class PGVectorIndex(EmbeddingIndex):
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in PGVector")
"""
Hybrid search combining vector similarity and keyword search using configurable reranking.
Args:
embedding: The query embedding vector
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
reranker_type: Type of reranker to use ("rrf" or "weighted")
reranker_params: Parameters for the reranker
Returns:
QueryChunksResponse with combined results
"""
if reranker_params is None:
reranker_params = {}
# Get results from both search methods
vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Convert responses to score dictionaries using chunk_id
vector_scores = {
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
chunk.chunk_id: score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
# Combine scores using the reranking utility
combined_scores = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type, reranker_params
)
# Efficient top-k selection because it only tracks the k best candidates it's seen so far
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
# Filter by score threshold
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
# Use the map to look up chunks by their IDs
chunks = []
scores = []
for doc_id, score in filtered_items:
if doc_id in chunk_map:
chunks.append(chunk_map[doc_id])
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
@ -170,6 +318,25 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
def get_pgvector_search_function(self) -> str:
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
def check_distance_metric_availability(self, distance_metric: str) -> None:
"""Check if the distance metric is supported by PGVector.
Args:
distance_metric: The distance metric to check
Raises:
ValueError: If the distance metric is not supported
"""
if distance_metric not in self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION:
supported_metrics = list(self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION.keys())
raise ValueError(
f"Distance metric '{distance_metric}' is not supported by PGVector. "
f"Supported metrics are: {', '.join(supported_metrics)}"
)
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
@ -185,8 +352,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_store: dict[str, dict[str, Any]] = {}
self.metadatadata_collection_name = "openai_vector_stores_metadata"
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
@ -233,9 +400,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
# Create and cache the PGVector index table for the vector DB
pgvector_index = PGVectorIndex(
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
)
await pgvector_index.initialize()
index = VectorDBWithIndex(
vector_db,
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
index=pgvector_index,
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
@ -272,8 +443,15 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
if vector_db_id in self.cache:
return self.cache[vector_db_id]
if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
await index.initialize()
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]

View file

@ -37,3 +37,122 @@ def sanitize_collection_name(name: str, weaviate_format=False) -> str:
else:
s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name))
return s
class WeightedInMemoryAggregator:
@staticmethod
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
"""
Normalize scores to 0-1 range using min-max normalization.
Args:
scores: dictionary of scores with document IDs as keys and scores as values
Returns:
Normalized scores with document IDs as keys and normalized scores as values
"""
if not scores:
return {}
min_score, max_score = min(scores.values()), max(scores.values())
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return dict.fromkeys(scores, 1.0)
@staticmethod
def weighted_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
alpha: float = 0.5,
) -> dict[str, float]:
"""
Rerank via weighted average of scores.
Args:
vector_scores: scores from vector search
keyword_scores: scores from keyword search
alpha: weight factor between 0 and 1 (default: 0.5)
0 = keyword only, 1 = vector only, 0.5 = equal weight
Returns:
All unique document IDs with weighted combined scores
"""
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
normalized_vector_scores = WeightedInMemoryAggregator._normalize_scores(vector_scores)
normalized_keyword_scores = WeightedInMemoryAggregator._normalize_scores(keyword_scores)
# Weighted formula: score = (1-alpha) * keyword_score + alpha * vector_score
# alpha=0 means keyword only, alpha=1 means vector only
return {
doc_id: ((1 - alpha) * normalized_keyword_scores.get(doc_id, 0.0))
+ (alpha * normalized_vector_scores.get(doc_id, 0.0))
for doc_id in all_ids
}
@staticmethod
def rrf_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
impact_factor: float = 60.0,
) -> dict[str, float]:
"""
Rerank via Reciprocal Rank Fusion.
Args:
vector_scores: scores from vector search
keyword_scores: scores from keyword search
impact_factor: impact factor for RRF (default: 60.0)
Returns:
All unique document IDs with RRF combined scores
"""
# Convert scores to ranks
vector_ranks = {
doc_id: i + 1
for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
}
keyword_ranks = {
doc_id: i + 1
for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
}
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
rrf_scores = {}
for doc_id in all_ids:
vector_rank = vector_ranks.get(doc_id, float("inf"))
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
# RRF formula: score = 1/(k + r) where k is impact_factor (default: 60.0) and r is the rank
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
return rrf_scores
@staticmethod
def combine_search_results(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
reranker_type: str = "rrf",
reranker_params: dict[str, float] | None = None,
) -> dict[str, float]:
"""
Combine vector and keyword search results using specified reranking strategy.
Args:
vector_scores: scores from vector search
keyword_scores: scores from keyword search
reranker_type: type of reranker to use (default: RERANKER_TYPE_RRF)
reranker_params: parameters for the reranker
Returns:
All unique document IDs with combined scores
"""
if reranker_params is None:
reranker_params = {}
if reranker_type == "weighted":
alpha = reranker_params.get("alpha", 0.5)
return WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha)
else:
# Default to RRF for None, RRF, or any unknown types
impact_factor = reranker_params.get("impact_factor", 60.0)
return WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor)

View file

@ -84,6 +84,7 @@ unit = [
"openai",
"aiosqlite",
"aiohttp",
"psycopg2-binary>=2.9.0",
"pypdf",
"mcp",
"chardet",
@ -111,6 +112,7 @@ test = [
"torch>=2.6.0",
"torchvision>=0.21.0",
"chardet",
"psycopg2-binary>=2.9.0",
"pypdf",
"mcp",
"datasets",

View file

@ -57,11 +57,13 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
"inline::sqlite-vec",
"remote::milvus",
"inline::milvus",
"remote::pgvector",
],
"hybrid": [
"inline::sqlite-vec",
"inline::milvus",
"remote::milvus",
"remote::pgvector",
],
}
supported_providers = search_mode_support.get(search_mode, [])

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import boto3
import pytest
from moto import mock_aws
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
class MockUploadFile:
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
self.content = content
self.filename = filename
self.content_type = content_type
async def read(self):
return self.content
@pytest.fixture
def sample_text_file():
content = b"Hello, this is a test file for the S3 Files API!"
return MockUploadFile(content, "sample_text_file-0.txt")
@pytest.fixture
def sample_text_file2():
content = b"Hello, this is a second test file for the S3 Files API!"
return MockUploadFile(content, "sample_text_file-1.txt")
@pytest.fixture
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
return S3FilesImplConfig(
bucket_name=f"test-bucket-{tmp_path.name}",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
)
@pytest.fixture
def s3_client():
# we use `with mock_aws()` because @mock_aws decorator does not support
# being a generator
with mock_aws():
# must yield or the mock will be reset before it is used
yield boto3.client("s3")
@pytest.fixture
async def s3_provider(s3_config, s3_client): # s3_client provides the moto mock, don't remove it
provider = await get_adapter_impl(s3_config, {})
yield provider
await provider.shutdown()

View file

@ -6,63 +6,11 @@
from unittest.mock import patch
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_aws
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.providers.remote.files.s3 import (
S3FilesImplConfig,
get_adapter_impl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
class MockUploadFile:
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
self.content = content
self.filename = filename
self.content_type = content_type
async def read(self):
return self.content
@pytest.fixture
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
return S3FilesImplConfig(
bucket_name="test-bucket",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
)
@pytest.fixture
def s3_client():
"""Create a mocked S3 client for testing."""
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
with mock_aws():
# must yield or the mock will be reset before it is used
yield boto3.client("s3")
@pytest.fixture
async def s3_provider(s3_config, s3_client):
"""Create an S3 files provider with mocked S3 for testing."""
provider = await get_adapter_impl(s3_config, {})
yield provider
await provider.shutdown()
@pytest.fixture
def sample_text_file():
content = b"Hello, this is a test file for the S3 Files API!"
return MockUploadFile(content, "sample_text_file.txt")
class TestS3FilesImpl:
@ -143,7 +91,7 @@ class TestS3FilesImpl:
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
assert exc_info.value.response["Error"]["Code"] == "404"
async def test_list_files(self, s3_provider, sample_text_file):
async def test_list_files(self, s3_provider, sample_text_file, sample_text_file2):
"""Test listing files after uploading some."""
sample_text_file.filename = "test_list_files_with_content_file1"
file1 = await s3_provider.openai_upload_file(
@ -151,9 +99,9 @@ class TestS3FilesImpl:
purpose=OpenAIFilePurpose.ASSISTANTS,
)
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
sample_text_file2.filename = "test_list_files_with_content_file2"
file2 = await s3_provider.openai_upload_file(
file=file2_content,
file=sample_text_file2,
purpose=OpenAIFilePurpose.BATCH,
)
@ -164,7 +112,7 @@ class TestS3FilesImpl:
assert file1.id in file_ids
assert file2.id in file_ids
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file, sample_text_file2):
"""Test listing files with purpose filter."""
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
file1 = await s3_provider.openai_upload_file(
@ -172,9 +120,9 @@ class TestS3FilesImpl:
purpose=OpenAIFilePurpose.ASSISTANTS,
)
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
sample_text_file2.filename = "test_list_files_with_purpose_filter_file2"
await s3_provider.openai_upload_file(
file=file2_content,
file=sample_text_file2,
purpose=OpenAIFilePurpose.BATCH,
)

View file

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import patch
import pytest
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.core.datatypes import User
from llama_stack.providers.remote.files.s3.files import S3FilesImpl
async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
"""Listing should not show files uploaded by other users."""
user_a = User("user-a", {"roles": ["team-a"]})
user_b = User("user-b", {"roles": ["team-b"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_a
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_b
listed = await s3_provider.openai_list_files()
assert all(f.id != uploaded.id for f in listed.data)
@pytest.mark.parametrize(
"op",
[S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file],
ids=["retrieve", "content", "delete"],
)
async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op):
"""Operations (metadata/content/delete) on another user's file should raise ResourceNotFoundError.
`op` is an async callable (provider, file_id) -> awaits the requested operation.
"""
user_a = User("user-a", {"roles": ["team-a"]})
user_b = User("user-b", {"roles": ["team-b"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_a
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_b
with pytest.raises(ResourceNotFoundError):
await op(s3_provider, uploaded.id)
async def test_shared_role_allows_listing(s3_provider, sample_text_file):
"""Listing should show files uploaded by other users when roles are shared."""
user_a = User("user-a", {"roles": ["shared-role"]})
user_b = User("user-b", {"roles": ["shared-role"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_a
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_b
listed = await s3_provider.openai_list_files()
assert any(f.id == uploaded.id for f in listed.data)
@pytest.mark.parametrize(
"op",
[S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file],
ids=["retrieve", "content", "delete"],
)
async def test_shared_role_allows_access(s3_provider, sample_text_file, op):
"""Operations (metadata/content/delete) on another user's file should succeed when users share a role.
`op` is an async callable (provider, file_id) -> awaits the requested operation.
"""
user_x = User("user-x", {"roles": ["shared-role"]})
user_y = User("user-y", {"roles": ["shared-role"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_x
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_y
await op(s3_provider, uploaded.id)

View file

@ -0,0 +1,248 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF, RERANKER_TYPE_WEIGHTED
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
class TestNormalizeScores:
"""Test cases for score normalization."""
def test_normalize_scores_basic(self):
"""Test basic score normalization."""
scores = {"doc1": 10.0, "doc2": 5.0, "doc3": 0.0}
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
assert normalized["doc1"] == 1.0 # Max score
assert normalized["doc3"] == 0.0 # Min score
assert normalized["doc2"] == 0.5 # Middle score
assert all(0 <= score <= 1 for score in normalized.values())
def test_normalize_scores_identical(self):
"""Test normalization when all scores are identical."""
scores = {"doc1": 5.0, "doc2": 5.0, "doc3": 5.0}
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
# All scores should be 1.0 when identical
assert all(score == 1.0 for score in normalized.values())
def test_normalize_scores_empty(self):
"""Test normalization with empty scores."""
scores = {}
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
assert normalized == {}
def test_normalize_scores_single(self):
"""Test normalization with single score."""
scores = {"doc1": 7.5}
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
assert normalized["doc1"] == 1.0
class TestWeightedRerank:
"""Test cases for weighted reranking."""
def test_weighted_rerank_basic(self):
"""Test basic weighted reranking."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5}
keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9}
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.5)
# Should include all documents
expected_docs = {"doc1", "doc2", "doc3", "doc4"}
assert set(combined.keys()) == expected_docs
# All scores should be between 0 and 1
assert all(0 <= score <= 1 for score in combined.values())
# doc1 appears in both searches, should have higher combined score
assert combined["doc1"] > 0
def test_weighted_rerank_alpha_zero(self):
"""Test weighted reranking with alpha=0 (keyword only)."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector
keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.0)
# Alpha=0 means vector scores are ignored, keyword scores dominate
# doc3 should score highest since it has highest keyword score
assert combined["doc3"] > combined["doc2"] > combined["doc1"]
def test_weighted_rerank_alpha_one(self):
"""Test weighted reranking with alpha=1 (vector only)."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector
keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=1.0)
# Alpha=1 means keyword scores are ignored, vector scores dominate
# doc1 should score highest since it has highest vector score
assert combined["doc1"] > combined["doc2"] > combined["doc3"]
def test_weighted_rerank_no_overlap(self):
"""Test weighted reranking with no overlapping documents."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc3": 0.8, "doc4": 0.6}
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.5)
assert len(combined) == 4
# With min-max normalization, lowest scoring docs in each group get 0.0
# but highest scoring docs should get positive scores
assert all(score >= 0 for score in combined.values())
assert combined["doc1"] > 0 # highest vector score
assert combined["doc3"] > 0 # highest keyword score
class TestRRFRerank:
"""Test cases for RRF (Reciprocal Rank Fusion) reranking."""
def test_rrf_rerank_basic(self):
"""Test basic RRF reranking."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5}
keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9}
combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
# Should include all documents
expected_docs = {"doc1", "doc2", "doc3", "doc4"}
assert set(combined.keys()) == expected_docs
# All scores should be positive
assert all(score > 0 for score in combined.values())
# Documents appearing in both searches should have higher scores
# doc1 and doc2 appear in both, doc3 and doc4 appear in only one
assert combined["doc1"] > combined["doc3"]
assert combined["doc2"] > combined["doc4"]
def test_rrf_rerank_rank_calculation(self):
"""Test that RRF correctly calculates ranks."""
# Create clear ranking order
vector_scores = {"doc1": 1.0, "doc2": 0.8, "doc3": 0.6} # Ranks: 1, 2, 3
keyword_scores = {"doc1": 0.5, "doc2": 1.0, "doc3": 0.7} # Ranks: 3, 1, 2
combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
# doc1: rank 1 in vector, rank 3 in keyword
# doc2: rank 2 in vector, rank 1 in keyword
# doc3: rank 3 in vector, rank 2 in keyword
# doc2 should have the highest combined score (ranks 2+1=3)
# followed by doc1 (ranks 1+3=4) and doc3 (ranks 3+2=5)
# Remember: lower rank sum = higher RRF score
assert combined["doc2"] > combined["doc1"] > combined["doc3"]
def test_rrf_rerank_impact_factor(self):
"""Test that impact factor affects RRF scores."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.8, "doc2": 0.6}
combined_low = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=10.0)
combined_high = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=100.0)
# Higher impact factor should generally result in lower scores
# (because 1/(k+r) decreases as k increases)
assert combined_low["doc1"] > combined_high["doc1"]
assert combined_low["doc2"] > combined_high["doc2"]
def test_rrf_rerank_missing_documents(self):
"""Test RRF handling of documents missing from one search."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.8, "doc3": 0.6}
combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
# Should include all documents
assert len(combined) == 3
# doc1 appears in both searches, should have highest score
assert combined["doc1"] > combined["doc2"]
assert combined["doc1"] > combined["doc3"]
class TestCombineSearchResults:
"""Test cases for the main combine_search_results function."""
def test_combine_search_results_rrf_default(self):
"""Test combining with RRF as default."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(vector_scores, keyword_scores)
# Should default to RRF
assert len(combined) == 3
assert all(score > 0 for score in combined.values())
def test_combine_search_results_rrf_explicit(self):
"""Test combining with explicit RRF."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_RRF, reranker_params={"impact_factor": 50.0}
)
assert len(combined) == 3
assert all(score > 0 for score in combined.values())
def test_combine_search_results_weighted(self):
"""Test combining with weighted reranking."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_WEIGHTED, reranker_params={"alpha": 0.3}
)
assert len(combined) == 3
assert all(0 <= score <= 1 for score in combined.values())
def test_combine_search_results_unknown_type(self):
"""Test combining with unknown reranker type defaults to RRF."""
vector_scores = {"doc1": 0.9}
keyword_scores = {"doc2": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type="unknown_type"
)
# Should fall back to RRF
assert len(combined) == 2
assert all(score > 0 for score in combined.values())
def test_combine_search_results_empty_params(self):
"""Test combining with empty parameters."""
vector_scores = {"doc1": 0.9}
keyword_scores = {"doc2": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(vector_scores, keyword_scores, reranker_params={})
# Should use default parameters
assert len(combined) == 2
assert all(score > 0 for score in combined.values())
def test_combine_search_results_empty_scores(self):
"""Test combining with empty score dictionaries."""
# Test with empty vector scores
combined = WeightedInMemoryAggregator.combine_search_results({}, {"doc1": 0.8})
assert len(combined) == 1
assert combined["doc1"] > 0
# Test with empty keyword scores
combined = WeightedInMemoryAggregator.combine_search_results({"doc1": 0.9}, {})
assert len(combined) == 1
assert combined["doc1"] > 0
# Test with both empty
combined = WeightedInMemoryAggregator.combine_search_results({}, {})
assert len(combined) == 0

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import random
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@ -12,7 +13,7 @@ from chromadb import PersistentClient
from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
@ -22,6 +23,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
EMBEDDING_DIMENSION = 384
@ -29,7 +32,7 @@ COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
def vector_provider(request):
return request.param
@ -333,15 +336,127 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
await index.delete()
@pytest.fixture
def mock_psycopg2_connection():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__ = MagicMock(return_value=cursor)
cursor.__exit__ = MagicMock()
connection.cursor.return_value = cursor
return connection, cursor
@pytest.fixture
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id="pgvector",
provider_resource_id="pgvector:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
index._test_chunks = []
original_add_chunks = index.add_chunks
async def mock_add_chunks(chunks, embeddings):
index._test_chunks = list(chunks)
await original_add_chunks(chunks, embeddings)
index.add_chunks = mock_add_chunks
async def mock_query_vector(embedding, k, score_threshold):
chunks = index._test_chunks[:k] if hasattr(index, "_test_chunks") else []
scores = [1.0] * len(chunks)
return QueryChunksResponse(chunks=chunks, scores=scores)
index.query_vector = mock_query_vector
yield index
@pytest.fixture
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
config = PGVectorVectorIOConfig(
host="localhost",
port=5432,
db="test_db",
user="test_user",
password="test_password",
kvstore=SqliteKVStoreConfig(),
)
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.autocommit = True
mock_connect.return_value = mock_conn
with patch(
"llama_stack.providers.remote.vector_io.pgvector.pgvector.check_extension_version"
) as mock_check_version:
mock_check_version.return_value = "0.5.1"
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
mock_kvstore = AsyncMock()
mock_kvstore_impl.return_value = mock_kvstore
with patch.object(adapter, "initialize_openai_vector_stores", new_callable=AsyncMock):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"):
await adapter.initialize()
adapter.conn = mock_conn
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
adapter.insert_chunks = mock_insert_chunks
async def mock_query_chunks(vector_db_id, query, params=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
adapter.query_chunks = mock_query_chunks
test_vector_db = VectorDB(
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
await adapter.register_vector_db(test_vector_db)
adapter.test_collection_id = test_vector_db.identifier
yield adapter
await adapter.shutdown()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter."""
vector_provider_dict = {
"milvus": "milvus_vec_adapter",
"faiss": "faiss_vec_adapter",
"sqlite_vec": "sqlite_vec_adapter",
"chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter",
}
return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import patch
import pytest
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex
PGVECTOR_PROVIDER = "pgvector"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture
def embedding_dimension():
"""Default embedding dimension for tests."""
return 384
@pytest.fixture
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
"""Create a PGVectorIndex instance with mocked database connection."""
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
# Use explicit COSINE distance metric for consistent testing
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
return index, cursor
class TestPGVectorIndex:
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
assert index.distance_metric == "L2"
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
def test_get_pgvector_search_function(self, pgvector_index):
index, cursor = pgvector_index
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
for metric, function in supported_metrics.items():
index.distance_metric = metric
assert index.get_pgvector_search_function() == function
def test_check_distance_metric_availability(self, pgvector_index):
index, cursor = pgvector_index
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
for metric in supported_metrics:
index.check_distance_metric_availability(metric)
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
index.check_distance_metric_availability("INVALID")
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
with pytest.raises(ValueError, match="Supported metrics are:"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
assert index.distance_metric == "COSINE"
except ValueError:
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
for metric in supported_metrics:
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
assert index.distance_metric == metric
expected_operators = {
"L2": "<->",
"L1": "<+>",
"COSINE": "<=>",
"INNER_PRODUCT": "<#>",
"HAMMING": "<~>",
"JACCARD": "<%>",
}
assert index.get_pgvector_search_function() == expected_operators[metric]
except Exception as e:
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")

1827
uv.lock generated

File diff suppressed because it is too large Load diff