mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
basic cohere rerank logging
This commit is contained in:
parent
4626c5a365
commit
1852e1cd9a
3 changed files with 32 additions and 1 deletions
|
@ -9,7 +9,7 @@ from litellm.llms.cohere.rerank import CohereRerank
|
|||
from litellm.llms.togetherai.rerank import TogetherAIRerank
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
from litellm.utils import client, supports_httpx_timeout
|
||||
|
||||
from .types import RerankRequest, RerankResponse
|
||||
|
||||
|
@ -20,6 +20,7 @@ together_rerank = TogetherAIRerank()
|
|||
#################################################
|
||||
|
||||
|
||||
@client
|
||||
async def arerank(
|
||||
model: str,
|
||||
query: str,
|
||||
|
@ -64,6 +65,7 @@ async def arerank(
|
|||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def rerank(
|
||||
model: str,
|
||||
query: str,
|
||||
|
|
|
@ -20,6 +20,7 @@ import pytest
|
|||
|
||||
import litellm
|
||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
|
||||
|
@ -177,3 +178,30 @@ async def test_rerank_custom_api_base():
|
|||
assert response.results is not None
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="cohere")
|
||||
|
||||
|
||||
class TestLogger(CustomLogger):
|
||||
|
||||
def __init__(self):
|
||||
self.kwargs = None
|
||||
super().__init__()
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print("in success event for rerank, kwargs = ", kwargs)
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_rerank_custom_callbacks():
|
||||
custom_logger = TestLogger()
|
||||
litellm.callbacks = [custom_logger]
|
||||
response = await litellm.arerank(
|
||||
model="cohere/rerank-english-v3.0",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
assert self.kwargs is not None
|
||||
|
||||
print("async re rank response: ", response)
|
||||
|
|
|
@ -745,6 +745,7 @@ def client(original_function):
|
|||
or kwargs.get("amoderation", False) == True
|
||||
or kwargs.get("atext_completion", False) == True
|
||||
or kwargs.get("atranscription", False) == True
|
||||
or kwargs.get("arerank", False) == True
|
||||
):
|
||||
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
|
||||
if litellm.num_retries_per_request is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue