Add cost tracking for rerank via bedrock (#8691)

* feat(bedrock/rerank): infer model region if model given as arn

* test: add unit testing to ensure bedrock region name inferred from arn on rerank

* feat(bedrock/rerank/transformation.py): include search units for bedrock rerank result

Resolves https://github.com/BerriAI/litellm/issues/7258#issuecomment-2671557137

* test(test_bedrock_completion.py): add testing for bedrock cohere rerank

* feat(cost_calculator.py): refactor rerank cost tracking to support bedrock cost tracking

* build(model_prices_and_context_window.json): add amazon.rerank model to model cost map

* fix(cost_calculator.py): bedrock/common_utils.py

get base model from model w/ arn -> handles rerank model

* build(model_prices_and_context_window.json): add bedrock cohere rerank pricing

* feat(bedrock/rerank): migrate bedrock config to basererank config

* Revert "feat(bedrock/rerank): migrate bedrock config to basererank config"

This reverts commit 84fae1f167.

* test: add testing to ensure large doc / queries are correctly counted

* Revert "test: add testing to ensure large doc / queries are correctly counted"

This reverts commit 4337f1657e.

* fix(migrate-jina-ai-to-rerank-config): enables cost tracking

* refactor(jina_ai/): finish migrating jina ai to base rerank config

enables cost tracking

* fix(jina_ai/rerank): e2e jina ai rerank cost tracking

* fix: cleanup dead code

* fix: fix python3.8 compatibility error

* test: fix test

* test: add e2e testing for azure ai rerank

* fix: fix linting error

* test: mark cohere as flaky
This commit is contained in:
Krish Dholakia 2025-02-20 21:00:18 -08:00 committed by GitHub
parent 4c9517fd78
commit b682dc4ec8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 524 additions and 296 deletions

View file

@ -10,6 +10,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) |
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
| Rerank Endpoint | `/rerank` |
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |

View file

@ -820,6 +820,7 @@ from .llms.cohere.completion.transformation import CohereTextConfig as CohereCon
from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
from .llms.infinity.rerank.transformation import InfinityRerankConfig
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.together_ai.chat import TogetherAIConfig

View file

@ -16,15 +16,9 @@ from litellm.llms.anthropic.cost_calculation import (
from litellm.llms.azure.cost_calculation import (
cost_per_token as azure_openai_cost_per_token,
)
from litellm.llms.azure_ai.cost_calculator import (
cost_per_query as azure_ai_rerank_cost_per_query,
)
from litellm.llms.bedrock.image.cost_calculator import (
cost_calculator as bedrock_image_cost_calculator,
)
from litellm.llms.cohere.cost_calculator import (
cost_per_query as cohere_rerank_cost_per_query,
)
from litellm.llms.databricks.cost_calculator import (
cost_per_token as databricks_cost_per_token,
)
@ -51,10 +45,12 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
cost_calculator as vertex_ai_image_cost_calculator,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse
from litellm.types.rerank import RerankBilledUnits, RerankResponse
from litellm.types.utils import (
CallTypesLiteral,
LlmProviders,
LlmProvidersSet,
ModelInfo,
PassthroughCallTypes,
Usage,
)
@ -64,6 +60,7 @@ from litellm.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ProviderConfigManager,
TextCompletionResponse,
TranscriptionResponse,
_cached_get_model_info_helper,
@ -114,6 +111,8 @@ def cost_per_token( # noqa: PLR0915
number_of_queries: Optional[int] = None,
### USAGE OBJECT ###
usage_object: Optional[Usage] = None, # just read the usage object if provided
### BILLED UNITS ###
rerank_billed_units: Optional[RerankBilledUnits] = None,
### CALL TYPE ###
call_type: CallTypesLiteral = "completion",
audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds
@ -238,6 +237,7 @@ def cost_per_token( # noqa: PLR0915
return rerank_cost(
model=model,
custom_llm_provider=custom_llm_provider,
billed_units=rerank_billed_units,
)
elif call_type == "atranscription" or call_type == "transcription":
return openai_cost_per_second(
@ -552,6 +552,7 @@ def completion_cost( # noqa: PLR0915
cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
completion_response=completion_response
)
rerank_billed_units: Optional[RerankBilledUnits] = None
model = _select_model_name_for_cost_calc(
model=model,
completion_response=completion_response,
@ -698,6 +699,11 @@ def completion_cost( # noqa: PLR0915
else:
billed_units = {}
rerank_billed_units = RerankBilledUnits(
search_units=billed_units.get("search_units"),
total_tokens=billed_units.get("total_tokens"),
)
search_units = (
billed_units.get("search_units") or 1
) # cohere charges per request by default.
@ -763,6 +769,7 @@ def completion_cost( # noqa: PLR0915
usage_object=cost_per_token_usage_object,
call_type=call_type,
audio_transcription_file_duration=audio_transcription_file_duration,
rerank_billed_units=rerank_billed_units,
)
_final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
@ -836,27 +843,33 @@ def response_cost_calculator(
def rerank_cost(
model: str,
custom_llm_provider: Optional[str],
billed_units: Optional[RerankBilledUnits] = None,
) -> Tuple[float, float]:
"""
Returns
- float or None: cost of response OR none if error.
"""
default_num_queries = 1
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
try:
if custom_llm_provider == "cohere":
return cohere_rerank_cost_per_query(
model=model, num_queries=default_num_queries
config = ProviderConfigManager.get_provider_rerank_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
try:
model_info: Optional[ModelInfo] = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
elif custom_llm_provider == "azure_ai":
return azure_ai_rerank_cost_per_query(
model=model, num_queries=default_num_queries
)
raise ValueError(
f"invalid custom_llm_provider for rerank model: {model}, custom_llm_provider: {custom_llm_provider}"
except Exception:
model_info = None
return config.calculate_rerank_cost(
model=model,
custom_llm_provider=custom_llm_provider,
billed_units=billed_units,
model_info=model_info,
)
except Exception as e:
raise e

View file

@ -1,32 +0,0 @@
"""
Handles custom cost calculation for Azure AI models.
Custom cost calculation for Azure AI models only requied for rerank.
"""
from typing import Tuple
from litellm.utils import get_model_info
def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
"""
Calculates the cost per query for a given rerank model.
Input:
- model: str, the model name without provider prefix
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
model_info = get_model_info(model=model, custom_llm_provider="azure_ai")
if (
"input_cost_per_query" not in model_info
or model_info["input_cost_per_query"] is None
):
return 0.0, 0.0
prompt_cost = model_info["input_cost_per_query"] * num_queries
return prompt_cost, 0.0

View file

@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.rerank import OptionalRerankParams, RerankBilledUnits, RerankResponse
from litellm.types.utils import ModelInfo
from ..chat.transformation import BaseLLMException
@ -66,7 +67,7 @@ class BaseRerankConfig(ABC):
@abstractmethod
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
non_default_params: dict,
model: str,
drop_params: bool,
query: str,
@ -79,8 +80,48 @@ class BaseRerankConfig(ABC):
) -> OptionalRerankParams:
pass
@abstractmethod
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
pass
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
def calculate_rerank_cost(
self,
model: str,
custom_llm_provider: Optional[str] = None,
billed_units: Optional[RerankBilledUnits] = None,
model_info: Optional[ModelInfo] = None,
) -> Tuple[float, float]:
"""
Calculates the cost per query for a given rerank model.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider.
- num_queries: int, the number of queries to calculate the cost for
- model_info: ModelInfo, the model info for the given model
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
if (
model_info is None
or "input_cost_per_query" not in model_info
or model_info["input_cost_per_query"] is None
or billed_units is None
):
return 0.0, 0.0
search_units = billed_units.get("search_units")
if search_units is None:
return 0.0, 0.0
prompt_cost = model_info["input_cost_per_query"] * search_units
return prompt_cost, 0.0

View file

@ -202,6 +202,61 @@ class BaseAWSLLM:
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
return credentials
def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
try:
# First check if the string contains the expected prefix
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
return None
# Split the ARN and check if we have enough parts
parts = model.split(":")
if len(parts) < 4:
return None
# Get the region from the correct position
region = parts[3]
if not region: # Check if region is empty
return None
return region
except Exception:
# Catch any unexpected errors and return None
return None
def _get_aws_region_name(
self, optional_params: dict, model: Optional[str] = None
) -> str:
"""
Get the AWS region name from the environment variables
"""
aws_region_name = optional_params.get("aws_region_name", None)
### SET REGION NAME ###
if aws_region_name is None:
# check model arn #
aws_region_name = self._get_aws_region_from_model_arn(model)
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if (
aws_region_name is None
and litellm_aws_region_name is not None
and isinstance(litellm_aws_region_name, str)
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if (
aws_region_name is None
and standard_aws_region_name is not None
and isinstance(standard_aws_region_name, str)
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
return aws_region_name
@tracer.wrap()
def _auth_with_web_identity_token(
self,
@ -423,7 +478,7 @@ class BaseAWSLLM:
return endpoint_url, proxy_endpoint_url
def _get_boto_credentials_from_optional_params(
self, optional_params: dict
self, optional_params: dict, model: Optional[str] = None
) -> Boto3CredentialsInfo:
"""
Get boto3 credentials from optional params
@ -443,7 +498,7 @@ class BaseAWSLLM:
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_region_name = self._get_aws_region_name(optional_params, model)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
@ -453,25 +508,6 @@ class BaseAWSLLM:
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret_str("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,

View file

@ -27,7 +27,7 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import CustomStreamWrapper, get_secret
from litellm.utils import CustomStreamWrapper
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -598,61 +598,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
)
return modelId
def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
try:
# First check if the string contains the expected prefix
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
return None
# Split the ARN and check if we have enough parts
parts = model.split(":")
if len(parts) < 4:
return None
# Get the region from the correct position
region = parts[3]
if not region: # Check if region is empty
return None
return region
except Exception:
# Catch any unexpected errors and return None
return None
def _get_aws_region_name(
self, optional_params: dict, model: Optional[str] = None
) -> str:
"""
Get the AWS region name from the environment variables
"""
aws_region_name = optional_params.get("aws_region_name", None)
### SET REGION NAME ###
if aws_region_name is None:
# check model arn #
aws_region_name = self.get_aws_region_from_model_arn(model)
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if (
aws_region_name is None
and litellm_aws_region_name is not None
and isinstance(litellm_aws_region_name, str)
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if (
aws_region_name is None
and standard_aws_region_name is not None
and isinstance(standard_aws_region_name, str)
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
return aws_region_name
def _get_model_id_from_model_with_spec(
self,
model: str,

View file

@ -318,6 +318,23 @@ class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
@staticmethod
def extract_model_name_from_arn(model: str) -> str:
"""
Extract the model name from an AWS Bedrock ARN.
Returns the string after the last '/' if 'arn' is in the input string.
Args:
arn (str): The ARN string to parse
Returns:
str: The extracted model name if 'arn' is in the string,
otherwise returns the original string
"""
if "arn" in model.lower():
return model.split("/")[-1]
return model
@staticmethod
def get_base_model(model: str) -> str:
"""
@ -335,6 +352,8 @@ class BedrockModelInfo(BaseLLMModelInfo):
if model.startswith("invoke/"):
model = model.split("/", 1)[1]
model = BedrockModelInfo.extract_model_name_from_arn(model)
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[

View file

@ -163,7 +163,7 @@ class BedrockImageGeneration(BaseAWSLLM):
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
optional_params, model
)
### SET RUNTIME ENDPOINT ###

View file

@ -6,6 +6,8 @@ import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
@ -27,8 +29,10 @@ class BedrockRerankHandler(BaseAWSLLM):
async def arerank(
self,
prepared_request: BedrockPreparedRequest,
client: Optional[AsyncHTTPHandler] = None,
):
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
if client is None:
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
try:
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status()
@ -54,7 +58,9 @@ class BedrockRerankHandler(BaseAWSLLM):
_is_async: Optional[bool] = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
request_data = RerankRequest(
model=model,
query=query,
@ -66,6 +72,7 @@ class BedrockRerankHandler(BaseAWSLLM):
data = BedrockRerankConfig()._transform_request(request_data)
prepared_request = self._prepare_request(
model=model,
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
@ -83,9 +90,10 @@ class BedrockRerankHandler(BaseAWSLLM):
)
if _is_async:
return self.arerank(prepared_request) # type: ignore
return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
client = _get_httpx_client()
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
try:
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status()
@ -95,10 +103,18 @@ class BedrockRerankHandler(BaseAWSLLM):
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return BedrockRerankConfig()._transform_response(response.json())
logging_obj.post_call(
original_response=response.text,
api_key="",
)
response_json = response.json()
return BedrockRerankConfig()._transform_response(response_json)
def _prepare_request(
self,
model: str,
api_base: Optional[str],
extra_headers: Optional[dict],
data: dict,
@ -110,7 +126,7 @@ class BedrockRerankHandler(BaseAWSLLM):
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
optional_params, model
)
### SET RUNTIME ENDPOINT ###

View file

@ -91,7 +91,9 @@ class BedrockRerankConfig:
example input:
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
"""
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_billed_units = RerankBilledUnits(
**response.get("usage", {"search_units": 1})
) # by default 1 search unit
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)

View file

@ -1,31 +0,0 @@
"""
Custom cost calculator for Cohere rerank models
"""
from typing import Tuple
from litellm.utils import get_model_info
def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
"""
Calculates the cost per query for a given rerank model.
Input:
- model: str, the model name without provider prefix
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
model_info = get_model_info(model=model, custom_llm_provider="cohere")
if (
"input_cost_per_query" not in model_info
or model_info["input_cost_per_query"] is None
):
return 0.0, 0.0
prompt_cost = model_info["input_cost_per_query"] * num_queries
return prompt_cost, 0.0

View file

@ -1,92 +1,3 @@
"""
Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs
HTTP calling migrated to `llm_http_handler.py`
"""
from typing import Any, Dict, List, Optional, Union
import litellm
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
from litellm.types.rerank import RerankRequest, RerankResponse
class JinaAIRerank(BaseLLM):
def rerank(
self,
model: str,
api_key: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
) -> RerankResponse:
client = _get_httpx_client()
request_data = RerankRequest(
model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
)
# exclude None values from request_data
request_data_dict = request_data.dict(exclude_none=True)
if _is_async:
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
response = client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
async def async_rerank( # New async method
self,
request_data_dict: Dict[str, Any],
api_key: str,
) -> RerankResponse:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.JINA_AI
) # Use async client
response = await client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
pass

View file

@ -7,30 +7,136 @@ Docs - https://jina.ai/reranker
"""
import uuid
from typing import List, Optional
from typing import Any, Dict, List, Optional, Tuple, Union
from httpx import URL, Response
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.types.rerank import (
OptionalRerankParams,
RerankBilledUnits,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
from litellm.types.utils import ModelInfo
class JinaAIRerankConfig:
def _transform_response(self, response: dict) -> RerankResponse:
class JinaAIRerankConfig(BaseRerankConfig):
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"top_n",
"documents",
"return_documents",
]
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_tokens = RerankTokens(**response.get("usage", {}))
def map_cohere_rerank_params(
self,
non_default_params: dict,
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
optional_params = {}
supported_params = self.get_supported_cohere_rerank_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return OptionalRerankParams(
**optional_params,
)
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
base_path = "/v1/rerank"
if api_base is None:
return "https://api.jina.ai/v1/rerank"
base = URL(api_base)
# Reconstruct URL with cleaned path
cleaned_base = str(base.copy_with(path=base_path))
return cleaned_base
def transform_rerank_request(
self, model: str, optional_rerank_params: OptionalRerankParams, headers: Dict
) -> Dict:
return {"model": model, **optional_rerank_params}
def transform_rerank_response(
self,
model: str,
raw_response: Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: Dict = {},
optional_params: Dict = {},
litellm_params: Dict = {},
) -> RerankResponse:
if raw_response.status_code != 200:
raise Exception(raw_response.text)
logging_obj.post_call(original_response=raw_response.text)
_json_response = raw_response.json()
_billed_units = RerankBilledUnits(**_json_response.get("usage", {}))
_tokens = RerankTokens(**_json_response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = response.get("results")
_results: Optional[List[dict]] = _json_response.get("results")
if _results is None:
raise ValueError(f"No results found in the response={response}")
raise ValueError(f"No results found in the response={_json_response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
id=_json_response.get("id") or str(uuid.uuid4()),
results=_results, # type: ignore
meta=rerank_meta,
) # Return response
def validate_environment(
self, headers: Dict, model: str, api_key: Optional[str] = None
) -> Dict:
if api_key is None:
raise ValueError(
"api_key is required. Set via `api_key` parameter or `JINA_API_KEY` environment variable."
)
return {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
}
def calculate_rerank_cost(
self,
model: str,
custom_llm_provider: Optional[str] = None,
billed_units: Optional[RerankBilledUnits] = None,
model_info: Optional[ModelInfo] = None,
) -> Tuple[float, float]:
"""
Jina AI reranker is priced at $0.000000018 per token.
"""
if (
model_info is None
or "input_cost_per_token" not in model_info
or model_info["input_cost_per_token"] is None
or billed_units is None
):
return 0.0, 0.0
total_tokens = billed_units.get("total_tokens")
if total_tokens is None:
return 0.0, 0.0
input_cost = model_info["input_cost_per_token"] * total_tokens
return input_cost, 0.0

View file

@ -5982,6 +5982,19 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"amazon.rerank-v1:0": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"max_query_tokens": 32000,
"max_document_chunks_per_query": 100,
"max_tokens_per_document_chunk": 512,
"input_cost_per_token": 0.0,
"input_cost_per_query": 0.001,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "rerank"
},
"amazon.titan-text-lite-v1": {
"max_tokens": 4000,
"max_input_tokens": 42000,
@ -7022,6 +7035,19 @@
"mode": "chat",
"supports_tool_choice": true
},
"cohere.rerank-v3-5:0": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"max_query_tokens": 32000,
"max_document_chunks_per_query": 100,
"max_tokens_per_document_chunk": 512,
"input_cost_per_token": 0.0,
"input_cost_per_query": 0.002,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "rerank"
},
"cohere.command-text-v14": {
"max_tokens": 4096,
"max_input_tokens": 4096,
@ -9154,5 +9180,15 @@
"input_cost_per_second": 0.00003333,
"output_cost_per_second": 0.00,
"litellm_provider": "assemblyai"
},
"jina-reranker-v2-base-multilingual": {
"max_tokens": 1024,
"max_input_tokens": 1024,
"max_output_tokens": 1024,
"max_document_chunks_per_query": 2048,
"input_cost_per_token": 0.000000018,
"output_cost_per_token": 0.000000018,
"litellm_provider": "jina_ai",
"mode": "rerank"
}
}

View file

@ -9,7 +9,6 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
from litellm.secret_managers.main import get_secret, get_secret_str
@ -20,7 +19,6 @@ from litellm.utils import ProviderConfigManager, client, exception_type
####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here
together_rerank = TogetherAIRerank()
jina_ai_rerank = JinaAIRerank()
bedrock_rerank = BedrockRerankHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
#################################################
@ -264,16 +262,26 @@ def rerank( # noqa: PLR0915
raise ValueError(
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
)
response = jina_ai_rerank.rerank(
api_base = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("BEDROCK_API_BASE") # type: ignore
)
response = base_llm_http_handler.rerank(
model=model,
api_key=dynamic_api_key,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
elif _custom_llm_provider == "bedrock":
api_base = (
@ -295,6 +303,7 @@ def rerank( # noqa: PLR0915
optional_params=optional_params.model_dump(exclude_unset=True),
api_base=api_base,
logging_obj=litellm_logging_obj,
client=client,
)
else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

View file

@ -17,6 +17,17 @@ def get_optional_rerank_params(
max_chunks_per_doc: Optional[int] = None,
non_default_params: Optional[dict] = None,
) -> OptionalRerankParams:
all_non_default_params = non_default_params or {}
if query is not None:
all_non_default_params["query"] = query
if top_n is not None:
all_non_default_params["top_n"] = top_n
if documents is not None:
all_non_default_params["documents"] = documents
if return_documents is not None:
all_non_default_params["return_documents"] = return_documents
if max_chunks_per_doc is not None:
all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc
return rerank_provider_config.map_cohere_rerank_params(
model=model,
drop_params=drop_params,
@ -27,5 +38,5 @@ def get_optional_rerank_params(
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
non_default_params=non_default_params,
non_default_params=all_non_default_params,
)

View file

@ -6198,6 +6198,8 @@ class ProviderConfigManager:
return litellm.AzureAIRerankConfig()
elif litellm.LlmProviders.INFINITY == provider:
return litellm.InfinityRerankConfig()
elif litellm.LlmProviders.JINA_AI == provider:
return litellm.JinaAIRerankConfig()
return litellm.CohereRerankConfig()
@staticmethod

View file

@ -5982,6 +5982,19 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"amazon.rerank-v1:0": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"max_query_tokens": 32000,
"max_document_chunks_per_query": 100,
"max_tokens_per_document_chunk": 512,
"input_cost_per_token": 0.0,
"input_cost_per_query": 0.001,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "rerank"
},
"amazon.titan-text-lite-v1": {
"max_tokens": 4000,
"max_input_tokens": 42000,
@ -7022,6 +7035,19 @@
"mode": "chat",
"supports_tool_choice": true
},
"cohere.rerank-v3-5:0": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"max_query_tokens": 32000,
"max_document_chunks_per_query": 100,
"max_tokens_per_document_chunk": 512,
"input_cost_per_token": 0.0,
"input_cost_per_query": 0.002,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "rerank"
},
"cohere.command-text-v14": {
"max_tokens": 4096,
"max_input_tokens": 4096,
@ -9154,5 +9180,15 @@
"input_cost_per_second": 0.00003333,
"output_cost_per_second": 0.00,
"litellm_provider": "assemblyai"
},
"jina-reranker-v2-base-multilingual": {
"max_tokens": 1024,
"max_input_tokens": 1024,
"max_output_tokens": 1024,
"max_document_chunks_per_query": 2048,
"input_cost_per_token": 0.000000018,
"output_cost_per_token": 0.000000018,
"litellm_provider": "jina_ai",
"mode": "rerank"
}
}

View file

@ -0,0 +1,14 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
from litellm import rerank
from litellm.llms.custom_httpx.http_handler import HTTPHandler

View file

@ -0,0 +1,51 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
from litellm import rerank
from litellm.llms.custom_httpx.http_handler import HTTPHandler
def test_rerank_infer_region_from_model_arn(monkeypatch):
mock_response = MagicMock()
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
args = {
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
"query": "hello",
"documents": ["hello", "world"],
}
def return_val():
return {
"results": [
{"index": 0, "relevanceScore": 0.6716859340667725},
{"index": 1, "relevanceScore": 0.0004994205664843321},
]
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
client = HTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
rerank(
model=args["model"],
query=args["query"],
documents=args["documents"],
client=client,
)
mock_post.assert_called_once()
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
assert "us-west-2" in mock_post.call_args.kwargs["url"]
assert "us-east-1" not in mock_post.call_args.kwargs["url"]

View file

@ -33,6 +33,7 @@ def assert_response_shape(response, custom_llm_provider):
expected_api_version_shape = {"version": str}
expected_billed_units_shape = {"search_units": int}
expected_billed_units_total_tokens_shape = {"total_tokens": int}
assert isinstance(response.id, expected_response_shape["id"])
assert isinstance(response.results, expected_response_shape["results"])
@ -52,9 +53,15 @@ def assert_response_shape(response, custom_llm_provider):
response.meta["api_version"]["version"],
expected_api_version_shape["version"],
)
assert isinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"]
)
if "total_tokens" in response.meta["billed_units"]:
assert isinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"]
response.meta["billed_units"]["total_tokens"],
expected_billed_units_total_tokens_shape["total_tokens"],
)
else:
assert isinstance(
response.meta["billed_units"]["search_units"],
expected_billed_units_shape["search_units"],
@ -79,7 +86,9 @@ class BaseLLMRerankTest(ABC):
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(self, sync_mode):
litellm.set_verbose = True
litellm._turn_on_debug()
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
rerank_call_args = self.get_base_rerank_call_args()
custom_llm_provider = self.get_custom_llm_provider()
if sync_mode is True:
@ -95,6 +104,9 @@ class BaseLLMRerankTest(ABC):
assert response.id is not None
assert response.results is not None
assert response._hidden_params["response_cost"] is not None
assert response._hidden_params["response_cost"] > 0
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)

View file

@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat import ModelResponseIterator
import httpx
import json
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from base_rerank_unit_tests import BaseLLMRerankTest
load_dotenv()
import io
@ -185,6 +186,7 @@ def test_completion_azure_ai_command_r():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_azure_deepseek_reasoning_content():
import json
@ -192,37 +194,48 @@ def test_azure_deepseek_reasoning_content():
with patch.object(client, "post") as mock_post:
mock_response = MagicMock()
mock_response.text = json.dumps(
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
"role": "assistant",
}
"finish_reason": "stop",
"index": 0,
"message": {
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
"role": "assistant",
},
}
],
}
)
)
mock_response.status_code = 200
# Add required response attributes
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response
response = litellm.completion(
model='azure_ai/deepseek-r1',
model="azure_ai/deepseek-r1",
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
api_key="my-fake-api-key",
client=client
)
client=client,
)
print(response)
assert(response.choices[0].message.reasoning_content == "I am thinking here")
assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue")
assert response.choices[0].message.reasoning_content == "I am thinking here"
assert response.choices[0].message.content == "\n\nThe sky is a canvas of blue"
class TestAzureAIRerank(BaseLLMRerankTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.AZURE_AI
def get_base_rerank_call_args(self) -> dict:
return {
"model": "azure_ai/cohere-rerank-v3-english",
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
}

View file

@ -2186,6 +2186,16 @@ class TestBedrockRerank(BaseLLMRerankTest):
}
class TestBedrockCohereRerank(BaseLLMRerankTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.BEDROCK
def get_base_rerank_call_args(self) -> dict:
return {
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0",
}
@pytest.mark.parametrize(
"messages, continue_message_index",
[

View file

@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider):
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.flaky(retries=3, delay=1)
async def test_basic_rerank(sync_mode):
litellm.set_verbose = True
if sync_mode is True:
@ -311,6 +312,7 @@ def test_complete_base_url_cohere():
(3, None, False),
],
)
@pytest.mark.flaky(retries=3, delay=1)
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
from litellm.caching.caching import Cache

View file

@ -1574,7 +1574,11 @@ def test_completion_cost_azure_ai_rerank(model):
"relevance_score": 0.990732,
},
],
meta={},
meta={
"billed_units": {
"search_units": 1,
}
},
)
print("response", response)
model = model