Add cost tracking for rerank via bedrock (#8691)

* feat(bedrock/rerank): infer model region if model given as arn

* test: add unit testing to ensure bedrock region name inferred from arn on rerank

* feat(bedrock/rerank/transformation.py): include search units for bedrock rerank result

Resolves https://github.com/BerriAI/litellm/issues/7258#issuecomment-2671557137

* test(test_bedrock_completion.py): add testing for bedrock cohere rerank

* feat(cost_calculator.py): refactor rerank cost tracking to support bedrock cost tracking

* build(model_prices_and_context_window.json): add amazon.rerank model to model cost map

* fix(cost_calculator.py): bedrock/common_utils.py

get base model from model w/ arn -> handles rerank model

* build(model_prices_and_context_window.json): add bedrock cohere rerank pricing

* feat(bedrock/rerank): migrate bedrock config to basererank config

* Revert "feat(bedrock/rerank): migrate bedrock config to basererank config"

This reverts commit 84fae1f167.

* test: add testing to ensure large doc / queries are correctly counted

* Revert "test: add testing to ensure large doc / queries are correctly counted"

This reverts commit 4337f1657e.

* fix(migrate-jina-ai-to-rerank-config): enables cost tracking

* refactor(jina_ai/): finish migrating jina ai to base rerank config

enables cost tracking

* fix(jina_ai/rerank): e2e jina ai rerank cost tracking

* fix: cleanup dead code

* fix: fix python3.8 compatibility error

* test: fix test

* test: add e2e testing for azure ai rerank

* fix: fix linting error

* test: mark cohere as flaky
This commit is contained in:
Krish Dholakia 2025-02-20 21:00:18 -08:00 committed by GitHub
parent 4c9517fd78
commit b682dc4ec8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 524 additions and 296 deletions

View file

@ -10,6 +10,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) | | Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) |
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) | | Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` | | Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
| Rerank Endpoint | `/rerank` |
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) | | Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |

View file

@ -820,6 +820,7 @@ from .llms.cohere.completion.transformation import CohereTextConfig as CohereCon
from .llms.cohere.rerank.transformation import CohereRerankConfig from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
from .llms.infinity.rerank.transformation import InfinityRerankConfig from .llms.infinity.rerank.transformation import InfinityRerankConfig
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.together_ai.chat import TogetherAIConfig from .llms.together_ai.chat import TogetherAIConfig

View file

@ -16,15 +16,9 @@ from litellm.llms.anthropic.cost_calculation import (
from litellm.llms.azure.cost_calculation import ( from litellm.llms.azure.cost_calculation import (
cost_per_token as azure_openai_cost_per_token, cost_per_token as azure_openai_cost_per_token,
) )
from litellm.llms.azure_ai.cost_calculator import (
cost_per_query as azure_ai_rerank_cost_per_query,
)
from litellm.llms.bedrock.image.cost_calculator import ( from litellm.llms.bedrock.image.cost_calculator import (
cost_calculator as bedrock_image_cost_calculator, cost_calculator as bedrock_image_cost_calculator,
) )
from litellm.llms.cohere.cost_calculator import (
cost_per_query as cohere_rerank_cost_per_query,
)
from litellm.llms.databricks.cost_calculator import ( from litellm.llms.databricks.cost_calculator import (
cost_per_token as databricks_cost_per_token, cost_per_token as databricks_cost_per_token,
) )
@ -51,10 +45,12 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
cost_calculator as vertex_ai_image_cost_calculator, cost_calculator as vertex_ai_image_cost_calculator,
) )
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse from litellm.types.rerank import RerankBilledUnits, RerankResponse
from litellm.types.utils import ( from litellm.types.utils import (
CallTypesLiteral, CallTypesLiteral,
LlmProviders,
LlmProvidersSet, LlmProvidersSet,
ModelInfo,
PassthroughCallTypes, PassthroughCallTypes,
Usage, Usage,
) )
@ -64,6 +60,7 @@ from litellm.utils import (
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
ModelResponse, ModelResponse,
ProviderConfigManager,
TextCompletionResponse, TextCompletionResponse,
TranscriptionResponse, TranscriptionResponse,
_cached_get_model_info_helper, _cached_get_model_info_helper,
@ -114,6 +111,8 @@ def cost_per_token( # noqa: PLR0915
number_of_queries: Optional[int] = None, number_of_queries: Optional[int] = None,
### USAGE OBJECT ### ### USAGE OBJECT ###
usage_object: Optional[Usage] = None, # just read the usage object if provided usage_object: Optional[Usage] = None, # just read the usage object if provided
### BILLED UNITS ###
rerank_billed_units: Optional[RerankBilledUnits] = None,
### CALL TYPE ### ### CALL TYPE ###
call_type: CallTypesLiteral = "completion", call_type: CallTypesLiteral = "completion",
audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds
@ -238,6 +237,7 @@ def cost_per_token( # noqa: PLR0915
return rerank_cost( return rerank_cost(
model=model, model=model,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
billed_units=rerank_billed_units,
) )
elif call_type == "atranscription" or call_type == "transcription": elif call_type == "atranscription" or call_type == "transcription":
return openai_cost_per_second( return openai_cost_per_second(
@ -552,6 +552,7 @@ def completion_cost( # noqa: PLR0915
cost_per_token_usage_object: Optional[Usage] = _get_usage_object( cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
completion_response=completion_response completion_response=completion_response
) )
rerank_billed_units: Optional[RerankBilledUnits] = None
model = _select_model_name_for_cost_calc( model = _select_model_name_for_cost_calc(
model=model, model=model,
completion_response=completion_response, completion_response=completion_response,
@ -698,6 +699,11 @@ def completion_cost( # noqa: PLR0915
else: else:
billed_units = {} billed_units = {}
rerank_billed_units = RerankBilledUnits(
search_units=billed_units.get("search_units"),
total_tokens=billed_units.get("total_tokens"),
)
search_units = ( search_units = (
billed_units.get("search_units") or 1 billed_units.get("search_units") or 1
) # cohere charges per request by default. ) # cohere charges per request by default.
@ -763,6 +769,7 @@ def completion_cost( # noqa: PLR0915
usage_object=cost_per_token_usage_object, usage_object=cost_per_token_usage_object,
call_type=call_type, call_type=call_type,
audio_transcription_file_duration=audio_transcription_file_duration, audio_transcription_file_duration=audio_transcription_file_duration,
rerank_billed_units=rerank_billed_units,
) )
_final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar _final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
@ -836,27 +843,33 @@ def response_cost_calculator(
def rerank_cost( def rerank_cost(
model: str, model: str,
custom_llm_provider: Optional[str], custom_llm_provider: Optional[str],
billed_units: Optional[RerankBilledUnits] = None,
) -> Tuple[float, float]: ) -> Tuple[float, float]:
""" """
Returns Returns
- float or None: cost of response OR none if error. - float or None: cost of response OR none if error.
""" """
default_num_queries = 1
_, custom_llm_provider, _, _ = litellm.get_llm_provider( _, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
try: try:
if custom_llm_provider == "cohere": config = ProviderConfigManager.get_provider_rerank_config(
return cohere_rerank_cost_per_query( model=model, provider=LlmProviders(custom_llm_provider)
model=model, num_queries=default_num_queries )
try:
model_info: Optional[ModelInfo] = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
) )
elif custom_llm_provider == "azure_ai": except Exception:
return azure_ai_rerank_cost_per_query( model_info = None
model=model, num_queries=default_num_queries
) return config.calculate_rerank_cost(
raise ValueError( model=model,
f"invalid custom_llm_provider for rerank model: {model}, custom_llm_provider: {custom_llm_provider}" custom_llm_provider=custom_llm_provider,
billed_units=billed_units,
model_info=model_info,
) )
except Exception as e: except Exception as e:
raise e raise e

View file

@ -1,32 +0,0 @@
"""
Handles custom cost calculation for Azure AI models.
Custom cost calculation for Azure AI models only requied for rerank.
"""
from typing import Tuple
from litellm.utils import get_model_info
def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
"""
Calculates the cost per query for a given rerank model.
Input:
- model: str, the model name without provider prefix
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
model_info = get_model_info(model=model, custom_llm_provider="azure_ai")
if (
"input_cost_per_query" not in model_info
or model_info["input_cost_per_query"] is None
):
return 0.0, 0.0
prompt_cost = model_info["input_cost_per_query"] * num_queries
return prompt_cost, 0.0

View file

@ -1,9 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx import httpx
from litellm.types.rerank import OptionalRerankParams, RerankResponse from litellm.types.rerank import OptionalRerankParams, RerankBilledUnits, RerankResponse
from litellm.types.utils import ModelInfo
from ..chat.transformation import BaseLLMException from ..chat.transformation import BaseLLMException
@ -66,7 +67,7 @@ class BaseRerankConfig(ABC):
@abstractmethod @abstractmethod
def map_cohere_rerank_params( def map_cohere_rerank_params(
self, self,
non_default_params: Optional[dict], non_default_params: dict,
model: str, model: str,
drop_params: bool, drop_params: bool,
query: str, query: str,
@ -79,8 +80,48 @@ class BaseRerankConfig(ABC):
) -> OptionalRerankParams: ) -> OptionalRerankParams:
pass pass
@abstractmethod
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException: ) -> BaseLLMException:
pass raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
def calculate_rerank_cost(
self,
model: str,
custom_llm_provider: Optional[str] = None,
billed_units: Optional[RerankBilledUnits] = None,
model_info: Optional[ModelInfo] = None,
) -> Tuple[float, float]:
"""
Calculates the cost per query for a given rerank model.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider.
- num_queries: int, the number of queries to calculate the cost for
- model_info: ModelInfo, the model info for the given model
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
if (
model_info is None
or "input_cost_per_query" not in model_info
or model_info["input_cost_per_query"] is None
or billed_units is None
):
return 0.0, 0.0
search_units = billed_units.get("search_units")
if search_units is None:
return 0.0, 0.0
prompt_cost = model_info["input_cost_per_query"] * search_units
return prompt_cost, 0.0

View file

@ -202,6 +202,61 @@ class BaseAWSLLM:
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl) self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
return credentials return credentials
def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
try:
# First check if the string contains the expected prefix
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
return None
# Split the ARN and check if we have enough parts
parts = model.split(":")
if len(parts) < 4:
return None
# Get the region from the correct position
region = parts[3]
if not region: # Check if region is empty
return None
return region
except Exception:
# Catch any unexpected errors and return None
return None
def _get_aws_region_name(
self, optional_params: dict, model: Optional[str] = None
) -> str:
"""
Get the AWS region name from the environment variables
"""
aws_region_name = optional_params.get("aws_region_name", None)
### SET REGION NAME ###
if aws_region_name is None:
# check model arn #
aws_region_name = self._get_aws_region_from_model_arn(model)
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if (
aws_region_name is None
and litellm_aws_region_name is not None
and isinstance(litellm_aws_region_name, str)
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if (
aws_region_name is None
and standard_aws_region_name is not None
and isinstance(standard_aws_region_name, str)
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
return aws_region_name
@tracer.wrap() @tracer.wrap()
def _auth_with_web_identity_token( def _auth_with_web_identity_token(
self, self,
@ -423,7 +478,7 @@ class BaseAWSLLM:
return endpoint_url, proxy_endpoint_url return endpoint_url, proxy_endpoint_url
def _get_boto_credentials_from_optional_params( def _get_boto_credentials_from_optional_params(
self, optional_params: dict self, optional_params: dict, model: Optional[str] = None
) -> Boto3CredentialsInfo: ) -> Boto3CredentialsInfo:
""" """
Get boto3 credentials from optional params Get boto3 credentials from optional params
@ -443,7 +498,7 @@ class BaseAWSLLM:
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None) aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None) aws_region_name = self._get_aws_region_name(optional_params, model)
aws_role_name = optional_params.pop("aws_role_name", None) aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None) aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None) aws_profile_name = optional_params.pop("aws_profile_name", None)
@ -453,25 +508,6 @@ class BaseAWSLLM:
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com ) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret_str("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials( credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,

View file

@ -27,7 +27,7 @@ from litellm.llms.custom_httpx.http_handler import (
) )
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage from litellm.types.utils import ModelResponse, Usage
from litellm.utils import CustomStreamWrapper, get_secret from litellm.utils import CustomStreamWrapper
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -598,61 +598,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
) )
return modelId return modelId
def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
try:
# First check if the string contains the expected prefix
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
return None
# Split the ARN and check if we have enough parts
parts = model.split(":")
if len(parts) < 4:
return None
# Get the region from the correct position
region = parts[3]
if not region: # Check if region is empty
return None
return region
except Exception:
# Catch any unexpected errors and return None
return None
def _get_aws_region_name(
self, optional_params: dict, model: Optional[str] = None
) -> str:
"""
Get the AWS region name from the environment variables
"""
aws_region_name = optional_params.get("aws_region_name", None)
### SET REGION NAME ###
if aws_region_name is None:
# check model arn #
aws_region_name = self.get_aws_region_from_model_arn(model)
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if (
aws_region_name is None
and litellm_aws_region_name is not None
and isinstance(litellm_aws_region_name, str)
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if (
aws_region_name is None
and standard_aws_region_name is not None
and isinstance(standard_aws_region_name, str)
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
return aws_region_name
def _get_model_id_from_model_with_spec( def _get_model_id_from_model_with_spec(
self, self,
model: str, model: str,

View file

@ -318,6 +318,23 @@ class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig() global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions() all_global_regions = global_config.get_all_regions()
@staticmethod
def extract_model_name_from_arn(model: str) -> str:
"""
Extract the model name from an AWS Bedrock ARN.
Returns the string after the last '/' if 'arn' is in the input string.
Args:
arn (str): The ARN string to parse
Returns:
str: The extracted model name if 'arn' is in the string,
otherwise returns the original string
"""
if "arn" in model.lower():
return model.split("/")[-1]
return model
@staticmethod @staticmethod
def get_base_model(model: str) -> str: def get_base_model(model: str) -> str:
""" """
@ -335,6 +352,8 @@ class BedrockModelInfo(BaseLLMModelInfo):
if model.startswith("invoke/"): if model.startswith("invoke/"):
model = model.split("/", 1)[1] model = model.split("/", 1)[1]
model = BedrockModelInfo.extract_model_name_from_arn(model)
potential_region = model.split(".", 1)[0] potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[ alt_potential_region = model.split("/", 1)[

View file

@ -163,7 +163,7 @@ class BedrockImageGeneration(BaseAWSLLM):
except ImportError: except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params( boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params optional_params, model
) )
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###

View file

@ -6,6 +6,8 @@ import httpx
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
@ -27,8 +29,10 @@ class BedrockRerankHandler(BaseAWSLLM):
async def arerank( async def arerank(
self, self,
prepared_request: BedrockPreparedRequest, prepared_request: BedrockPreparedRequest,
client: Optional[AsyncHTTPHandler] = None,
): ):
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK) if client is None:
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
try: try:
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status() response.raise_for_status()
@ -54,7 +58,9 @@ class BedrockRerankHandler(BaseAWSLLM):
_is_async: Optional[bool] = False, _is_async: Optional[bool] = False,
api_base: Optional[str] = None, api_base: Optional[str] = None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse: ) -> RerankResponse:
request_data = RerankRequest( request_data = RerankRequest(
model=model, model=model,
query=query, query=query,
@ -66,6 +72,7 @@ class BedrockRerankHandler(BaseAWSLLM):
data = BedrockRerankConfig()._transform_request(request_data) data = BedrockRerankConfig()._transform_request(request_data)
prepared_request = self._prepare_request( prepared_request = self._prepare_request(
model=model,
optional_params=optional_params, optional_params=optional_params,
api_base=api_base, api_base=api_base,
extra_headers=extra_headers, extra_headers=extra_headers,
@ -83,9 +90,10 @@ class BedrockRerankHandler(BaseAWSLLM):
) )
if _is_async: if _is_async:
return self.arerank(prepared_request) # type: ignore return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
client = _get_httpx_client() if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
try: try:
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status() response.raise_for_status()
@ -95,10 +103,18 @@ class BedrockRerankHandler(BaseAWSLLM):
except httpx.TimeoutException: except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.") raise BedrockError(status_code=408, message="Timeout error occurred.")
return BedrockRerankConfig()._transform_response(response.json()) logging_obj.post_call(
original_response=response.text,
api_key="",
)
response_json = response.json()
return BedrockRerankConfig()._transform_response(response_json)
def _prepare_request( def _prepare_request(
self, self,
model: str,
api_base: Optional[str], api_base: Optional[str],
extra_headers: Optional[dict], extra_headers: Optional[dict],
data: dict, data: dict,
@ -110,7 +126,7 @@ class BedrockRerankHandler(BaseAWSLLM):
except ImportError: except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params( boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params optional_params, model
) )
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###

View file

@ -91,7 +91,9 @@ class BedrockRerankConfig:
example input: example input:
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]} {"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
""" """
_billed_units = RerankBilledUnits(**response.get("usage", {})) _billed_units = RerankBilledUnits(
**response.get("usage", {"search_units": 1})
) # by default 1 search unit
_tokens = RerankTokens(**response.get("usage", {})) _tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)

View file

@ -1,31 +0,0 @@
"""
Custom cost calculator for Cohere rerank models
"""
from typing import Tuple
from litellm.utils import get_model_info
def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
"""
Calculates the cost per query for a given rerank model.
Input:
- model: str, the model name without provider prefix
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
model_info = get_model_info(model=model, custom_llm_provider="cohere")
if (
"input_cost_per_query" not in model_info
or model_info["input_cost_per_query"] is None
):
return 0.0, 0.0
prompt_cost = model_info["input_cost_per_query"] * num_queries
return prompt_cost, 0.0

View file

@ -1,92 +1,3 @@
""" """
Re rank api HTTP calling migrated to `llm_http_handler.py`
LiteLLM supports the re rank API format, no paramter transformation occurs
""" """
from typing import Any, Dict, List, Optional, Union
import litellm
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
from litellm.types.rerank import RerankRequest, RerankResponse
class JinaAIRerank(BaseLLM):
def rerank(
self,
model: str,
api_key: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
) -> RerankResponse:
client = _get_httpx_client()
request_data = RerankRequest(
model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
)
# exclude None values from request_data
request_data_dict = request_data.dict(exclude_none=True)
if _is_async:
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
response = client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
async def async_rerank( # New async method
self,
request_data_dict: Dict[str, Any],
api_key: str,
) -> RerankResponse:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.JINA_AI
) # Use async client
response = await client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
pass

View file

@ -7,30 +7,136 @@ Docs - https://jina.ai/reranker
""" """
import uuid import uuid
from typing import List, Optional from typing import Any, Dict, List, Optional, Tuple, Union
from httpx import URL, Response
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.types.rerank import ( from litellm.types.rerank import (
OptionalRerankParams,
RerankBilledUnits, RerankBilledUnits,
RerankResponse, RerankResponse,
RerankResponseMeta, RerankResponseMeta,
RerankTokens, RerankTokens,
) )
from litellm.types.utils import ModelInfo
class JinaAIRerankConfig: class JinaAIRerankConfig(BaseRerankConfig):
def _transform_response(self, response: dict) -> RerankResponse: def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"top_n",
"documents",
"return_documents",
]
_billed_units = RerankBilledUnits(**response.get("usage", {})) def map_cohere_rerank_params(
_tokens = RerankTokens(**response.get("usage", {})) self,
non_default_params: dict,
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
optional_params = {}
supported_params = self.get_supported_cohere_rerank_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return OptionalRerankParams(
**optional_params,
)
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
base_path = "/v1/rerank"
if api_base is None:
return "https://api.jina.ai/v1/rerank"
base = URL(api_base)
# Reconstruct URL with cleaned path
cleaned_base = str(base.copy_with(path=base_path))
return cleaned_base
def transform_rerank_request(
self, model: str, optional_rerank_params: OptionalRerankParams, headers: Dict
) -> Dict:
return {"model": model, **optional_rerank_params}
def transform_rerank_response(
self,
model: str,
raw_response: Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: Dict = {},
optional_params: Dict = {},
litellm_params: Dict = {},
) -> RerankResponse:
if raw_response.status_code != 200:
raise Exception(raw_response.text)
logging_obj.post_call(original_response=raw_response.text)
_json_response = raw_response.json()
_billed_units = RerankBilledUnits(**_json_response.get("usage", {}))
_tokens = RerankTokens(**_json_response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = response.get("results") _results: Optional[List[dict]] = _json_response.get("results")
if _results is None: if _results is None:
raise ValueError(f"No results found in the response={response}") raise ValueError(f"No results found in the response={_json_response}")
return RerankResponse( return RerankResponse(
id=response.get("id") or str(uuid.uuid4()), id=_json_response.get("id") or str(uuid.uuid4()),
results=_results, # type: ignore results=_results, # type: ignore
meta=rerank_meta, meta=rerank_meta,
) # Return response ) # Return response
def validate_environment(
self, headers: Dict, model: str, api_key: Optional[str] = None
) -> Dict:
if api_key is None:
raise ValueError(
"api_key is required. Set via `api_key` parameter or `JINA_API_KEY` environment variable."
)
return {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
}
def calculate_rerank_cost(
self,
model: str,
custom_llm_provider: Optional[str] = None,
billed_units: Optional[RerankBilledUnits] = None,
model_info: Optional[ModelInfo] = None,
) -> Tuple[float, float]:
"""
Jina AI reranker is priced at $0.000000018 per token.
"""
if (
model_info is None
or "input_cost_per_token" not in model_info
or model_info["input_cost_per_token"] is None
or billed_units is None
):
return 0.0, 0.0
total_tokens = billed_units.get("total_tokens")
if total_tokens is None:
return 0.0, 0.0
input_cost = model_info["input_cost_per_token"] * total_tokens
return input_cost, 0.0

View file

@ -5982,6 +5982,19 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"amazon.rerank-v1:0": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"max_query_tokens": 32000,
"max_document_chunks_per_query": 100,
"max_tokens_per_document_chunk": 512,
"input_cost_per_token": 0.0,
"input_cost_per_query": 0.001,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "rerank"
},
"amazon.titan-text-lite-v1": { "amazon.titan-text-lite-v1": {
"max_tokens": 4000, "max_tokens": 4000,
"max_input_tokens": 42000, "max_input_tokens": 42000,
@ -7022,6 +7035,19 @@
"mode": "chat", "mode": "chat",
"supports_tool_choice": true "supports_tool_choice": true
}, },
"cohere.rerank-v3-5:0": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"max_query_tokens": 32000,
"max_document_chunks_per_query": 100,
"max_tokens_per_document_chunk": 512,
"input_cost_per_token": 0.0,
"input_cost_per_query": 0.002,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "rerank"
},
"cohere.command-text-v14": { "cohere.command-text-v14": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 4096, "max_input_tokens": 4096,
@ -9154,5 +9180,15 @@
"input_cost_per_second": 0.00003333, "input_cost_per_second": 0.00003333,
"output_cost_per_second": 0.00, "output_cost_per_second": 0.00,
"litellm_provider": "assemblyai" "litellm_provider": "assemblyai"
},
"jina-reranker-v2-base-multilingual": {
"max_tokens": 1024,
"max_input_tokens": 1024,
"max_output_tokens": 1024,
"max_document_chunks_per_query": 2048,
"input_cost_per_token": 0.000000018,
"output_cost_per_token": 0.000000018,
"litellm_provider": "jina_ai",
"mode": "rerank"
} }
} }

View file

@ -9,7 +9,6 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.rerank_api.rerank_utils import get_optional_rerank_params from litellm.rerank_api.rerank_utils import get_optional_rerank_params
from litellm.secret_managers.main import get_secret, get_secret_str from litellm.secret_managers.main import get_secret, get_secret_str
@ -20,7 +19,6 @@ from litellm.utils import ProviderConfigManager, client, exception_type
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here # Initialize any necessary instances or variables here
together_rerank = TogetherAIRerank() together_rerank = TogetherAIRerank()
jina_ai_rerank = JinaAIRerank()
bedrock_rerank = BedrockRerankHandler() bedrock_rerank = BedrockRerankHandler()
base_llm_http_handler = BaseLLMHTTPHandler() base_llm_http_handler = BaseLLMHTTPHandler()
################################################# #################################################
@ -264,16 +262,26 @@ def rerank( # noqa: PLR0915
raise ValueError( raise ValueError(
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment" "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
) )
response = jina_ai_rerank.rerank(
api_base = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("BEDROCK_API_BASE") # type: ignore
)
response = base_llm_http_handler.rerank(
model=model, model=model,
api_key=dynamic_api_key, custom_llm_provider=_custom_llm_provider,
query=query, optional_rerank_params=optional_rerank_params,
documents=documents, logging_obj=litellm_logging_obj,
top_n=top_n, timeout=optional_params.timeout,
rank_fields=rank_fields, api_key=dynamic_api_key or optional_params.api_key,
return_documents=return_documents, api_base=api_base,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async, _is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
) )
elif _custom_llm_provider == "bedrock": elif _custom_llm_provider == "bedrock":
api_base = ( api_base = (
@ -295,6 +303,7 @@ def rerank( # noqa: PLR0915
optional_params=optional_params.model_dump(exclude_unset=True), optional_params=optional_params.model_dump(exclude_unset=True),
api_base=api_base, api_base=api_base,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
client=client,
) )
else: else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}") raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

View file

@ -17,6 +17,17 @@ def get_optional_rerank_params(
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
non_default_params: Optional[dict] = None, non_default_params: Optional[dict] = None,
) -> OptionalRerankParams: ) -> OptionalRerankParams:
all_non_default_params = non_default_params or {}
if query is not None:
all_non_default_params["query"] = query
if top_n is not None:
all_non_default_params["top_n"] = top_n
if documents is not None:
all_non_default_params["documents"] = documents
if return_documents is not None:
all_non_default_params["return_documents"] = return_documents
if max_chunks_per_doc is not None:
all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc
return rerank_provider_config.map_cohere_rerank_params( return rerank_provider_config.map_cohere_rerank_params(
model=model, model=model,
drop_params=drop_params, drop_params=drop_params,
@ -27,5 +38,5 @@ def get_optional_rerank_params(
rank_fields=rank_fields, rank_fields=rank_fields,
return_documents=return_documents, return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc, max_chunks_per_doc=max_chunks_per_doc,
non_default_params=non_default_params, non_default_params=all_non_default_params,
) )

View file

@ -6198,6 +6198,8 @@ class ProviderConfigManager:
return litellm.AzureAIRerankConfig() return litellm.AzureAIRerankConfig()
elif litellm.LlmProviders.INFINITY == provider: elif litellm.LlmProviders.INFINITY == provider:
return litellm.InfinityRerankConfig() return litellm.InfinityRerankConfig()
elif litellm.LlmProviders.JINA_AI == provider:
return litellm.JinaAIRerankConfig()
return litellm.CohereRerankConfig() return litellm.CohereRerankConfig()
@staticmethod @staticmethod

View file

@ -5982,6 +5982,19 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"amazon.rerank-v1:0": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"max_query_tokens": 32000,
"max_document_chunks_per_query": 100,
"max_tokens_per_document_chunk": 512,
"input_cost_per_token": 0.0,
"input_cost_per_query": 0.001,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "rerank"
},
"amazon.titan-text-lite-v1": { "amazon.titan-text-lite-v1": {
"max_tokens": 4000, "max_tokens": 4000,
"max_input_tokens": 42000, "max_input_tokens": 42000,
@ -7022,6 +7035,19 @@
"mode": "chat", "mode": "chat",
"supports_tool_choice": true "supports_tool_choice": true
}, },
"cohere.rerank-v3-5:0": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"max_query_tokens": 32000,
"max_document_chunks_per_query": 100,
"max_tokens_per_document_chunk": 512,
"input_cost_per_token": 0.0,
"input_cost_per_query": 0.002,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "rerank"
},
"cohere.command-text-v14": { "cohere.command-text-v14": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 4096, "max_input_tokens": 4096,
@ -9154,5 +9180,15 @@
"input_cost_per_second": 0.00003333, "input_cost_per_second": 0.00003333,
"output_cost_per_second": 0.00, "output_cost_per_second": 0.00,
"litellm_provider": "assemblyai" "litellm_provider": "assemblyai"
},
"jina-reranker-v2-base-multilingual": {
"max_tokens": 1024,
"max_input_tokens": 1024,
"max_output_tokens": 1024,
"max_document_chunks_per_query": 2048,
"input_cost_per_token": 0.000000018,
"output_cost_per_token": 0.000000018,
"litellm_provider": "jina_ai",
"mode": "rerank"
} }
} }

View file

@ -0,0 +1,14 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
from litellm import rerank
from litellm.llms.custom_httpx.http_handler import HTTPHandler

View file

@ -0,0 +1,51 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
from litellm import rerank
from litellm.llms.custom_httpx.http_handler import HTTPHandler
def test_rerank_infer_region_from_model_arn(monkeypatch):
mock_response = MagicMock()
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
args = {
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
"query": "hello",
"documents": ["hello", "world"],
}
def return_val():
return {
"results": [
{"index": 0, "relevanceScore": 0.6716859340667725},
{"index": 1, "relevanceScore": 0.0004994205664843321},
]
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
client = HTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
rerank(
model=args["model"],
query=args["query"],
documents=args["documents"],
client=client,
)
mock_post.assert_called_once()
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
assert "us-west-2" in mock_post.call_args.kwargs["url"]
assert "us-east-1" not in mock_post.call_args.kwargs["url"]

View file

@ -33,6 +33,7 @@ def assert_response_shape(response, custom_llm_provider):
expected_api_version_shape = {"version": str} expected_api_version_shape = {"version": str}
expected_billed_units_shape = {"search_units": int} expected_billed_units_shape = {"search_units": int}
expected_billed_units_total_tokens_shape = {"total_tokens": int}
assert isinstance(response.id, expected_response_shape["id"]) assert isinstance(response.id, expected_response_shape["id"])
assert isinstance(response.results, expected_response_shape["results"]) assert isinstance(response.results, expected_response_shape["results"])
@ -52,9 +53,15 @@ def assert_response_shape(response, custom_llm_provider):
response.meta["api_version"]["version"], response.meta["api_version"]["version"],
expected_api_version_shape["version"], expected_api_version_shape["version"],
) )
assert isinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"]
)
if "total_tokens" in response.meta["billed_units"]:
assert isinstance( assert isinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"] response.meta["billed_units"]["total_tokens"],
expected_billed_units_total_tokens_shape["total_tokens"],
) )
else:
assert isinstance( assert isinstance(
response.meta["billed_units"]["search_units"], response.meta["billed_units"]["search_units"],
expected_billed_units_shape["search_units"], expected_billed_units_shape["search_units"],
@ -79,7 +86,9 @@ class BaseLLMRerankTest(ABC):
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(self, sync_mode): async def test_basic_rerank(self, sync_mode):
litellm.set_verbose = True litellm._turn_on_debug()
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
rerank_call_args = self.get_base_rerank_call_args() rerank_call_args = self.get_base_rerank_call_args()
custom_llm_provider = self.get_custom_llm_provider() custom_llm_provider = self.get_custom_llm_provider()
if sync_mode is True: if sync_mode is True:
@ -95,6 +104,9 @@ class BaseLLMRerankTest(ABC):
assert response.id is not None assert response.id is not None
assert response.results is not None assert response.results is not None
assert response._hidden_params["response_cost"] is not None
assert response._hidden_params["response_cost"] > 0
assert_response_shape( assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value response=response, custom_llm_provider=custom_llm_provider.value
) )

View file

@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat import ModelResponseIterator
import httpx import httpx
import json import json
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from base_rerank_unit_tests import BaseLLMRerankTest
load_dotenv() load_dotenv()
import io import io
@ -185,6 +186,7 @@ def test_completion_azure_ai_command_r():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_azure_deepseek_reasoning_content(): def test_azure_deepseek_reasoning_content():
import json import json
@ -197,12 +199,12 @@ def test_azure_deepseek_reasoning_content():
{ {
"choices": [ "choices": [
{ {
"finish_reason": "stop", "finish_reason": "stop",
"index": 0, "index": 0,
"message": { "message": {
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue", "content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
"role": "assistant", "role": "assistant",
} },
} }
], ],
} }
@ -214,15 +216,26 @@ def test_azure_deepseek_reasoning_content():
mock_response.json = lambda: json.loads(mock_response.text) mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response mock_post.return_value = mock_response
response = litellm.completion( response = litellm.completion(
model='azure_ai/deepseek-r1', model="azure_ai/deepseek-r1",
messages=[{"role": "user", "content": "Hello, world!"}], messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions", api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
api_key="my-fake-api-key", api_key="my-fake-api-key",
client=client client=client,
) )
print(response) print(response)
assert(response.choices[0].message.reasoning_content == "I am thinking here") assert response.choices[0].message.reasoning_content == "I am thinking here"
assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue") assert response.choices[0].message.content == "\n\nThe sky is a canvas of blue"
class TestAzureAIRerank(BaseLLMRerankTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.AZURE_AI
def get_base_rerank_call_args(self) -> dict:
return {
"model": "azure_ai/cohere-rerank-v3-english",
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
}

View file

@ -2186,6 +2186,16 @@ class TestBedrockRerank(BaseLLMRerankTest):
} }
class TestBedrockCohereRerank(BaseLLMRerankTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.BEDROCK
def get_base_rerank_call_args(self) -> dict:
return {
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0",
}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"messages, continue_message_index", "messages, continue_message_index",
[ [

View file

@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider):
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.flaky(retries=3, delay=1)
async def test_basic_rerank(sync_mode): async def test_basic_rerank(sync_mode):
litellm.set_verbose = True litellm.set_verbose = True
if sync_mode is True: if sync_mode is True:
@ -311,6 +312,7 @@ def test_complete_base_url_cohere():
(3, None, False), (3, None, False),
], ],
) )
@pytest.mark.flaky(retries=3, delay=1)
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit): async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
from litellm.caching.caching import Cache from litellm.caching.caching import Cache

View file

@ -1574,7 +1574,11 @@ def test_completion_cost_azure_ai_rerank(model):
"relevance_score": 0.990732, "relevance_score": 0.990732,
}, },
], ],
meta={}, meta={
"billed_units": {
"search_units": 1,
}
},
) )
print("response", response) print("response", response)
model = model model = model