mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
add cost tracking for rerank+ test
This commit is contained in:
parent
186f3dad7f
commit
8bd57b6167
2 changed files with 12 additions and 1 deletions
|
@ -30,6 +30,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from litellm.litellm_core_utils.redact_messages import (
|
||||
redact_message_input_output_from_logging,
|
||||
)
|
||||
from litellm.rerank_api.types import RerankResponse
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||
from litellm.types.utils import (
|
||||
|
@ -524,6 +525,7 @@ class Logging:
|
|||
TranscriptionResponse,
|
||||
TextCompletionResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
RerankResponse,
|
||||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
):
|
||||
|
@ -585,6 +587,7 @@ class Logging:
|
|||
or isinstance(result, TranscriptionResponse)
|
||||
or isinstance(result, TextCompletionResponse)
|
||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
||||
or isinstance(result, RerankResponse)
|
||||
):
|
||||
## RESPONSE COST ##
|
||||
self.model_call_details["response_cost"] = (
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
@ -184,11 +185,14 @@ class TestLogger(CustomLogger):
|
|||
|
||||
def __init__(self):
|
||||
self.kwargs = None
|
||||
self.response_obj = None
|
||||
super().__init__()
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print("in success event for rerank, kwargs = ", kwargs)
|
||||
print("in success event for rerank, response_obj = ", response_obj)
|
||||
self.kwargs = kwargs
|
||||
self.response_obj = response_obj
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
|
@ -202,6 +206,10 @@ async def test_rerank_custom_callbacks():
|
|||
top_n=3,
|
||||
)
|
||||
|
||||
assert self.kwargs is not None
|
||||
await asyncio.sleep(5)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
assert custom_logger.kwargs is not None
|
||||
assert custom_logger.kwargs.get("response_cost") > 0.0
|
||||
assert custom_logger.response_obj is not None
|
||||
assert custom_logger.response_obj.results is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue