basic cohere rerank logging

This commit is contained in:
Ishaan Jaff 2024-09-06 15:18:16 -07:00
parent 4626c5a365
commit 1852e1cd9a
3 changed files with 32 additions and 1 deletions

View file

@ -9,7 +9,7 @@ from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.togetherai.rerank import TogetherAIRerank from litellm.llms.togetherai.rerank import TogetherAIRerank
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import supports_httpx_timeout from litellm.utils import client, supports_httpx_timeout
from .types import RerankRequest, RerankResponse from .types import RerankRequest, RerankResponse
@ -20,6 +20,7 @@ together_rerank = TogetherAIRerank()
################################################# #################################################
@client
async def arerank( async def arerank(
model: str, model: str,
query: str, query: str,
@ -64,6 +65,7 @@ async def arerank(
raise e raise e
@client
def rerank( def rerank(
model: str, model: str,
query: str, query: str,

View file

@ -20,6 +20,7 @@ import pytest
import litellm import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
@ -177,3 +178,30 @@ async def test_rerank_custom_api_base():
assert response.results is not None assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere") assert_response_shape(response, custom_llm_provider="cohere")
class TestLogger(CustomLogger):
def __init__(self):
self.kwargs = None
super().__init__()
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print("in success event for rerank, kwargs = ", kwargs)
self.kwargs = kwargs
@pytest.mark.asyncio()
async def test_rerank_custom_callbacks():
custom_logger = TestLogger()
litellm.callbacks = [custom_logger]
response = await litellm.arerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=3,
)
assert self.kwargs is not None
print("async re rank response: ", response)

View file

@ -745,6 +745,7 @@ def client(original_function):
or kwargs.get("amoderation", False) == True or kwargs.get("amoderation", False) == True
or kwargs.get("atext_completion", False) == True or kwargs.get("atext_completion", False) == True
or kwargs.get("atranscription", False) == True or kwargs.get("atranscription", False) == True
or kwargs.get("arerank", False) == True
): ):
# [OPTIONAL] CHECK MAX RETRIES / REQUEST # [OPTIONAL] CHECK MAX RETRIES / REQUEST
if litellm.num_retries_per_request is not None: if litellm.num_retries_per_request is not None: