diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 41de82ab66..462208cfcd 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -9,7 +9,7 @@ from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.togetherai.rerank import TogetherAIRerank from litellm.secret_managers.main import get_secret from litellm.types.router import * -from litellm.utils import supports_httpx_timeout +from litellm.utils import client, supports_httpx_timeout from .types import RerankRequest, RerankResponse @@ -20,6 +20,7 @@ together_rerank = TogetherAIRerank() ################################################# +@client async def arerank( model: str, query: str, @@ -64,6 +65,7 @@ async def arerank( raise e +@client def rerank( model: str, query: str, diff --git a/litellm/tests/test_rerank.py b/litellm/tests/test_rerank.py index 4e70424bc3..c3d6faed42 100644 --- a/litellm/tests/test_rerank.py +++ b/litellm/tests/test_rerank.py @@ -20,6 +20,7 @@ import pytest import litellm from litellm import RateLimitError, Timeout, completion, completion_cost, embedding +from litellm.integrations.custom_logger import CustomLogger from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler @@ -177,3 +178,30 @@ async def test_rerank_custom_api_base(): assert response.results is not None assert_response_shape(response, custom_llm_provider="cohere") + + +class TestLogger(CustomLogger): + + def __init__(self): + self.kwargs = None + super().__init__() + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print("in success event for rerank, kwargs = ", kwargs) + self.kwargs = kwargs + + +@pytest.mark.asyncio() +async def test_rerank_custom_callbacks(): + custom_logger = TestLogger() + litellm.callbacks = [custom_logger] + response = await litellm.arerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + assert self.kwargs is not None + + print("async re rank response: ", response) diff --git a/litellm/utils.py b/litellm/utils.py index c362a7b5a0..aecf2de4c8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -745,6 +745,7 @@ def client(original_function): or kwargs.get("amoderation", False) == True or kwargs.get("atext_completion", False) == True or kwargs.get("atranscription", False) == True + or kwargs.get("arerank", False) == True ): # [OPTIONAL] CHECK MAX RETRIES / REQUEST if litellm.num_retries_per_request is not None: