add cost tracking for rerank+ test

This commit is contained in:
Ishaan Jaff 2024-09-06 16:06:19 -07:00
parent 186f3dad7f
commit 8bd57b6167
2 changed files with 12 additions and 1 deletions

View file

@ -30,6 +30,7 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
)
from litellm.rerank_api.types import RerankResponse
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import (
@ -524,6 +525,7 @@ class Logging:
TranscriptionResponse,
TextCompletionResponse,
HttpxBinaryResponseContent,
RerankResponse,
],
cache_hit: Optional[bool] = None,
):
@ -585,6 +587,7 @@ class Logging:
or isinstance(result, TranscriptionResponse)
or isinstance(result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts
or isinstance(result, RerankResponse)
):
## RESPONSE COST ##
self.model_call_details["response_cost"] = (

View file

@ -1,3 +1,4 @@
import asyncio
import json
import os
import sys
@ -184,11 +185,14 @@ class TestLogger(CustomLogger):
def __init__(self):
self.kwargs = None
self.response_obj = None
super().__init__()
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print("in success event for rerank, kwargs = ", kwargs)
print("in success event for rerank, response_obj = ", response_obj)
self.kwargs = kwargs
self.response_obj = response_obj
@pytest.mark.asyncio()
@ -202,6 +206,10 @@ async def test_rerank_custom_callbacks():
top_n=3,
)
assert self.kwargs is not None
await asyncio.sleep(5)
print("async re rank response: ", response)
assert custom_logger.kwargs is not None
assert custom_logger.kwargs.get("response_cost") > 0.0
assert custom_logger.response_obj is not None
assert custom_logger.response_obj.results is not None