test: add testing to ensure large doc / queries are correctly counted

This commit is contained in:
Krrish Dholakia 2025-02-20 16:40:10 -08:00
parent 75e231b1e4
commit 4337f1657e
2 changed files with 50 additions and 0 deletions

View file

@ -29,6 +29,42 @@ from litellm.types.rerank import (
class BedrockRerankConfig:
def count_total_queries(
self,
query_tokens: int,
document_tokens: list[int],
cost_per_1000_queries: float = 1.0,
) -> float:
"""
Calculate the cost of a request based on token counts and document chunks.
Args:
query_tokens (int): Number of tokens in the query
document_tokens (list[int]): List of token counts for each document
cost_per_1000_queries (float): Cost per 1000 queries in dollars (default: $1.00)
Returns:
float: Total cost in dollars
"""
TOKENS_PER_DOCUMENT = 512
CHUNKS_PER_QUERY = 100
# Validate query length
if query_tokens >= TOKENS_PER_DOCUMENT:
raise ValueError("Query tokens exceed maximum allowed tokens per document")
# Calculate total chunks needed
total_chunks = 0
available_tokens = TOKENS_PER_DOCUMENT - query_tokens
for doc_tokens in document_tokens:
# Calculate chunks needed for this document
chunks_needed = (doc_tokens + available_tokens - 1) // available_tokens
total_chunks += max(1, chunks_needed)
# Calculate total queries needed (rounded up to nearest multiple of CHUNKS_PER_QUERY)
total_queries = (total_chunks + CHUNKS_PER_QUERY - 1) // CHUNKS_PER_QUERY
return total_queries
def _transform_sources(
self, documents: List[Union[str, dict]]

View file

@ -2185,6 +2185,20 @@ class TestBedrockRerank(BaseLLMRerankTest):
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
}
def test_bedrock_rerank_large_query(self):
from litellm import rerank
response = rerank(
model="bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
query="What is the capital of France?",
documents=["Paris", "London", "Berlin", "Madrid", "Rome"] * 100,
top_n=3,
)
assert response is not None
search_units = response.meta["billed_units"]["search_units"]
assert search_units == 5
assert response._hidden_params["response_cost"] == round(5 * 0.001, 4)
class TestBedrockCohereRerank(BaseLLMRerankTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders: