diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md
index 744be74c09..00fe45e99e 100644
--- a/docs/my-website/docs/providers/bedrock.md
+++ b/docs/my-website/docs/providers/bedrock.md
@@ -10,6 +10,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) |
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
+| Rerank Endpoint | `/rerank` |
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
diff --git a/litellm/__init__.py b/litellm/__init__.py
index cbbeca6750..10177bb53c 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -820,6 +820,7 @@ from .llms.cohere.completion.transformation import CohereTextConfig as CohereCon
from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
from .llms.infinity.rerank.transformation import InfinityRerankConfig
+from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.together_ai.chat import TogetherAIConfig
diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py
index edf45d77a8..488684f02b 100644
--- a/litellm/cost_calculator.py
+++ b/litellm/cost_calculator.py
@@ -16,15 +16,9 @@ from litellm.llms.anthropic.cost_calculation import (
from litellm.llms.azure.cost_calculation import (
cost_per_token as azure_openai_cost_per_token,
)
-from litellm.llms.azure_ai.cost_calculator import (
- cost_per_query as azure_ai_rerank_cost_per_query,
-)
from litellm.llms.bedrock.image.cost_calculator import (
cost_calculator as bedrock_image_cost_calculator,
)
-from litellm.llms.cohere.cost_calculator import (
- cost_per_query as cohere_rerank_cost_per_query,
-)
from litellm.llms.databricks.cost_calculator import (
cost_per_token as databricks_cost_per_token,
)
@@ -51,10 +45,12 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
cost_calculator as vertex_ai_image_cost_calculator,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent
-from litellm.types.rerank import RerankResponse
+from litellm.types.rerank import RerankBilledUnits, RerankResponse
from litellm.types.utils import (
CallTypesLiteral,
+ LlmProviders,
LlmProvidersSet,
+ ModelInfo,
PassthroughCallTypes,
Usage,
)
@@ -64,6 +60,7 @@ from litellm.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
+ ProviderConfigManager,
TextCompletionResponse,
TranscriptionResponse,
_cached_get_model_info_helper,
@@ -114,6 +111,8 @@ def cost_per_token( # noqa: PLR0915
number_of_queries: Optional[int] = None,
### USAGE OBJECT ###
usage_object: Optional[Usage] = None, # just read the usage object if provided
+ ### BILLED UNITS ###
+ rerank_billed_units: Optional[RerankBilledUnits] = None,
### CALL TYPE ###
call_type: CallTypesLiteral = "completion",
audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds
@@ -238,6 +237,7 @@ def cost_per_token( # noqa: PLR0915
return rerank_cost(
model=model,
custom_llm_provider=custom_llm_provider,
+ billed_units=rerank_billed_units,
)
elif call_type == "atranscription" or call_type == "transcription":
return openai_cost_per_second(
@@ -552,6 +552,7 @@ def completion_cost( # noqa: PLR0915
cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
completion_response=completion_response
)
+ rerank_billed_units: Optional[RerankBilledUnits] = None
model = _select_model_name_for_cost_calc(
model=model,
completion_response=completion_response,
@@ -698,6 +699,11 @@ def completion_cost( # noqa: PLR0915
else:
billed_units = {}
+ rerank_billed_units = RerankBilledUnits(
+ search_units=billed_units.get("search_units"),
+ total_tokens=billed_units.get("total_tokens"),
+ )
+
search_units = (
billed_units.get("search_units") or 1
) # cohere charges per request by default.
@@ -763,6 +769,7 @@ def completion_cost( # noqa: PLR0915
usage_object=cost_per_token_usage_object,
call_type=call_type,
audio_transcription_file_duration=audio_transcription_file_duration,
+ rerank_billed_units=rerank_billed_units,
)
_final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
@@ -836,27 +843,33 @@ def response_cost_calculator(
def rerank_cost(
model: str,
custom_llm_provider: Optional[str],
+ billed_units: Optional[RerankBilledUnits] = None,
) -> Tuple[float, float]:
"""
Returns
- float or None: cost of response OR none if error.
"""
- default_num_queries = 1
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
try:
- if custom_llm_provider == "cohere":
- return cohere_rerank_cost_per_query(
- model=model, num_queries=default_num_queries
+ config = ProviderConfigManager.get_provider_rerank_config(
+ model=model, provider=LlmProviders(custom_llm_provider)
+ )
+
+ try:
+ model_info: Optional[ModelInfo] = litellm.get_model_info(
+ model=model, custom_llm_provider=custom_llm_provider
)
- elif custom_llm_provider == "azure_ai":
- return azure_ai_rerank_cost_per_query(
- model=model, num_queries=default_num_queries
- )
- raise ValueError(
- f"invalid custom_llm_provider for rerank model: {model}, custom_llm_provider: {custom_llm_provider}"
+ except Exception:
+ model_info = None
+
+ return config.calculate_rerank_cost(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ billed_units=billed_units,
+ model_info=model_info,
)
except Exception as e:
raise e
diff --git a/litellm/llms/azure_ai/cost_calculator.py b/litellm/llms/azure_ai/cost_calculator.py
deleted file mode 100644
index 96d7018458..0000000000
--- a/litellm/llms/azure_ai/cost_calculator.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""
-Handles custom cost calculation for Azure AI models.
-
-Custom cost calculation for Azure AI models only requied for rerank.
-"""
-
-from typing import Tuple
-
-from litellm.utils import get_model_info
-
-
-def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
- """
- Calculates the cost per query for a given rerank model.
-
- Input:
- - model: str, the model name without provider prefix
-
- Returns:
- Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
- """
- model_info = get_model_info(model=model, custom_llm_provider="azure_ai")
-
- if (
- "input_cost_per_query" not in model_info
- or model_info["input_cost_per_query"] is None
- ):
- return 0.0, 0.0
-
- prompt_cost = model_info["input_cost_per_query"] * num_queries
-
- return prompt_cost, 0.0
diff --git a/litellm/llms/base_llm/rerank/transformation.py b/litellm/llms/base_llm/rerank/transformation.py
index d956c9a555..524ed0f8d9 100644
--- a/litellm/llms/base_llm/rerank/transformation.py
+++ b/litellm/llms/base_llm/rerank/transformation.py
@@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
-from litellm.types.rerank import OptionalRerankParams, RerankResponse
+from litellm.types.rerank import OptionalRerankParams, RerankBilledUnits, RerankResponse
+from litellm.types.utils import ModelInfo
from ..chat.transformation import BaseLLMException
@@ -66,7 +67,7 @@ class BaseRerankConfig(ABC):
@abstractmethod
def map_cohere_rerank_params(
self,
- non_default_params: Optional[dict],
+ non_default_params: dict,
model: str,
drop_params: bool,
query: str,
@@ -79,8 +80,48 @@ class BaseRerankConfig(ABC):
) -> OptionalRerankParams:
pass
- @abstractmethod
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
- pass
+ raise BaseLLMException(
+ status_code=status_code,
+ message=error_message,
+ headers=headers,
+ )
+
+ def calculate_rerank_cost(
+ self,
+ model: str,
+ custom_llm_provider: Optional[str] = None,
+ billed_units: Optional[RerankBilledUnits] = None,
+ model_info: Optional[ModelInfo] = None,
+ ) -> Tuple[float, float]:
+ """
+ Calculates the cost per query for a given rerank model.
+
+ Input:
+ - model: str, the model name without provider prefix
+ - custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider.
+ - num_queries: int, the number of queries to calculate the cost for
+ - model_info: ModelInfo, the model info for the given model
+
+ Returns:
+ Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
+ """
+
+ if (
+ model_info is None
+ or "input_cost_per_query" not in model_info
+ or model_info["input_cost_per_query"] is None
+ or billed_units is None
+ ):
+ return 0.0, 0.0
+
+ search_units = billed_units.get("search_units")
+
+ if search_units is None:
+ return 0.0, 0.0
+
+ prompt_cost = model_info["input_cost_per_query"] * search_units
+
+ return prompt_cost, 0.0
diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py
index c46a5f8a0e..33597879c0 100644
--- a/litellm/llms/bedrock/base_aws_llm.py
+++ b/litellm/llms/bedrock/base_aws_llm.py
@@ -202,6 +202,61 @@ class BaseAWSLLM:
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
return credentials
+ def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
+ try:
+ # First check if the string contains the expected prefix
+ if not isinstance(model, str) or "arn:aws:bedrock" not in model:
+ return None
+
+ # Split the ARN and check if we have enough parts
+ parts = model.split(":")
+ if len(parts) < 4:
+ return None
+
+ # Get the region from the correct position
+ region = parts[3]
+ if not region: # Check if region is empty
+ return None
+
+ return region
+ except Exception:
+ # Catch any unexpected errors and return None
+ return None
+
+ def _get_aws_region_name(
+ self, optional_params: dict, model: Optional[str] = None
+ ) -> str:
+ """
+ Get the AWS region name from the environment variables
+ """
+ aws_region_name = optional_params.get("aws_region_name", None)
+ ### SET REGION NAME ###
+ if aws_region_name is None:
+ # check model arn #
+ aws_region_name = self._get_aws_region_from_model_arn(model)
+ # check env #
+ litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
+
+ if (
+ aws_region_name is None
+ and litellm_aws_region_name is not None
+ and isinstance(litellm_aws_region_name, str)
+ ):
+ aws_region_name = litellm_aws_region_name
+
+ standard_aws_region_name = get_secret("AWS_REGION", None)
+ if (
+ aws_region_name is None
+ and standard_aws_region_name is not None
+ and isinstance(standard_aws_region_name, str)
+ ):
+ aws_region_name = standard_aws_region_name
+
+ if aws_region_name is None:
+ aws_region_name = "us-west-2"
+
+ return aws_region_name
+
@tracer.wrap()
def _auth_with_web_identity_token(
self,
@@ -423,7 +478,7 @@ class BaseAWSLLM:
return endpoint_url, proxy_endpoint_url
def _get_boto_credentials_from_optional_params(
- self, optional_params: dict
+ self, optional_params: dict, model: Optional[str] = None
) -> Boto3CredentialsInfo:
"""
Get boto3 credentials from optional params
@@ -443,7 +498,7 @@ class BaseAWSLLM:
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
- aws_region_name = optional_params.pop("aws_region_name", None)
+ aws_region_name = self._get_aws_region_name(optional_params, model)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
@@ -453,25 +508,6 @@ class BaseAWSLLM:
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
- ### SET REGION NAME ###
- if aws_region_name is None:
- # check env #
- litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
-
- if litellm_aws_region_name is not None and isinstance(
- litellm_aws_region_name, str
- ):
- aws_region_name = litellm_aws_region_name
-
- standard_aws_region_name = get_secret_str("AWS_REGION", None)
- if standard_aws_region_name is not None and isinstance(
- standard_aws_region_name, str
- ):
- aws_region_name = standard_aws_region_name
-
- if aws_region_name is None:
- aws_region_name = "us-west-2"
-
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py
index f369057744..b7d4f0ae6d 100644
--- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py
+++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py
@@ -27,7 +27,7 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
-from litellm.utils import CustomStreamWrapper, get_secret
+from litellm.utils import CustomStreamWrapper
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@@ -598,61 +598,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
)
return modelId
- def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
- try:
- # First check if the string contains the expected prefix
- if not isinstance(model, str) or "arn:aws:bedrock" not in model:
- return None
-
- # Split the ARN and check if we have enough parts
- parts = model.split(":")
- if len(parts) < 4:
- return None
-
- # Get the region from the correct position
- region = parts[3]
- if not region: # Check if region is empty
- return None
-
- return region
- except Exception:
- # Catch any unexpected errors and return None
- return None
-
- def _get_aws_region_name(
- self, optional_params: dict, model: Optional[str] = None
- ) -> str:
- """
- Get the AWS region name from the environment variables
- """
- aws_region_name = optional_params.get("aws_region_name", None)
- ### SET REGION NAME ###
- if aws_region_name is None:
- # check model arn #
- aws_region_name = self.get_aws_region_from_model_arn(model)
- # check env #
- litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
-
- if (
- aws_region_name is None
- and litellm_aws_region_name is not None
- and isinstance(litellm_aws_region_name, str)
- ):
- aws_region_name = litellm_aws_region_name
-
- standard_aws_region_name = get_secret("AWS_REGION", None)
- if (
- aws_region_name is None
- and standard_aws_region_name is not None
- and isinstance(standard_aws_region_name, str)
- ):
- aws_region_name = standard_aws_region_name
-
- if aws_region_name is None:
- aws_region_name = "us-west-2"
-
- return aws_region_name
-
def _get_model_id_from_model_with_spec(
self,
model: str,
diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py
index 8a534f6eac..54be359897 100644
--- a/litellm/llms/bedrock/common_utils.py
+++ b/litellm/llms/bedrock/common_utils.py
@@ -318,6 +318,23 @@ class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
+ @staticmethod
+ def extract_model_name_from_arn(model: str) -> str:
+ """
+ Extract the model name from an AWS Bedrock ARN.
+ Returns the string after the last '/' if 'arn' is in the input string.
+
+ Args:
+ arn (str): The ARN string to parse
+
+ Returns:
+ str: The extracted model name if 'arn' is in the string,
+ otherwise returns the original string
+ """
+ if "arn" in model.lower():
+ return model.split("/")[-1]
+ return model
+
@staticmethod
def get_base_model(model: str) -> str:
"""
@@ -335,6 +352,8 @@ class BedrockModelInfo(BaseLLMModelInfo):
if model.startswith("invoke/"):
model = model.split("/", 1)[1]
+ model = BedrockModelInfo.extract_model_name_from_arn(model)
+
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
diff --git a/litellm/llms/bedrock/image/image_handler.py b/litellm/llms/bedrock/image/image_handler.py
index 5b14833f42..4bd63fd21b 100644
--- a/litellm/llms/bedrock/image/image_handler.py
+++ b/litellm/llms/bedrock/image/image_handler.py
@@ -163,7 +163,7 @@ class BedrockImageGeneration(BaseAWSLLM):
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
- optional_params
+ optional_params, model
)
### SET RUNTIME ENDPOINT ###
diff --git a/litellm/llms/bedrock/rerank/handler.py b/litellm/llms/bedrock/rerank/handler.py
index 3683be06b6..cd8be6912c 100644
--- a/litellm/llms/bedrock/rerank/handler.py
+++ b/litellm/llms/bedrock/rerank/handler.py
@@ -6,6 +6,8 @@ import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
@@ -27,8 +29,10 @@ class BedrockRerankHandler(BaseAWSLLM):
async def arerank(
self,
prepared_request: BedrockPreparedRequest,
+ client: Optional[AsyncHTTPHandler] = None,
):
- client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
+ if client is None:
+ client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
try:
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status()
@@ -54,7 +58,9 @@ class BedrockRerankHandler(BaseAWSLLM):
_is_async: Optional[bool] = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
+
request_data = RerankRequest(
model=model,
query=query,
@@ -66,6 +72,7 @@ class BedrockRerankHandler(BaseAWSLLM):
data = BedrockRerankConfig()._transform_request(request_data)
prepared_request = self._prepare_request(
+ model=model,
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
@@ -83,9 +90,10 @@ class BedrockRerankHandler(BaseAWSLLM):
)
if _is_async:
- return self.arerank(prepared_request) # type: ignore
+ return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
- client = _get_httpx_client()
+ if client is None or not isinstance(client, HTTPHandler):
+ client = _get_httpx_client()
try:
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status()
@@ -95,10 +103,18 @@ class BedrockRerankHandler(BaseAWSLLM):
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
- return BedrockRerankConfig()._transform_response(response.json())
+ logging_obj.post_call(
+ original_response=response.text,
+ api_key="",
+ )
+
+ response_json = response.json()
+
+ return BedrockRerankConfig()._transform_response(response_json)
def _prepare_request(
self,
+ model: str,
api_base: Optional[str],
extra_headers: Optional[dict],
data: dict,
@@ -110,7 +126,7 @@ class BedrockRerankHandler(BaseAWSLLM):
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
- optional_params
+ optional_params, model
)
### SET RUNTIME ENDPOINT ###
diff --git a/litellm/llms/bedrock/rerank/transformation.py b/litellm/llms/bedrock/rerank/transformation.py
index 7dc9b0aab1..a5380febe9 100644
--- a/litellm/llms/bedrock/rerank/transformation.py
+++ b/litellm/llms/bedrock/rerank/transformation.py
@@ -91,7 +91,9 @@ class BedrockRerankConfig:
example input:
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
"""
- _billed_units = RerankBilledUnits(**response.get("usage", {}))
+ _billed_units = RerankBilledUnits(
+ **response.get("usage", {"search_units": 1})
+ ) # by default 1 search unit
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
diff --git a/litellm/llms/cohere/cost_calculator.py b/litellm/llms/cohere/cost_calculator.py
deleted file mode 100644
index 224dd5cfa8..0000000000
--- a/litellm/llms/cohere/cost_calculator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-"""
-Custom cost calculator for Cohere rerank models
-"""
-
-from typing import Tuple
-
-from litellm.utils import get_model_info
-
-
-def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
- """
- Calculates the cost per query for a given rerank model.
-
- Input:
- - model: str, the model name without provider prefix
-
- Returns:
- Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
- """
-
- model_info = get_model_info(model=model, custom_llm_provider="cohere")
-
- if (
- "input_cost_per_query" not in model_info
- or model_info["input_cost_per_query"] is None
- ):
- return 0.0, 0.0
-
- prompt_cost = model_info["input_cost_per_query"] * num_queries
-
- return prompt_cost, 0.0
diff --git a/litellm/llms/jina_ai/rerank/handler.py b/litellm/llms/jina_ai/rerank/handler.py
index 355624cd2a..94076da4f3 100644
--- a/litellm/llms/jina_ai/rerank/handler.py
+++ b/litellm/llms/jina_ai/rerank/handler.py
@@ -1,92 +1,3 @@
"""
-Re rank api
-
-LiteLLM supports the re rank API format, no paramter transformation occurs
+HTTP calling migrated to `llm_http_handler.py`
"""
-
-from typing import Any, Dict, List, Optional, Union
-
-import litellm
-from litellm.llms.base import BaseLLM
-from litellm.llms.custom_httpx.http_handler import (
- _get_httpx_client,
- get_async_httpx_client,
-)
-from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
-from litellm.types.rerank import RerankRequest, RerankResponse
-
-
-class JinaAIRerank(BaseLLM):
- def rerank(
- self,
- model: str,
- api_key: str,
- query: str,
- documents: List[Union[str, Dict[str, Any]]],
- top_n: Optional[int] = None,
- rank_fields: Optional[List[str]] = None,
- return_documents: Optional[bool] = True,
- max_chunks_per_doc: Optional[int] = None,
- _is_async: Optional[bool] = False,
- ) -> RerankResponse:
- client = _get_httpx_client()
-
- request_data = RerankRequest(
- model=model,
- query=query,
- top_n=top_n,
- documents=documents,
- rank_fields=rank_fields,
- return_documents=return_documents,
- )
-
- # exclude None values from request_data
- request_data_dict = request_data.dict(exclude_none=True)
-
- if _is_async:
- return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
-
- response = client.post(
- "https://api.jina.ai/v1/rerank",
- headers={
- "accept": "application/json",
- "content-type": "application/json",
- "authorization": f"Bearer {api_key}",
- },
- json=request_data_dict,
- )
-
- if response.status_code != 200:
- raise Exception(response.text)
-
- _json_response = response.json()
-
- return JinaAIRerankConfig()._transform_response(_json_response)
-
- async def async_rerank( # New async method
- self,
- request_data_dict: Dict[str, Any],
- api_key: str,
- ) -> RerankResponse:
- client = get_async_httpx_client(
- llm_provider=litellm.LlmProviders.JINA_AI
- ) # Use async client
-
- response = await client.post(
- "https://api.jina.ai/v1/rerank",
- headers={
- "accept": "application/json",
- "content-type": "application/json",
- "authorization": f"Bearer {api_key}",
- },
- json=request_data_dict,
- )
-
- if response.status_code != 200:
- raise Exception(response.text)
-
- _json_response = response.json()
-
- return JinaAIRerankConfig()._transform_response(_json_response)
-
- pass
diff --git a/litellm/llms/jina_ai/rerank/transformation.py b/litellm/llms/jina_ai/rerank/transformation.py
index a6c0a810c7..4adb9cb0ec 100644
--- a/litellm/llms/jina_ai/rerank/transformation.py
+++ b/litellm/llms/jina_ai/rerank/transformation.py
@@ -7,30 +7,136 @@ Docs - https://jina.ai/reranker
"""
import uuid
-from typing import List, Optional
+from typing import Any, Dict, List, Optional, Tuple, Union
+from httpx import URL, Response
+
+from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
+from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.types.rerank import (
+ OptionalRerankParams,
RerankBilledUnits,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
+from litellm.types.utils import ModelInfo
-class JinaAIRerankConfig:
- def _transform_response(self, response: dict) -> RerankResponse:
+class JinaAIRerankConfig(BaseRerankConfig):
+ def get_supported_cohere_rerank_params(self, model: str) -> list:
+ return [
+ "query",
+ "top_n",
+ "documents",
+ "return_documents",
+ ]
- _billed_units = RerankBilledUnits(**response.get("usage", {}))
- _tokens = RerankTokens(**response.get("usage", {}))
+ def map_cohere_rerank_params(
+ self,
+ non_default_params: dict,
+ model: str,
+ drop_params: bool,
+ query: str,
+ documents: List[Union[str, Dict[str, Any]]],
+ custom_llm_provider: Optional[str] = None,
+ top_n: Optional[int] = None,
+ rank_fields: Optional[List[str]] = None,
+ return_documents: Optional[bool] = True,
+ max_chunks_per_doc: Optional[int] = None,
+ ) -> OptionalRerankParams:
+ optional_params = {}
+ supported_params = self.get_supported_cohere_rerank_params(model)
+ for k, v in non_default_params.items():
+ if k in supported_params:
+ optional_params[k] = v
+ return OptionalRerankParams(
+ **optional_params,
+ )
+
+ def get_complete_url(self, api_base: Optional[str], model: str) -> str:
+ base_path = "/v1/rerank"
+
+ if api_base is None:
+ return "https://api.jina.ai/v1/rerank"
+ base = URL(api_base)
+ # Reconstruct URL with cleaned path
+ cleaned_base = str(base.copy_with(path=base_path))
+
+ return cleaned_base
+
+ def transform_rerank_request(
+ self, model: str, optional_rerank_params: OptionalRerankParams, headers: Dict
+ ) -> Dict:
+ return {"model": model, **optional_rerank_params}
+
+ def transform_rerank_response(
+ self,
+ model: str,
+ raw_response: Response,
+ model_response: RerankResponse,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: Optional[str] = None,
+ request_data: Dict = {},
+ optional_params: Dict = {},
+ litellm_params: Dict = {},
+ ) -> RerankResponse:
+ if raw_response.status_code != 200:
+ raise Exception(raw_response.text)
+
+ logging_obj.post_call(original_response=raw_response.text)
+
+ _json_response = raw_response.json()
+
+ _billed_units = RerankBilledUnits(**_json_response.get("usage", {}))
+ _tokens = RerankTokens(**_json_response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
- _results: Optional[List[dict]] = response.get("results")
+ _results: Optional[List[dict]] = _json_response.get("results")
if _results is None:
- raise ValueError(f"No results found in the response={response}")
+ raise ValueError(f"No results found in the response={_json_response}")
return RerankResponse(
- id=response.get("id") or str(uuid.uuid4()),
+ id=_json_response.get("id") or str(uuid.uuid4()),
results=_results, # type: ignore
meta=rerank_meta,
) # Return response
+
+ def validate_environment(
+ self, headers: Dict, model: str, api_key: Optional[str] = None
+ ) -> Dict:
+ if api_key is None:
+ raise ValueError(
+ "api_key is required. Set via `api_key` parameter or `JINA_API_KEY` environment variable."
+ )
+ return {
+ "accept": "application/json",
+ "content-type": "application/json",
+ "authorization": f"Bearer {api_key}",
+ }
+
+ def calculate_rerank_cost(
+ self,
+ model: str,
+ custom_llm_provider: Optional[str] = None,
+ billed_units: Optional[RerankBilledUnits] = None,
+ model_info: Optional[ModelInfo] = None,
+ ) -> Tuple[float, float]:
+ """
+ Jina AI reranker is priced at $0.000000018 per token.
+ """
+ if (
+ model_info is None
+ or "input_cost_per_token" not in model_info
+ or model_info["input_cost_per_token"] is None
+ or billed_units is None
+ ):
+ return 0.0, 0.0
+
+ total_tokens = billed_units.get("total_tokens")
+ if total_tokens is None:
+ return 0.0, 0.0
+
+ input_cost = model_info["input_cost_per_token"] * total_tokens
+ return input_cost, 0.0
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index c2c4140933..04175ec502 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -5982,6 +5982,19 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
+ "amazon.rerank-v1:0": {
+ "max_tokens": 32000,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 32000,
+ "max_query_tokens": 32000,
+ "max_document_chunks_per_query": 100,
+ "max_tokens_per_document_chunk": 512,
+ "input_cost_per_token": 0.0,
+ "input_cost_per_query": 0.001,
+ "output_cost_per_token": 0.0,
+ "litellm_provider": "bedrock",
+ "mode": "rerank"
+ },
"amazon.titan-text-lite-v1": {
"max_tokens": 4000,
"max_input_tokens": 42000,
@@ -7022,6 +7035,19 @@
"mode": "chat",
"supports_tool_choice": true
},
+ "cohere.rerank-v3-5:0": {
+ "max_tokens": 32000,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 32000,
+ "max_query_tokens": 32000,
+ "max_document_chunks_per_query": 100,
+ "max_tokens_per_document_chunk": 512,
+ "input_cost_per_token": 0.0,
+ "input_cost_per_query": 0.002,
+ "output_cost_per_token": 0.0,
+ "litellm_provider": "bedrock",
+ "mode": "rerank"
+ },
"cohere.command-text-v14": {
"max_tokens": 4096,
"max_input_tokens": 4096,
@@ -9154,5 +9180,15 @@
"input_cost_per_second": 0.00003333,
"output_cost_per_second": 0.00,
"litellm_provider": "assemblyai"
+ },
+ "jina-reranker-v2-base-multilingual": {
+ "max_tokens": 1024,
+ "max_input_tokens": 1024,
+ "max_output_tokens": 1024,
+ "max_document_chunks_per_query": 2048,
+ "input_cost_per_token": 0.000000018,
+ "output_cost_per_token": 0.000000018,
+ "litellm_provider": "jina_ai",
+ "mode": "rerank"
}
}
diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py
index 8ec05ddadf..9a6eaeb0a7 100644
--- a/litellm/rerank_api/main.py
+++ b/litellm/rerank_api/main.py
@@ -9,7 +9,6 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
-from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
from litellm.secret_managers.main import get_secret, get_secret_str
@@ -20,7 +19,6 @@ from litellm.utils import ProviderConfigManager, client, exception_type
####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here
together_rerank = TogetherAIRerank()
-jina_ai_rerank = JinaAIRerank()
bedrock_rerank = BedrockRerankHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
#################################################
@@ -264,16 +262,26 @@ def rerank( # noqa: PLR0915
raise ValueError(
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
)
- response = jina_ai_rerank.rerank(
+
+ api_base = (
+ dynamic_api_base
+ or optional_params.api_base
+ or litellm.api_base
+ or get_secret("BEDROCK_API_BASE") # type: ignore
+ )
+
+ response = base_llm_http_handler.rerank(
model=model,
- api_key=dynamic_api_key,
- query=query,
- documents=documents,
- top_n=top_n,
- rank_fields=rank_fields,
- return_documents=return_documents,
- max_chunks_per_doc=max_chunks_per_doc,
+ custom_llm_provider=_custom_llm_provider,
+ optional_rerank_params=optional_rerank_params,
+ logging_obj=litellm_logging_obj,
+ timeout=optional_params.timeout,
+ api_key=dynamic_api_key or optional_params.api_key,
+ api_base=api_base,
_is_async=_is_async,
+ headers=headers or litellm.headers or {},
+ client=client,
+ model_response=model_response,
)
elif _custom_llm_provider == "bedrock":
api_base = (
@@ -295,6 +303,7 @@ def rerank( # noqa: PLR0915
optional_params=optional_params.model_dump(exclude_unset=True),
api_base=api_base,
logging_obj=litellm_logging_obj,
+ client=client,
)
else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
diff --git a/litellm/rerank_api/rerank_utils.py b/litellm/rerank_api/rerank_utils.py
index c3e5fda56e..00fb1c5ece 100644
--- a/litellm/rerank_api/rerank_utils.py
+++ b/litellm/rerank_api/rerank_utils.py
@@ -17,6 +17,17 @@ def get_optional_rerank_params(
max_chunks_per_doc: Optional[int] = None,
non_default_params: Optional[dict] = None,
) -> OptionalRerankParams:
+ all_non_default_params = non_default_params or {}
+ if query is not None:
+ all_non_default_params["query"] = query
+ if top_n is not None:
+ all_non_default_params["top_n"] = top_n
+ if documents is not None:
+ all_non_default_params["documents"] = documents
+ if return_documents is not None:
+ all_non_default_params["return_documents"] = return_documents
+ if max_chunks_per_doc is not None:
+ all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc
return rerank_provider_config.map_cohere_rerank_params(
model=model,
drop_params=drop_params,
@@ -27,5 +38,5 @@ def get_optional_rerank_params(
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
- non_default_params=non_default_params,
+ non_default_params=all_non_default_params,
)
diff --git a/litellm/utils.py b/litellm/utils.py
index d18cfed20d..3414f289d3 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -6198,6 +6198,8 @@ class ProviderConfigManager:
return litellm.AzureAIRerankConfig()
elif litellm.LlmProviders.INFINITY == provider:
return litellm.InfinityRerankConfig()
+ elif litellm.LlmProviders.JINA_AI == provider:
+ return litellm.JinaAIRerankConfig()
return litellm.CohereRerankConfig()
@staticmethod
diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json
index c2c4140933..04175ec502 100644
--- a/model_prices_and_context_window.json
+++ b/model_prices_and_context_window.json
@@ -5982,6 +5982,19 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
+ "amazon.rerank-v1:0": {
+ "max_tokens": 32000,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 32000,
+ "max_query_tokens": 32000,
+ "max_document_chunks_per_query": 100,
+ "max_tokens_per_document_chunk": 512,
+ "input_cost_per_token": 0.0,
+ "input_cost_per_query": 0.001,
+ "output_cost_per_token": 0.0,
+ "litellm_provider": "bedrock",
+ "mode": "rerank"
+ },
"amazon.titan-text-lite-v1": {
"max_tokens": 4000,
"max_input_tokens": 42000,
@@ -7022,6 +7035,19 @@
"mode": "chat",
"supports_tool_choice": true
},
+ "cohere.rerank-v3-5:0": {
+ "max_tokens": 32000,
+ "max_input_tokens": 32000,
+ "max_output_tokens": 32000,
+ "max_query_tokens": 32000,
+ "max_document_chunks_per_query": 100,
+ "max_tokens_per_document_chunk": 512,
+ "input_cost_per_token": 0.0,
+ "input_cost_per_query": 0.002,
+ "output_cost_per_token": 0.0,
+ "litellm_provider": "bedrock",
+ "mode": "rerank"
+ },
"cohere.command-text-v14": {
"max_tokens": 4096,
"max_input_tokens": 4096,
@@ -9154,5 +9180,15 @@
"input_cost_per_second": 0.00003333,
"output_cost_per_second": 0.00,
"litellm_provider": "assemblyai"
+ },
+ "jina-reranker-v2-base-multilingual": {
+ "max_tokens": 1024,
+ "max_input_tokens": 1024,
+ "max_output_tokens": 1024,
+ "max_document_chunks_per_query": 2048,
+ "input_cost_per_token": 0.000000018,
+ "output_cost_per_token": 0.000000018,
+ "litellm_provider": "jina_ai",
+ "mode": "rerank"
}
}
diff --git a/tests/litellm/llms/bedrock/rerank/transformation.py b/tests/litellm/llms/bedrock/rerank/transformation.py
new file mode 100644
index 0000000000..870a7cb1f1
--- /dev/null
+++ b/tests/litellm/llms/bedrock/rerank/transformation.py
@@ -0,0 +1,14 @@
+import json
+import os
+import sys
+
+import pytest
+from fastapi.testclient import TestClient
+
+sys.path.insert(
+ 0, os.path.abspath("../../..")
+) # Adds the parent directory to the system path
+from unittest.mock import MagicMock, patch
+
+from litellm import rerank
+from litellm.llms.custom_httpx.http_handler import HTTPHandler
diff --git a/tests/litellm/rerank_api/test_main.py b/tests/litellm/rerank_api/test_main.py
new file mode 100644
index 0000000000..d3051b81f9
--- /dev/null
+++ b/tests/litellm/rerank_api/test_main.py
@@ -0,0 +1,51 @@
+import json
+import os
+import sys
+
+import pytest
+from fastapi.testclient import TestClient
+
+sys.path.insert(
+ 0, os.path.abspath("../../..")
+) # Adds the parent directory to the system path
+from unittest.mock import MagicMock, patch
+
+from litellm import rerank
+from litellm.llms.custom_httpx.http_handler import HTTPHandler
+
+
+def test_rerank_infer_region_from_model_arn(monkeypatch):
+ mock_response = MagicMock()
+
+ monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
+ args = {
+ "model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
+ "query": "hello",
+ "documents": ["hello", "world"],
+ }
+
+ def return_val():
+ return {
+ "results": [
+ {"index": 0, "relevanceScore": 0.6716859340667725},
+ {"index": 1, "relevanceScore": 0.0004994205664843321},
+ ]
+ }
+
+ mock_response.json = return_val
+ mock_response.headers = {"key": "value"}
+ mock_response.status_code = 200
+
+ client = HTTPHandler()
+
+ with patch.object(client, "post", return_value=mock_response) as mock_post:
+ rerank(
+ model=args["model"],
+ query=args["query"],
+ documents=args["documents"],
+ client=client,
+ )
+ mock_post.assert_called_once()
+ print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
+ assert "us-west-2" in mock_post.call_args.kwargs["url"]
+ assert "us-east-1" not in mock_post.call_args.kwargs["url"]
diff --git a/tests/llm_translation/base_rerank_unit_tests.py b/tests/llm_translation/base_rerank_unit_tests.py
index 54f6009fc6..cff4a02753 100644
--- a/tests/llm_translation/base_rerank_unit_tests.py
+++ b/tests/llm_translation/base_rerank_unit_tests.py
@@ -33,6 +33,7 @@ def assert_response_shape(response, custom_llm_provider):
expected_api_version_shape = {"version": str}
expected_billed_units_shape = {"search_units": int}
+ expected_billed_units_total_tokens_shape = {"total_tokens": int}
assert isinstance(response.id, expected_response_shape["id"])
assert isinstance(response.results, expected_response_shape["results"])
@@ -52,9 +53,15 @@ def assert_response_shape(response, custom_llm_provider):
response.meta["api_version"]["version"],
expected_api_version_shape["version"],
)
+ assert isinstance(
+ response.meta["billed_units"], expected_meta_shape["billed_units"]
+ )
+ if "total_tokens" in response.meta["billed_units"]:
assert isinstance(
- response.meta["billed_units"], expected_meta_shape["billed_units"]
+ response.meta["billed_units"]["total_tokens"],
+ expected_billed_units_total_tokens_shape["total_tokens"],
)
+ else:
assert isinstance(
response.meta["billed_units"]["search_units"],
expected_billed_units_shape["search_units"],
@@ -79,7 +86,9 @@ class BaseLLMRerankTest(ABC):
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(self, sync_mode):
- litellm.set_verbose = True
+ litellm._turn_on_debug()
+ os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
+ litellm.model_cost = litellm.get_model_cost_map(url="")
rerank_call_args = self.get_base_rerank_call_args()
custom_llm_provider = self.get_custom_llm_provider()
if sync_mode is True:
@@ -95,6 +104,9 @@ class BaseLLMRerankTest(ABC):
assert response.id is not None
assert response.results is not None
+ assert response._hidden_params["response_cost"] is not None
+ assert response._hidden_params["response_cost"] > 0
+
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)
diff --git a/tests/llm_translation/test_azure_ai.py b/tests/llm_translation/test_azure_ai.py
index efb183bda0..c22c9edafa 100644
--- a/tests/llm_translation/test_azure_ai.py
+++ b/tests/llm_translation/test_azure_ai.py
@@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat import ModelResponseIterator
import httpx
import json
from litellm.llms.custom_httpx.http_handler import HTTPHandler
+from base_rerank_unit_tests import BaseLLMRerankTest
load_dotenv()
import io
@@ -185,6 +186,7 @@ def test_completion_azure_ai_command_r():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
+
def test_azure_deepseek_reasoning_content():
import json
@@ -192,37 +194,48 @@ def test_azure_deepseek_reasoning_content():
with patch.object(client, "post") as mock_post:
mock_response = MagicMock()
-
+
mock_response.text = json.dumps(
{
"choices": [
{
- "finish_reason": "stop",
- "index": 0,
- "message": {
- "content": "I am thinking here\n\nThe sky is a canvas of blue",
- "role": "assistant",
- }
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {
+ "content": "I am thinking here\n\nThe sky is a canvas of blue",
+ "role": "assistant",
+ },
}
],
}
- )
-
+ )
+
mock_response.status_code = 200
# Add required response attributes
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response
-
response = litellm.completion(
- model='azure_ai/deepseek-r1',
+ model="azure_ai/deepseek-r1",
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
api_key="my-fake-api-key",
- client=client
- )
+ client=client,
+ )
print(response)
- assert(response.choices[0].message.reasoning_content == "I am thinking here")
- assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue")
+ assert response.choices[0].message.reasoning_content == "I am thinking here"
+ assert response.choices[0].message.content == "\n\nThe sky is a canvas of blue"
+
+
+class TestAzureAIRerank(BaseLLMRerankTest):
+ def get_custom_llm_provider(self) -> litellm.LlmProviders:
+ return litellm.LlmProviders.AZURE_AI
+
+ def get_base_rerank_call_args(self) -> dict:
+ return {
+ "model": "azure_ai/cohere-rerank-v3-english",
+ "api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
+ "api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
+ }
diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py
index 685c5e5409..7fca7b5f1a 100644
--- a/tests/llm_translation/test_bedrock_completion.py
+++ b/tests/llm_translation/test_bedrock_completion.py
@@ -2186,6 +2186,16 @@ class TestBedrockRerank(BaseLLMRerankTest):
}
+class TestBedrockCohereRerank(BaseLLMRerankTest):
+ def get_custom_llm_provider(self) -> litellm.LlmProviders:
+ return litellm.LlmProviders.BEDROCK
+
+ def get_base_rerank_call_args(self) -> dict:
+ return {
+ "model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0",
+ }
+
+
@pytest.mark.parametrize(
"messages, continue_message_index",
[
diff --git a/tests/llm_translation/test_rerank.py b/tests/llm_translation/test_rerank.py
index 82efa92dfd..d9bcc7ee2d 100644
--- a/tests/llm_translation/test_rerank.py
+++ b/tests/llm_translation/test_rerank.py
@@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider):
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
+@pytest.mark.flaky(retries=3, delay=1)
async def test_basic_rerank(sync_mode):
litellm.set_verbose = True
if sync_mode is True:
@@ -311,6 +312,7 @@ def test_complete_base_url_cohere():
(3, None, False),
],
)
+@pytest.mark.flaky(retries=3, delay=1)
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
from litellm.caching.caching import Cache
diff --git a/tests/local_testing/test_completion_cost.py b/tests/local_testing/test_completion_cost.py
index 124a222fda..978ff75a4c 100644
--- a/tests/local_testing/test_completion_cost.py
+++ b/tests/local_testing/test_completion_cost.py
@@ -1574,7 +1574,11 @@ def test_completion_cost_azure_ai_rerank(model):
"relevance_score": 0.990732,
},
],
- meta={},
+ meta={
+ "billed_units": {
+ "search_units": 1,
+ }
+ },
)
print("response", response)
model = model