diff --git a/litellm/llms/bedrock/rerank/transformation.py b/litellm/llms/bedrock/rerank/transformation.py index a5380febe9..09856bbb20 100644 --- a/litellm/llms/bedrock/rerank/transformation.py +++ b/litellm/llms/bedrock/rerank/transformation.py @@ -29,6 +29,42 @@ from litellm.types.rerank import ( class BedrockRerankConfig: + def count_total_queries( + self, + query_tokens: int, + document_tokens: list[int], + cost_per_1000_queries: float = 1.0, + ) -> float: + """ + Calculate the cost of a request based on token counts and document chunks. + + Args: + query_tokens (int): Number of tokens in the query + document_tokens (list[int]): List of token counts for each document + cost_per_1000_queries (float): Cost per 1000 queries in dollars (default: $1.00) + + Returns: + float: Total cost in dollars + """ + TOKENS_PER_DOCUMENT = 512 + CHUNKS_PER_QUERY = 100 + + # Validate query length + if query_tokens >= TOKENS_PER_DOCUMENT: + raise ValueError("Query tokens exceed maximum allowed tokens per document") + + # Calculate total chunks needed + total_chunks = 0 + available_tokens = TOKENS_PER_DOCUMENT - query_tokens + + for doc_tokens in document_tokens: + # Calculate chunks needed for this document + chunks_needed = (doc_tokens + available_tokens - 1) // available_tokens + total_chunks += max(1, chunks_needed) + + # Calculate total queries needed (rounded up to nearest multiple of CHUNKS_PER_QUERY) + total_queries = (total_chunks + CHUNKS_PER_QUERY - 1) // CHUNKS_PER_QUERY + return total_queries def _transform_sources( self, documents: List[Union[str, dict]] diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 7fca7b5f1a..f15deaa3ce 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2185,6 +2185,20 @@ class TestBedrockRerank(BaseLLMRerankTest): "model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0", } + def test_bedrock_rerank_large_query(self): + from litellm import rerank + + response = rerank( + model="bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0", + query="What is the capital of France?", + documents=["Paris", "London", "Berlin", "Madrid", "Rome"] * 100, + top_n=3, + ) + assert response is not None + search_units = response.meta["billed_units"]["search_units"] + assert search_units == 5 + assert response._hidden_params["response_cost"] == round(5 * 0.001, 4) + class TestBedrockCohereRerank(BaseLLMRerankTest): def get_custom_llm_provider(self) -> litellm.LlmProviders: