diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index 744be74c09..00fe45e99e 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -10,6 +10,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor | Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) | | Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) | | Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` | +| Rerank Endpoint | `/rerank` | | Pass-through Endpoint | [Supported](../pass_through/bedrock.md) | diff --git a/litellm/__init__.py b/litellm/__init__.py index cbbeca6750..10177bb53c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -820,6 +820,7 @@ from .llms.cohere.completion.transformation import CohereTextConfig as CohereCon from .llms.cohere.rerank.transformation import CohereRerankConfig from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig from .llms.infinity.rerank.transformation import InfinityRerankConfig +from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config from .llms.together_ai.chat import TogetherAIConfig diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index edf45d77a8..488684f02b 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -16,15 +16,9 @@ from litellm.llms.anthropic.cost_calculation import ( from litellm.llms.azure.cost_calculation import ( cost_per_token as azure_openai_cost_per_token, ) -from litellm.llms.azure_ai.cost_calculator import ( - cost_per_query as azure_ai_rerank_cost_per_query, -) from litellm.llms.bedrock.image.cost_calculator import ( cost_calculator as bedrock_image_cost_calculator, ) -from litellm.llms.cohere.cost_calculator import ( - cost_per_query as cohere_rerank_cost_per_query, -) from litellm.llms.databricks.cost_calculator import ( cost_per_token as databricks_cost_per_token, ) @@ -51,10 +45,12 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import ( cost_calculator as vertex_ai_image_cost_calculator, ) from litellm.types.llms.openai import HttpxBinaryResponseContent -from litellm.types.rerank import RerankResponse +from litellm.types.rerank import RerankBilledUnits, RerankResponse from litellm.types.utils import ( CallTypesLiteral, + LlmProviders, LlmProvidersSet, + ModelInfo, PassthroughCallTypes, Usage, ) @@ -64,6 +60,7 @@ from litellm.utils import ( EmbeddingResponse, ImageResponse, ModelResponse, + ProviderConfigManager, TextCompletionResponse, TranscriptionResponse, _cached_get_model_info_helper, @@ -114,6 +111,8 @@ def cost_per_token( # noqa: PLR0915 number_of_queries: Optional[int] = None, ### USAGE OBJECT ### usage_object: Optional[Usage] = None, # just read the usage object if provided + ### BILLED UNITS ### + rerank_billed_units: Optional[RerankBilledUnits] = None, ### CALL TYPE ### call_type: CallTypesLiteral = "completion", audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds @@ -238,6 +237,7 @@ def cost_per_token( # noqa: PLR0915 return rerank_cost( model=model, custom_llm_provider=custom_llm_provider, + billed_units=rerank_billed_units, ) elif call_type == "atranscription" or call_type == "transcription": return openai_cost_per_second( @@ -552,6 +552,7 @@ def completion_cost( # noqa: PLR0915 cost_per_token_usage_object: Optional[Usage] = _get_usage_object( completion_response=completion_response ) + rerank_billed_units: Optional[RerankBilledUnits] = None model = _select_model_name_for_cost_calc( model=model, completion_response=completion_response, @@ -698,6 +699,11 @@ def completion_cost( # noqa: PLR0915 else: billed_units = {} + rerank_billed_units = RerankBilledUnits( + search_units=billed_units.get("search_units"), + total_tokens=billed_units.get("total_tokens"), + ) + search_units = ( billed_units.get("search_units") or 1 ) # cohere charges per request by default. @@ -763,6 +769,7 @@ def completion_cost( # noqa: PLR0915 usage_object=cost_per_token_usage_object, call_type=call_type, audio_transcription_file_duration=audio_transcription_file_duration, + rerank_billed_units=rerank_billed_units, ) _final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar @@ -836,27 +843,33 @@ def response_cost_calculator( def rerank_cost( model: str, custom_llm_provider: Optional[str], + billed_units: Optional[RerankBilledUnits] = None, ) -> Tuple[float, float]: """ Returns - float or None: cost of response OR none if error. """ - default_num_queries = 1 _, custom_llm_provider, _, _ = litellm.get_llm_provider( model=model, custom_llm_provider=custom_llm_provider ) try: - if custom_llm_provider == "cohere": - return cohere_rerank_cost_per_query( - model=model, num_queries=default_num_queries + config = ProviderConfigManager.get_provider_rerank_config( + model=model, provider=LlmProviders(custom_llm_provider) + ) + + try: + model_info: Optional[ModelInfo] = litellm.get_model_info( + model=model, custom_llm_provider=custom_llm_provider ) - elif custom_llm_provider == "azure_ai": - return azure_ai_rerank_cost_per_query( - model=model, num_queries=default_num_queries - ) - raise ValueError( - f"invalid custom_llm_provider for rerank model: {model}, custom_llm_provider: {custom_llm_provider}" + except Exception: + model_info = None + + return config.calculate_rerank_cost( + model=model, + custom_llm_provider=custom_llm_provider, + billed_units=billed_units, + model_info=model_info, ) except Exception as e: raise e diff --git a/litellm/llms/azure_ai/cost_calculator.py b/litellm/llms/azure_ai/cost_calculator.py deleted file mode 100644 index 96d7018458..0000000000 --- a/litellm/llms/azure_ai/cost_calculator.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Handles custom cost calculation for Azure AI models. - -Custom cost calculation for Azure AI models only requied for rerank. -""" - -from typing import Tuple - -from litellm.utils import get_model_info - - -def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]: - """ - Calculates the cost per query for a given rerank model. - - Input: - - model: str, the model name without provider prefix - - Returns: - Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd - """ - model_info = get_model_info(model=model, custom_llm_provider="azure_ai") - - if ( - "input_cost_per_query" not in model_info - or model_info["input_cost_per_query"] is None - ): - return 0.0, 0.0 - - prompt_cost = model_info["input_cost_per_query"] * num_queries - - return prompt_cost, 0.0 diff --git a/litellm/llms/base_llm/rerank/transformation.py b/litellm/llms/base_llm/rerank/transformation.py index d956c9a555..524ed0f8d9 100644 --- a/litellm/llms/base_llm/rerank/transformation.py +++ b/litellm/llms/base_llm/rerank/transformation.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import httpx -from litellm.types.rerank import OptionalRerankParams, RerankResponse +from litellm.types.rerank import OptionalRerankParams, RerankBilledUnits, RerankResponse +from litellm.types.utils import ModelInfo from ..chat.transformation import BaseLLMException @@ -66,7 +67,7 @@ class BaseRerankConfig(ABC): @abstractmethod def map_cohere_rerank_params( self, - non_default_params: Optional[dict], + non_default_params: dict, model: str, drop_params: bool, query: str, @@ -79,8 +80,48 @@ class BaseRerankConfig(ABC): ) -> OptionalRerankParams: pass - @abstractmethod def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: - pass + raise BaseLLMException( + status_code=status_code, + message=error_message, + headers=headers, + ) + + def calculate_rerank_cost( + self, + model: str, + custom_llm_provider: Optional[str] = None, + billed_units: Optional[RerankBilledUnits] = None, + model_info: Optional[ModelInfo] = None, + ) -> Tuple[float, float]: + """ + Calculates the cost per query for a given rerank model. + + Input: + - model: str, the model name without provider prefix + - custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider. + - num_queries: int, the number of queries to calculate the cost for + - model_info: ModelInfo, the model info for the given model + + Returns: + Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd + """ + + if ( + model_info is None + or "input_cost_per_query" not in model_info + or model_info["input_cost_per_query"] is None + or billed_units is None + ): + return 0.0, 0.0 + + search_units = billed_units.get("search_units") + + if search_units is None: + return 0.0, 0.0 + + prompt_cost = model_info["input_cost_per_query"] * search_units + + return prompt_cost, 0.0 diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index c46a5f8a0e..33597879c0 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -202,6 +202,61 @@ class BaseAWSLLM: self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl) return credentials + def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]: + try: + # First check if the string contains the expected prefix + if not isinstance(model, str) or "arn:aws:bedrock" not in model: + return None + + # Split the ARN and check if we have enough parts + parts = model.split(":") + if len(parts) < 4: + return None + + # Get the region from the correct position + region = parts[3] + if not region: # Check if region is empty + return None + + return region + except Exception: + # Catch any unexpected errors and return None + return None + + def _get_aws_region_name( + self, optional_params: dict, model: Optional[str] = None + ) -> str: + """ + Get the AWS region name from the environment variables + """ + aws_region_name = optional_params.get("aws_region_name", None) + ### SET REGION NAME ### + if aws_region_name is None: + # check model arn # + aws_region_name = self._get_aws_region_from_model_arn(model) + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if ( + aws_region_name is None + and litellm_aws_region_name is not None + and isinstance(litellm_aws_region_name, str) + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if ( + aws_region_name is None + and standard_aws_region_name is not None + and isinstance(standard_aws_region_name, str) + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + return aws_region_name + @tracer.wrap() def _auth_with_web_identity_token( self, @@ -423,7 +478,7 @@ class BaseAWSLLM: return endpoint_url, proxy_endpoint_url def _get_boto_credentials_from_optional_params( - self, optional_params: dict + self, optional_params: dict, model: Optional[str] = None ) -> Boto3CredentialsInfo: """ Get boto3 credentials from optional params @@ -443,7 +498,7 @@ class BaseAWSLLM: aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_session_token = optional_params.pop("aws_session_token", None) - aws_region_name = optional_params.pop("aws_region_name", None) + aws_region_name = self._get_aws_region_name(optional_params, model) aws_role_name = optional_params.pop("aws_role_name", None) aws_session_name = optional_params.pop("aws_session_name", None) aws_profile_name = optional_params.pop("aws_profile_name", None) @@ -453,25 +508,6 @@ class BaseAWSLLM: "aws_bedrock_runtime_endpoint", None ) # https://bedrock-runtime.{region_name}.amazonaws.com - ### SET REGION NAME ### - if aws_region_name is None: - # check env # - litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None) - - if litellm_aws_region_name is not None and isinstance( - litellm_aws_region_name, str - ): - aws_region_name = litellm_aws_region_name - - standard_aws_region_name = get_secret_str("AWS_REGION", None) - if standard_aws_region_name is not None and isinstance( - standard_aws_region_name, str - ): - aws_region_name = standard_aws_region_name - - if aws_region_name is None: - aws_region_name = "us-west-2" - credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index f369057744..b7d4f0ae6d 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -27,7 +27,7 @@ from litellm.llms.custom_httpx.http_handler import ( ) from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ModelResponse, Usage -from litellm.utils import CustomStreamWrapper, get_secret +from litellm.utils import CustomStreamWrapper if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj @@ -598,61 +598,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): ) return modelId - def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]: - try: - # First check if the string contains the expected prefix - if not isinstance(model, str) or "arn:aws:bedrock" not in model: - return None - - # Split the ARN and check if we have enough parts - parts = model.split(":") - if len(parts) < 4: - return None - - # Get the region from the correct position - region = parts[3] - if not region: # Check if region is empty - return None - - return region - except Exception: - # Catch any unexpected errors and return None - return None - - def _get_aws_region_name( - self, optional_params: dict, model: Optional[str] = None - ) -> str: - """ - Get the AWS region name from the environment variables - """ - aws_region_name = optional_params.get("aws_region_name", None) - ### SET REGION NAME ### - if aws_region_name is None: - # check model arn # - aws_region_name = self.get_aws_region_from_model_arn(model) - # check env # - litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) - - if ( - aws_region_name is None - and litellm_aws_region_name is not None - and isinstance(litellm_aws_region_name, str) - ): - aws_region_name = litellm_aws_region_name - - standard_aws_region_name = get_secret("AWS_REGION", None) - if ( - aws_region_name is None - and standard_aws_region_name is not None - and isinstance(standard_aws_region_name, str) - ): - aws_region_name = standard_aws_region_name - - if aws_region_name is None: - aws_region_name = "us-west-2" - - return aws_region_name - def _get_model_id_from_model_with_spec( self, model: str, diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index 8a534f6eac..54be359897 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -318,6 +318,23 @@ class BedrockModelInfo(BaseLLMModelInfo): global_config = AmazonBedrockGlobalConfig() all_global_regions = global_config.get_all_regions() + @staticmethod + def extract_model_name_from_arn(model: str) -> str: + """ + Extract the model name from an AWS Bedrock ARN. + Returns the string after the last '/' if 'arn' is in the input string. + + Args: + arn (str): The ARN string to parse + + Returns: + str: The extracted model name if 'arn' is in the string, + otherwise returns the original string + """ + if "arn" in model.lower(): + return model.split("/")[-1] + return model + @staticmethod def get_base_model(model: str) -> str: """ @@ -335,6 +352,8 @@ class BedrockModelInfo(BaseLLMModelInfo): if model.startswith("invoke/"): model = model.split("/", 1)[1] + model = BedrockModelInfo.extract_model_name_from_arn(model) + potential_region = model.split(".", 1)[0] alt_potential_region = model.split("/", 1)[ diff --git a/litellm/llms/bedrock/image/image_handler.py b/litellm/llms/bedrock/image/image_handler.py index 5b14833f42..4bd63fd21b 100644 --- a/litellm/llms/bedrock/image/image_handler.py +++ b/litellm/llms/bedrock/image/image_handler.py @@ -163,7 +163,7 @@ class BedrockImageGeneration(BaseAWSLLM): except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") boto3_credentials_info = self._get_boto_credentials_from_optional_params( - optional_params + optional_params, model ) ### SET RUNTIME ENDPOINT ### diff --git a/litellm/llms/bedrock/rerank/handler.py b/litellm/llms/bedrock/rerank/handler.py index 3683be06b6..cd8be6912c 100644 --- a/litellm/llms/bedrock/rerank/handler.py +++ b/litellm/llms/bedrock/rerank/handler.py @@ -6,6 +6,8 @@ import httpx import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, _get_httpx_client, get_async_httpx_client, ) @@ -27,8 +29,10 @@ class BedrockRerankHandler(BaseAWSLLM): async def arerank( self, prepared_request: BedrockPreparedRequest, + client: Optional[AsyncHTTPHandler] = None, ): - client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK) + if client is None: + client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK) try: response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore response.raise_for_status() @@ -54,7 +58,9 @@ class BedrockRerankHandler(BaseAWSLLM): _is_async: Optional[bool] = False, api_base: Optional[str] = None, extra_headers: Optional[dict] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: + request_data = RerankRequest( model=model, query=query, @@ -66,6 +72,7 @@ class BedrockRerankHandler(BaseAWSLLM): data = BedrockRerankConfig()._transform_request(request_data) prepared_request = self._prepare_request( + model=model, optional_params=optional_params, api_base=api_base, extra_headers=extra_headers, @@ -83,9 +90,10 @@ class BedrockRerankHandler(BaseAWSLLM): ) if _is_async: - return self.arerank(prepared_request) # type: ignore + return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore - client = _get_httpx_client() + if client is None or not isinstance(client, HTTPHandler): + client = _get_httpx_client() try: response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore response.raise_for_status() @@ -95,10 +103,18 @@ class BedrockRerankHandler(BaseAWSLLM): except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") - return BedrockRerankConfig()._transform_response(response.json()) + logging_obj.post_call( + original_response=response.text, + api_key="", + ) + + response_json = response.json() + + return BedrockRerankConfig()._transform_response(response_json) def _prepare_request( self, + model: str, api_base: Optional[str], extra_headers: Optional[dict], data: dict, @@ -110,7 +126,7 @@ class BedrockRerankHandler(BaseAWSLLM): except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") boto3_credentials_info = self._get_boto_credentials_from_optional_params( - optional_params + optional_params, model ) ### SET RUNTIME ENDPOINT ### diff --git a/litellm/llms/bedrock/rerank/transformation.py b/litellm/llms/bedrock/rerank/transformation.py index 7dc9b0aab1..a5380febe9 100644 --- a/litellm/llms/bedrock/rerank/transformation.py +++ b/litellm/llms/bedrock/rerank/transformation.py @@ -91,7 +91,9 @@ class BedrockRerankConfig: example input: {"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]} """ - _billed_units = RerankBilledUnits(**response.get("usage", {})) + _billed_units = RerankBilledUnits( + **response.get("usage", {"search_units": 1}) + ) # by default 1 search unit _tokens = RerankTokens(**response.get("usage", {})) rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) diff --git a/litellm/llms/cohere/cost_calculator.py b/litellm/llms/cohere/cost_calculator.py deleted file mode 100644 index 224dd5cfa8..0000000000 --- a/litellm/llms/cohere/cost_calculator.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -Custom cost calculator for Cohere rerank models -""" - -from typing import Tuple - -from litellm.utils import get_model_info - - -def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]: - """ - Calculates the cost per query for a given rerank model. - - Input: - - model: str, the model name without provider prefix - - Returns: - Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd - """ - - model_info = get_model_info(model=model, custom_llm_provider="cohere") - - if ( - "input_cost_per_query" not in model_info - or model_info["input_cost_per_query"] is None - ): - return 0.0, 0.0 - - prompt_cost = model_info["input_cost_per_query"] * num_queries - - return prompt_cost, 0.0 diff --git a/litellm/llms/jina_ai/rerank/handler.py b/litellm/llms/jina_ai/rerank/handler.py index 355624cd2a..94076da4f3 100644 --- a/litellm/llms/jina_ai/rerank/handler.py +++ b/litellm/llms/jina_ai/rerank/handler.py @@ -1,92 +1,3 @@ """ -Re rank api - -LiteLLM supports the re rank API format, no paramter transformation occurs +HTTP calling migrated to `llm_http_handler.py` """ - -from typing import Any, Dict, List, Optional, Union - -import litellm -from litellm.llms.base import BaseLLM -from litellm.llms.custom_httpx.http_handler import ( - _get_httpx_client, - get_async_httpx_client, -) -from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig -from litellm.types.rerank import RerankRequest, RerankResponse - - -class JinaAIRerank(BaseLLM): - def rerank( - self, - model: str, - api_key: str, - query: str, - documents: List[Union[str, Dict[str, Any]]], - top_n: Optional[int] = None, - rank_fields: Optional[List[str]] = None, - return_documents: Optional[bool] = True, - max_chunks_per_doc: Optional[int] = None, - _is_async: Optional[bool] = False, - ) -> RerankResponse: - client = _get_httpx_client() - - request_data = RerankRequest( - model=model, - query=query, - top_n=top_n, - documents=documents, - rank_fields=rank_fields, - return_documents=return_documents, - ) - - # exclude None values from request_data - request_data_dict = request_data.dict(exclude_none=True) - - if _is_async: - return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method - - response = client.post( - "https://api.jina.ai/v1/rerank", - headers={ - "accept": "application/json", - "content-type": "application/json", - "authorization": f"Bearer {api_key}", - }, - json=request_data_dict, - ) - - if response.status_code != 200: - raise Exception(response.text) - - _json_response = response.json() - - return JinaAIRerankConfig()._transform_response(_json_response) - - async def async_rerank( # New async method - self, - request_data_dict: Dict[str, Any], - api_key: str, - ) -> RerankResponse: - client = get_async_httpx_client( - llm_provider=litellm.LlmProviders.JINA_AI - ) # Use async client - - response = await client.post( - "https://api.jina.ai/v1/rerank", - headers={ - "accept": "application/json", - "content-type": "application/json", - "authorization": f"Bearer {api_key}", - }, - json=request_data_dict, - ) - - if response.status_code != 200: - raise Exception(response.text) - - _json_response = response.json() - - return JinaAIRerankConfig()._transform_response(_json_response) - - pass diff --git a/litellm/llms/jina_ai/rerank/transformation.py b/litellm/llms/jina_ai/rerank/transformation.py index a6c0a810c7..4adb9cb0ec 100644 --- a/litellm/llms/jina_ai/rerank/transformation.py +++ b/litellm/llms/jina_ai/rerank/transformation.py @@ -7,30 +7,136 @@ Docs - https://jina.ai/reranker """ import uuid -from typing import List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union +from httpx import URL, Response + +from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.types.rerank import ( + OptionalRerankParams, RerankBilledUnits, RerankResponse, RerankResponseMeta, RerankTokens, ) +from litellm.types.utils import ModelInfo -class JinaAIRerankConfig: - def _transform_response(self, response: dict) -> RerankResponse: +class JinaAIRerankConfig(BaseRerankConfig): + def get_supported_cohere_rerank_params(self, model: str) -> list: + return [ + "query", + "top_n", + "documents", + "return_documents", + ] - _billed_units = RerankBilledUnits(**response.get("usage", {})) - _tokens = RerankTokens(**response.get("usage", {})) + def map_cohere_rerank_params( + self, + non_default_params: dict, + model: str, + drop_params: bool, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[str] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + ) -> OptionalRerankParams: + optional_params = {} + supported_params = self.get_supported_cohere_rerank_params(model) + for k, v in non_default_params.items(): + if k in supported_params: + optional_params[k] = v + return OptionalRerankParams( + **optional_params, + ) + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + base_path = "/v1/rerank" + + if api_base is None: + return "https://api.jina.ai/v1/rerank" + base = URL(api_base) + # Reconstruct URL with cleaned path + cleaned_base = str(base.copy_with(path=base_path)) + + return cleaned_base + + def transform_rerank_request( + self, model: str, optional_rerank_params: OptionalRerankParams, headers: Dict + ) -> Dict: + return {"model": model, **optional_rerank_params} + + def transform_rerank_response( + self, + model: str, + raw_response: Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: Dict = {}, + optional_params: Dict = {}, + litellm_params: Dict = {}, + ) -> RerankResponse: + if raw_response.status_code != 200: + raise Exception(raw_response.text) + + logging_obj.post_call(original_response=raw_response.text) + + _json_response = raw_response.json() + + _billed_units = RerankBilledUnits(**_json_response.get("usage", {})) + _tokens = RerankTokens(**_json_response.get("usage", {})) rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) - _results: Optional[List[dict]] = response.get("results") + _results: Optional[List[dict]] = _json_response.get("results") if _results is None: - raise ValueError(f"No results found in the response={response}") + raise ValueError(f"No results found in the response={_json_response}") return RerankResponse( - id=response.get("id") or str(uuid.uuid4()), + id=_json_response.get("id") or str(uuid.uuid4()), results=_results, # type: ignore meta=rerank_meta, ) # Return response + + def validate_environment( + self, headers: Dict, model: str, api_key: Optional[str] = None + ) -> Dict: + if api_key is None: + raise ValueError( + "api_key is required. Set via `api_key` parameter or `JINA_API_KEY` environment variable." + ) + return { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {api_key}", + } + + def calculate_rerank_cost( + self, + model: str, + custom_llm_provider: Optional[str] = None, + billed_units: Optional[RerankBilledUnits] = None, + model_info: Optional[ModelInfo] = None, + ) -> Tuple[float, float]: + """ + Jina AI reranker is priced at $0.000000018 per token. + """ + if ( + model_info is None + or "input_cost_per_token" not in model_info + or model_info["input_cost_per_token"] is None + or billed_units is None + ): + return 0.0, 0.0 + + total_tokens = billed_units.get("total_tokens") + if total_tokens is None: + return 0.0, 0.0 + + input_cost = model_info["input_cost_per_token"] * total_tokens + return input_cost, 0.0 diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index c2c4140933..04175ec502 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -5982,6 +5982,19 @@ "litellm_provider": "bedrock", "mode": "chat" }, + "amazon.rerank-v1:0": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "max_query_tokens": 32000, + "max_document_chunks_per_query": 100, + "max_tokens_per_document_chunk": 512, + "input_cost_per_token": 0.0, + "input_cost_per_query": 0.001, + "output_cost_per_token": 0.0, + "litellm_provider": "bedrock", + "mode": "rerank" + }, "amazon.titan-text-lite-v1": { "max_tokens": 4000, "max_input_tokens": 42000, @@ -7022,6 +7035,19 @@ "mode": "chat", "supports_tool_choice": true }, + "cohere.rerank-v3-5:0": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "max_query_tokens": 32000, + "max_document_chunks_per_query": 100, + "max_tokens_per_document_chunk": 512, + "input_cost_per_token": 0.0, + "input_cost_per_query": 0.002, + "output_cost_per_token": 0.0, + "litellm_provider": "bedrock", + "mode": "rerank" + }, "cohere.command-text-v14": { "max_tokens": 4096, "max_input_tokens": 4096, @@ -9154,5 +9180,15 @@ "input_cost_per_second": 0.00003333, "output_cost_per_second": 0.00, "litellm_provider": "assemblyai" + }, + "jina-reranker-v2-base-multilingual": { + "max_tokens": 1024, + "max_input_tokens": 1024, + "max_output_tokens": 1024, + "max_document_chunks_per_query": 2048, + "input_cost_per_token": 0.000000018, + "output_cost_per_token": 0.000000018, + "litellm_provider": "jina_ai", + "mode": "rerank" } } diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 8ec05ddadf..9a6eaeb0a7 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -9,7 +9,6 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler -from litellm.llms.jina_ai.rerank.handler import JinaAIRerank from litellm.llms.together_ai.rerank.handler import TogetherAIRerank from litellm.rerank_api.rerank_utils import get_optional_rerank_params from litellm.secret_managers.main import get_secret, get_secret_str @@ -20,7 +19,6 @@ from litellm.utils import ProviderConfigManager, client, exception_type ####### ENVIRONMENT VARIABLES ################### # Initialize any necessary instances or variables here together_rerank = TogetherAIRerank() -jina_ai_rerank = JinaAIRerank() bedrock_rerank = BedrockRerankHandler() base_llm_http_handler = BaseLLMHTTPHandler() ################################################# @@ -264,16 +262,26 @@ def rerank( # noqa: PLR0915 raise ValueError( "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment" ) - response = jina_ai_rerank.rerank( + + api_base = ( + dynamic_api_base + or optional_params.api_base + or litellm.api_base + or get_secret("BEDROCK_API_BASE") # type: ignore + ) + + response = base_llm_http_handler.rerank( model=model, - api_key=dynamic_api_key, - query=query, - documents=documents, - top_n=top_n, - rank_fields=rank_fields, - return_documents=return_documents, - max_chunks_per_doc=max_chunks_per_doc, + custom_llm_provider=_custom_llm_provider, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=dynamic_api_key or optional_params.api_key, + api_base=api_base, _is_async=_is_async, + headers=headers or litellm.headers or {}, + client=client, + model_response=model_response, ) elif _custom_llm_provider == "bedrock": api_base = ( @@ -295,6 +303,7 @@ def rerank( # noqa: PLR0915 optional_params=optional_params.model_dump(exclude_unset=True), api_base=api_base, logging_obj=litellm_logging_obj, + client=client, ) else: raise ValueError(f"Unsupported provider: {_custom_llm_provider}") diff --git a/litellm/rerank_api/rerank_utils.py b/litellm/rerank_api/rerank_utils.py index c3e5fda56e..00fb1c5ece 100644 --- a/litellm/rerank_api/rerank_utils.py +++ b/litellm/rerank_api/rerank_utils.py @@ -17,6 +17,17 @@ def get_optional_rerank_params( max_chunks_per_doc: Optional[int] = None, non_default_params: Optional[dict] = None, ) -> OptionalRerankParams: + all_non_default_params = non_default_params or {} + if query is not None: + all_non_default_params["query"] = query + if top_n is not None: + all_non_default_params["top_n"] = top_n + if documents is not None: + all_non_default_params["documents"] = documents + if return_documents is not None: + all_non_default_params["return_documents"] = return_documents + if max_chunks_per_doc is not None: + all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc return rerank_provider_config.map_cohere_rerank_params( model=model, drop_params=drop_params, @@ -27,5 +38,5 @@ def get_optional_rerank_params( rank_fields=rank_fields, return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, - non_default_params=non_default_params, + non_default_params=all_non_default_params, ) diff --git a/litellm/utils.py b/litellm/utils.py index d18cfed20d..3414f289d3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6198,6 +6198,8 @@ class ProviderConfigManager: return litellm.AzureAIRerankConfig() elif litellm.LlmProviders.INFINITY == provider: return litellm.InfinityRerankConfig() + elif litellm.LlmProviders.JINA_AI == provider: + return litellm.JinaAIRerankConfig() return litellm.CohereRerankConfig() @staticmethod diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index c2c4140933..04175ec502 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -5982,6 +5982,19 @@ "litellm_provider": "bedrock", "mode": "chat" }, + "amazon.rerank-v1:0": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "max_query_tokens": 32000, + "max_document_chunks_per_query": 100, + "max_tokens_per_document_chunk": 512, + "input_cost_per_token": 0.0, + "input_cost_per_query": 0.001, + "output_cost_per_token": 0.0, + "litellm_provider": "bedrock", + "mode": "rerank" + }, "amazon.titan-text-lite-v1": { "max_tokens": 4000, "max_input_tokens": 42000, @@ -7022,6 +7035,19 @@ "mode": "chat", "supports_tool_choice": true }, + "cohere.rerank-v3-5:0": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "max_query_tokens": 32000, + "max_document_chunks_per_query": 100, + "max_tokens_per_document_chunk": 512, + "input_cost_per_token": 0.0, + "input_cost_per_query": 0.002, + "output_cost_per_token": 0.0, + "litellm_provider": "bedrock", + "mode": "rerank" + }, "cohere.command-text-v14": { "max_tokens": 4096, "max_input_tokens": 4096, @@ -9154,5 +9180,15 @@ "input_cost_per_second": 0.00003333, "output_cost_per_second": 0.00, "litellm_provider": "assemblyai" + }, + "jina-reranker-v2-base-multilingual": { + "max_tokens": 1024, + "max_input_tokens": 1024, + "max_output_tokens": 1024, + "max_document_chunks_per_query": 2048, + "input_cost_per_token": 0.000000018, + "output_cost_per_token": 0.000000018, + "litellm_provider": "jina_ai", + "mode": "rerank" } } diff --git a/tests/litellm/llms/bedrock/rerank/transformation.py b/tests/litellm/llms/bedrock/rerank/transformation.py new file mode 100644 index 0000000000..870a7cb1f1 --- /dev/null +++ b/tests/litellm/llms/bedrock/rerank/transformation.py @@ -0,0 +1,14 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path +from unittest.mock import MagicMock, patch + +from litellm import rerank +from litellm.llms.custom_httpx.http_handler import HTTPHandler diff --git a/tests/litellm/rerank_api/test_main.py b/tests/litellm/rerank_api/test_main.py new file mode 100644 index 0000000000..d3051b81f9 --- /dev/null +++ b/tests/litellm/rerank_api/test_main.py @@ -0,0 +1,51 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path +from unittest.mock import MagicMock, patch + +from litellm import rerank +from litellm.llms.custom_httpx.http_handler import HTTPHandler + + +def test_rerank_infer_region_from_model_arn(monkeypatch): + mock_response = MagicMock() + + monkeypatch.setenv("AWS_REGION_NAME", "us-east-1") + args = { + "model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0", + "query": "hello", + "documents": ["hello", "world"], + } + + def return_val(): + return { + "results": [ + {"index": 0, "relevanceScore": 0.6716859340667725}, + {"index": 1, "relevanceScore": 0.0004994205664843321}, + ] + } + + mock_response.json = return_val + mock_response.headers = {"key": "value"} + mock_response.status_code = 200 + + client = HTTPHandler() + + with patch.object(client, "post", return_value=mock_response) as mock_post: + rerank( + model=args["model"], + query=args["query"], + documents=args["documents"], + client=client, + ) + mock_post.assert_called_once() + print(f"mock_post.call_args: {mock_post.call_args.kwargs}") + assert "us-west-2" in mock_post.call_args.kwargs["url"] + assert "us-east-1" not in mock_post.call_args.kwargs["url"] diff --git a/tests/llm_translation/base_rerank_unit_tests.py b/tests/llm_translation/base_rerank_unit_tests.py index 54f6009fc6..cff4a02753 100644 --- a/tests/llm_translation/base_rerank_unit_tests.py +++ b/tests/llm_translation/base_rerank_unit_tests.py @@ -33,6 +33,7 @@ def assert_response_shape(response, custom_llm_provider): expected_api_version_shape = {"version": str} expected_billed_units_shape = {"search_units": int} + expected_billed_units_total_tokens_shape = {"total_tokens": int} assert isinstance(response.id, expected_response_shape["id"]) assert isinstance(response.results, expected_response_shape["results"]) @@ -52,9 +53,15 @@ def assert_response_shape(response, custom_llm_provider): response.meta["api_version"]["version"], expected_api_version_shape["version"], ) + assert isinstance( + response.meta["billed_units"], expected_meta_shape["billed_units"] + ) + if "total_tokens" in response.meta["billed_units"]: assert isinstance( - response.meta["billed_units"], expected_meta_shape["billed_units"] + response.meta["billed_units"]["total_tokens"], + expected_billed_units_total_tokens_shape["total_tokens"], ) + else: assert isinstance( response.meta["billed_units"]["search_units"], expected_billed_units_shape["search_units"], @@ -79,7 +86,9 @@ class BaseLLMRerankTest(ABC): @pytest.mark.asyncio() @pytest.mark.parametrize("sync_mode", [True, False]) async def test_basic_rerank(self, sync_mode): - litellm.set_verbose = True + litellm._turn_on_debug() + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") rerank_call_args = self.get_base_rerank_call_args() custom_llm_provider = self.get_custom_llm_provider() if sync_mode is True: @@ -95,6 +104,9 @@ class BaseLLMRerankTest(ABC): assert response.id is not None assert response.results is not None + assert response._hidden_params["response_cost"] is not None + assert response._hidden_params["response_cost"] > 0 + assert_response_shape( response=response, custom_llm_provider=custom_llm_provider.value ) diff --git a/tests/llm_translation/test_azure_ai.py b/tests/llm_translation/test_azure_ai.py index efb183bda0..c22c9edafa 100644 --- a/tests/llm_translation/test_azure_ai.py +++ b/tests/llm_translation/test_azure_ai.py @@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat import ModelResponseIterator import httpx import json from litellm.llms.custom_httpx.http_handler import HTTPHandler +from base_rerank_unit_tests import BaseLLMRerankTest load_dotenv() import io @@ -185,6 +186,7 @@ def test_completion_azure_ai_command_r(): except Exception as e: pytest.fail(f"Error occurred: {e}") + def test_azure_deepseek_reasoning_content(): import json @@ -192,37 +194,48 @@ def test_azure_deepseek_reasoning_content(): with patch.object(client, "post") as mock_post: mock_response = MagicMock() - + mock_response.text = json.dumps( { "choices": [ { - "finish_reason": "stop", - "index": 0, - "message": { - "content": "I am thinking here\n\nThe sky is a canvas of blue", - "role": "assistant", - } + "finish_reason": "stop", + "index": 0, + "message": { + "content": "I am thinking here\n\nThe sky is a canvas of blue", + "role": "assistant", + }, } ], } - ) - + ) + mock_response.status_code = 200 # Add required response attributes mock_response.headers = {"Content-Type": "application/json"} mock_response.json = lambda: json.loads(mock_response.text) mock_post.return_value = mock_response - response = litellm.completion( - model='azure_ai/deepseek-r1', + model="azure_ai/deepseek-r1", messages=[{"role": "user", "content": "Hello, world!"}], api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions", api_key="my-fake-api-key", - client=client - ) + client=client, + ) print(response) - assert(response.choices[0].message.reasoning_content == "I am thinking here") - assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue") + assert response.choices[0].message.reasoning_content == "I am thinking here" + assert response.choices[0].message.content == "\n\nThe sky is a canvas of blue" + + +class TestAzureAIRerank(BaseLLMRerankTest): + def get_custom_llm_provider(self) -> litellm.LlmProviders: + return litellm.LlmProviders.AZURE_AI + + def get_base_rerank_call_args(self) -> dict: + return { + "model": "azure_ai/cohere-rerank-v3-english", + "api_base": os.getenv("AZURE_AI_COHERE_API_BASE"), + "api_key": os.getenv("AZURE_AI_COHERE_API_KEY"), + } diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 685c5e5409..7fca7b5f1a 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2186,6 +2186,16 @@ class TestBedrockRerank(BaseLLMRerankTest): } +class TestBedrockCohereRerank(BaseLLMRerankTest): + def get_custom_llm_provider(self) -> litellm.LlmProviders: + return litellm.LlmProviders.BEDROCK + + def get_base_rerank_call_args(self) -> dict: + return { + "model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0", + } + + @pytest.mark.parametrize( "messages, continue_message_index", [ diff --git a/tests/llm_translation/test_rerank.py b/tests/llm_translation/test_rerank.py index 82efa92dfd..d9bcc7ee2d 100644 --- a/tests/llm_translation/test_rerank.py +++ b/tests/llm_translation/test_rerank.py @@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider): @pytest.mark.asyncio() @pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.flaky(retries=3, delay=1) async def test_basic_rerank(sync_mode): litellm.set_verbose = True if sync_mode is True: @@ -311,6 +312,7 @@ def test_complete_base_url_cohere(): (3, None, False), ], ) +@pytest.mark.flaky(retries=3, delay=1) async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit): from litellm.caching.caching import Cache diff --git a/tests/local_testing/test_completion_cost.py b/tests/local_testing/test_completion_cost.py index 124a222fda..978ff75a4c 100644 --- a/tests/local_testing/test_completion_cost.py +++ b/tests/local_testing/test_completion_cost.py @@ -1574,7 +1574,11 @@ def test_completion_cost_azure_ai_rerank(model): "relevance_score": 0.990732, }, ], - meta={}, + meta={ + "billed_units": { + "search_units": 1, + } + }, ) print("response", response) model = model