diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index a66a800026..1eb4d0eb94 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -22,6 +22,7 @@ from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_cha from litellm.llms.anthropic.cost_calculation import ( cost_per_token as anthropic_cost_per_token, ) +from litellm.rerank_api.types import RerankResponse from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import PassthroughCallTypes, Usage @@ -93,6 +94,8 @@ def cost_per_token( "transcription", "aspeech", "speech", + "rerank", + "arerank", ] = "completion", ) -> Tuple[float, float]: """ @@ -487,6 +490,8 @@ def completion_cost( "transcription", "aspeech", "speech", + "rerank", + "arerank", ] = "completion", ### REGION ### custom_llm_provider=None, @@ -747,6 +752,7 @@ def response_cost_calculator( TranscriptionResponse, TextCompletionResponse, HttpxBinaryResponseContent, + RerankResponse, ], model: str, custom_llm_provider: Optional[str], @@ -765,6 +771,8 @@ def response_cost_calculator( "transcription", "aspeech", "speech", + "rerank", + "arerank", ], optional_params: dict, cache_hit: Optional[bool] = None, @@ -789,6 +797,15 @@ def response_cost_calculator( call_type=call_type, custom_llm_provider=custom_llm_provider, ) + elif isinstance(response_object, RerankResponse) and ( + call_type == "arerank" or call_type == "rerank" + ): + response_cost = rerank_cost( + rerank_response=response_object, + model=model, + call_type=call_type, + custom_llm_provider=custom_llm_provider, + ) else: if custom_pricing is True: # override defaults if custom pricing is set base_model = model @@ -820,3 +837,28 @@ def response_cost_calculator( ) ) return None + + +def rerank_cost( + rerank_response: RerankResponse, + model: str, + call_type: Literal["rerank", "arerank"], + custom_llm_provider: Optional[str], +) -> float: + """ + Returns + - float or None: cost of response OR none if error. + """ + _, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model) + + try: + if custom_llm_provider == "cohere": + return 0.002 + raise ValueError( + f"invalid custom_llm_provider for rerank model: {model}, custom_llm_provider: {custom_llm_provider}" + ) + except Exception as e: + verbose_logger.exception( + f"litellm.cost_calculator.py::rerank_cost - Exception occurred - {str(e)}" + ) + raise e diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 7af0a1cade..f57dd3b812 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -30,6 +30,7 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, ) +from litellm.rerank_api.types import RerankResponse from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import ( @@ -525,6 +526,7 @@ class Logging: TranscriptionResponse, TextCompletionResponse, HttpxBinaryResponseContent, + RerankResponse, ], cache_hit: Optional[bool] = None, ): @@ -588,6 +590,7 @@ class Logging: or isinstance(result, TranscriptionResponse) or isinstance(result, TextCompletionResponse) or isinstance(result, HttpxBinaryResponseContent) # tts + or isinstance(result, RerankResponse) ): ## RESPONSE COST ## self.model_call_details["response_cost"] = ( diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 41de82ab66..462208cfcd 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -9,7 +9,7 @@ from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.togetherai.rerank import TogetherAIRerank from litellm.secret_managers.main import get_secret from litellm.types.router import * -from litellm.utils import supports_httpx_timeout +from litellm.utils import client, supports_httpx_timeout from .types import RerankRequest, RerankResponse @@ -20,6 +20,7 @@ together_rerank = TogetherAIRerank() ################################################# +@client async def arerank( model: str, query: str, @@ -64,6 +65,7 @@ async def arerank( raise e +@client def rerank( model: str, query: str, diff --git a/litellm/rerank_api/types.py b/litellm/rerank_api/types.py index 605e25a2ec..00cb32c180 100644 --- a/litellm/rerank_api/types.py +++ b/litellm/rerank_api/types.py @@ -23,3 +23,16 @@ class RerankResponse(BaseModel): id: str results: List[dict] # Contains index and relevance_score meta: dict # Contains api_version and billed_units + _hidden_params: dict = {} + + class Config: + underscore_attrs_are_private = True + + def __getitem__(self, key): + return self.__dict__[key] + + def get(self, key, default=None): + return self.__dict__.get(key, default) + + def __contains__(self, key): + return key in self.__dict__ diff --git a/litellm/tests/test_rerank.py b/litellm/tests/test_rerank.py index 4e70424bc3..4d0fdfb344 100644 --- a/litellm/tests/test_rerank.py +++ b/litellm/tests/test_rerank.py @@ -1,3 +1,4 @@ +import asyncio import json import os import sys @@ -20,6 +21,7 @@ import pytest import litellm from litellm import RateLimitError, Timeout, completion, completion_cost, embedding +from litellm.integrations.custom_logger import CustomLogger from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler @@ -177,3 +179,37 @@ async def test_rerank_custom_api_base(): assert response.results is not None assert_response_shape(response, custom_llm_provider="cohere") + + +class TestLogger(CustomLogger): + + def __init__(self): + self.kwargs = None + self.response_obj = None + super().__init__() + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print("in success event for rerank, kwargs = ", kwargs) + print("in success event for rerank, response_obj = ", response_obj) + self.kwargs = kwargs + self.response_obj = response_obj + + +@pytest.mark.asyncio() +async def test_rerank_custom_callbacks(): + custom_logger = TestLogger() + litellm.callbacks = [custom_logger] + response = await litellm.arerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + await asyncio.sleep(5) + + print("async re rank response: ", response) + assert custom_logger.kwargs is not None + assert custom_logger.kwargs.get("response_cost") > 0.0 + assert custom_logger.response_obj is not None + assert custom_logger.response_obj.results is not None diff --git a/litellm/utils.py b/litellm/utils.py index d8aa51bd51..6b7b94a70c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -745,6 +745,7 @@ def client(original_function): or kwargs.get("amoderation", False) == True or kwargs.get("atext_completion", False) == True or kwargs.get("atranscription", False) == True + or kwargs.get("arerank", False) == True ): # [OPTIONAL] CHECK MAX RETRIES / REQUEST if litellm.num_retries_per_request is not None: