From 1852e1cd9a4e2bcbd979299b6df3086fc596aec9 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 15:18:16 -0700 Subject: [PATCH 1/5] basic cohere rerank logging --- litellm/rerank_api/main.py | 4 +++- litellm/tests/test_rerank.py | 28 ++++++++++++++++++++++++++++ litellm/utils.py | 1 + 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 41de82ab6..462208cfc 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -9,7 +9,7 @@ from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.togetherai.rerank import TogetherAIRerank from litellm.secret_managers.main import get_secret from litellm.types.router import * -from litellm.utils import supports_httpx_timeout +from litellm.utils import client, supports_httpx_timeout from .types import RerankRequest, RerankResponse @@ -20,6 +20,7 @@ together_rerank = TogetherAIRerank() ################################################# +@client async def arerank( model: str, query: str, @@ -64,6 +65,7 @@ async def arerank( raise e +@client def rerank( model: str, query: str, diff --git a/litellm/tests/test_rerank.py b/litellm/tests/test_rerank.py index 4e70424bc..c3d6faed4 100644 --- a/litellm/tests/test_rerank.py +++ b/litellm/tests/test_rerank.py @@ -20,6 +20,7 @@ import pytest import litellm from litellm import RateLimitError, Timeout, completion, completion_cost, embedding +from litellm.integrations.custom_logger import CustomLogger from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler @@ -177,3 +178,30 @@ async def test_rerank_custom_api_base(): assert response.results is not None assert_response_shape(response, custom_llm_provider="cohere") + + +class TestLogger(CustomLogger): + + def __init__(self): + self.kwargs = None + super().__init__() + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print("in success event for rerank, kwargs = ", kwargs) + self.kwargs = kwargs + + +@pytest.mark.asyncio() +async def test_rerank_custom_callbacks(): + custom_logger = TestLogger() + litellm.callbacks = [custom_logger] + response = await litellm.arerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + assert self.kwargs is not None + + print("async re rank response: ", response) diff --git a/litellm/utils.py b/litellm/utils.py index c362a7b5a..aecf2de4c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -745,6 +745,7 @@ def client(original_function): or kwargs.get("amoderation", False) == True or kwargs.get("atext_completion", False) == True or kwargs.get("atranscription", False) == True + or kwargs.get("arerank", False) == True ): # [OPTIONAL] CHECK MAX RETRIES / REQUEST if litellm.num_retries_per_request is not None: From e095daf2e46c9eee017d5526f80084f1173d614b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 16:04:54 -0700 Subject: [PATCH 2/5] add cost tracking for rerank --- litellm/cost_calculator.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index a66a80002..30d3ed036 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -22,6 +22,7 @@ from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_cha from litellm.llms.anthropic.cost_calculation import ( cost_per_token as anthropic_cost_per_token, ) +from litellm.rerank_api.types import RerankResponse from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import PassthroughCallTypes, Usage @@ -747,6 +748,7 @@ def response_cost_calculator( TranscriptionResponse, TextCompletionResponse, HttpxBinaryResponseContent, + RerankResponse, ], model: str, custom_llm_provider: Optional[str], @@ -765,6 +767,8 @@ def response_cost_calculator( "transcription", "aspeech", "speech", + "rerank", + "arerank", ], optional_params: dict, cache_hit: Optional[bool] = None, @@ -789,6 +793,13 @@ def response_cost_calculator( call_type=call_type, custom_llm_provider=custom_llm_provider, ) + elif isinstance(response_object, RerankResponse): + response_cost = rerank_cost( + rerank_response=response_object, + model=model, + call_type=call_type, + custom_llm_provider=custom_llm_provider, + ) else: if custom_pricing is True: # override defaults if custom pricing is set base_model = model @@ -820,3 +831,22 @@ def response_cost_calculator( ) ) return None + + +def rerank_cost( + rerank_response: RerankResponse, + model: str, + call_type: Literal["rerank", "arerank"], + custom_llm_provider: Optional[str], +) -> Optional[float]: + """ + Returns + - float or None: cost of response OR none if error. + """ + _, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model) + + try: + if custom_llm_provider == "cohere": + return 0.002 + except Exception as e: + raise e From d1342c59917acdc633938d1ff4e36a95638e7b23 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 16:05:19 -0700 Subject: [PATCH 3/5] fix RerankResponse type --- litellm/rerank_api/types.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/litellm/rerank_api/types.py b/litellm/rerank_api/types.py index 605e25a2e..3b95597c3 100644 --- a/litellm/rerank_api/types.py +++ b/litellm/rerank_api/types.py @@ -23,3 +23,16 @@ class RerankResponse(BaseModel): id: str results: List[dict] # Contains index and relevance_score meta: dict # Contains api_version and billed_units + _hidden_params: Optional[dict] = {} + + class Config: + underscore_attrs_are_private = True + + def __getitem__(self, key): + return self.__dict__[key] + + def get(self, key, default=None): + return self.__dict__.get(key, default) + + def __contains__(self, key): + return key in self.__dict__ From b659095f71a40cb6a4ad85bd215c5ed8c8067381 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 16:06:19 -0700 Subject: [PATCH 4/5] add cost tracking for rerank+ test --- litellm/litellm_core_utils/litellm_logging.py | 3 +++ litellm/tests/test_rerank.py | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 2ea3f23d3..486347a7a 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -30,6 +30,7 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, ) +from litellm.rerank_api.types import RerankResponse from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import ( @@ -524,6 +525,7 @@ class Logging: TranscriptionResponse, TextCompletionResponse, HttpxBinaryResponseContent, + RerankResponse, ], cache_hit: Optional[bool] = None, ): @@ -585,6 +587,7 @@ class Logging: or isinstance(result, TranscriptionResponse) or isinstance(result, TextCompletionResponse) or isinstance(result, HttpxBinaryResponseContent) # tts + or isinstance(result, RerankResponse) ): ## RESPONSE COST ## self.model_call_details["response_cost"] = ( diff --git a/litellm/tests/test_rerank.py b/litellm/tests/test_rerank.py index c3d6faed4..4d0fdfb34 100644 --- a/litellm/tests/test_rerank.py +++ b/litellm/tests/test_rerank.py @@ -1,3 +1,4 @@ +import asyncio import json import os import sys @@ -184,11 +185,14 @@ class TestLogger(CustomLogger): def __init__(self): self.kwargs = None + self.response_obj = None super().__init__() async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): print("in success event for rerank, kwargs = ", kwargs) + print("in success event for rerank, response_obj = ", response_obj) self.kwargs = kwargs + self.response_obj = response_obj @pytest.mark.asyncio() @@ -202,6 +206,10 @@ async def test_rerank_custom_callbacks(): top_n=3, ) - assert self.kwargs is not None + await asyncio.sleep(5) print("async re rank response: ", response) + assert custom_logger.kwargs is not None + assert custom_logger.kwargs.get("response_cost") > 0.0 + assert custom_logger.response_obj is not None + assert custom_logger.response_obj.results is not None From 3c16fcff1b0fb5e46d3f2f7755f428adf9ffe8e5 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 16:41:47 -0700 Subject: [PATCH 5/5] fix linting errors --- litellm/cost_calculator.py | 16 ++++++++++++++-- litellm/rerank_api/types.py | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 30d3ed036..1eb4d0eb9 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -94,6 +94,8 @@ def cost_per_token( "transcription", "aspeech", "speech", + "rerank", + "arerank", ] = "completion", ) -> Tuple[float, float]: """ @@ -488,6 +490,8 @@ def completion_cost( "transcription", "aspeech", "speech", + "rerank", + "arerank", ] = "completion", ### REGION ### custom_llm_provider=None, @@ -793,7 +797,9 @@ def response_cost_calculator( call_type=call_type, custom_llm_provider=custom_llm_provider, ) - elif isinstance(response_object, RerankResponse): + elif isinstance(response_object, RerankResponse) and ( + call_type == "arerank" or call_type == "rerank" + ): response_cost = rerank_cost( rerank_response=response_object, model=model, @@ -838,7 +844,7 @@ def rerank_cost( model: str, call_type: Literal["rerank", "arerank"], custom_llm_provider: Optional[str], -) -> Optional[float]: +) -> float: """ Returns - float or None: cost of response OR none if error. @@ -848,5 +854,11 @@ def rerank_cost( try: if custom_llm_provider == "cohere": return 0.002 + raise ValueError( + f"invalid custom_llm_provider for rerank model: {model}, custom_llm_provider: {custom_llm_provider}" + ) except Exception as e: + verbose_logger.exception( + f"litellm.cost_calculator.py::rerank_cost - Exception occurred - {str(e)}" + ) raise e diff --git a/litellm/rerank_api/types.py b/litellm/rerank_api/types.py index 3b95597c3..00cb32c18 100644 --- a/litellm/rerank_api/types.py +++ b/litellm/rerank_api/types.py @@ -23,7 +23,7 @@ class RerankResponse(BaseModel): id: str results: List[dict] # Contains index and relevance_score meta: dict # Contains api_version and billed_units - _hidden_params: Optional[dict] = {} + _hidden_params: dict = {} class Config: underscore_attrs_are_private = True