mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
basic cohere rerank logging
This commit is contained in:
parent
4626c5a365
commit
1852e1cd9a
3 changed files with 32 additions and 1 deletions
|
@ -9,7 +9,7 @@ from litellm.llms.cohere.rerank import CohereRerank
|
||||||
from litellm.llms.togetherai.rerank import TogetherAIRerank
|
from litellm.llms.togetherai.rerank import TogetherAIRerank
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
from litellm.types.router import *
|
from litellm.types.router import *
|
||||||
from litellm.utils import supports_httpx_timeout
|
from litellm.utils import client, supports_httpx_timeout
|
||||||
|
|
||||||
from .types import RerankRequest, RerankResponse
|
from .types import RerankRequest, RerankResponse
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ together_rerank = TogetherAIRerank()
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
@client
|
||||||
async def arerank(
|
async def arerank(
|
||||||
model: str,
|
model: str,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -64,6 +65,7 @@ async def arerank(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@client
|
||||||
def rerank(
|
def rerank(
|
||||||
model: str,
|
model: str,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
|
@ -20,6 +20,7 @@ import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,3 +178,30 @@ async def test_rerank_custom_api_base():
|
||||||
assert response.results is not None
|
assert response.results is not None
|
||||||
|
|
||||||
assert_response_shape(response, custom_llm_provider="cohere")
|
assert_response_shape(response, custom_llm_provider="cohere")
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogger(CustomLogger):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.kwargs = None
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print("in success event for rerank, kwargs = ", kwargs)
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_rerank_custom_callbacks():
|
||||||
|
custom_logger = TestLogger()
|
||||||
|
litellm.callbacks = [custom_logger]
|
||||||
|
response = await litellm.arerank(
|
||||||
|
model="cohere/rerank-english-v3.0",
|
||||||
|
query="hello",
|
||||||
|
documents=["hello", "world"],
|
||||||
|
top_n=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.kwargs is not None
|
||||||
|
|
||||||
|
print("async re rank response: ", response)
|
||||||
|
|
|
@ -745,6 +745,7 @@ def client(original_function):
|
||||||
or kwargs.get("amoderation", False) == True
|
or kwargs.get("amoderation", False) == True
|
||||||
or kwargs.get("atext_completion", False) == True
|
or kwargs.get("atext_completion", False) == True
|
||||||
or kwargs.get("atranscription", False) == True
|
or kwargs.get("atranscription", False) == True
|
||||||
|
or kwargs.get("arerank", False) == True
|
||||||
):
|
):
|
||||||
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
|
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
|
||||||
if litellm.num_retries_per_request is not None:
|
if litellm.num_retries_per_request is not None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue