add cost tracking for rerank+ test

This commit is contained in:
Ishaan Jaff 2024-09-06 16:06:19 -07:00
parent 186f3dad7f
commit 8bd57b6167
2 changed files with 12 additions and 1 deletions

View file

@ -30,6 +30,7 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import ( from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging, redact_message_input_output_from_logging,
) )
from litellm.rerank_api.types import RerankResponse
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import ( from litellm.types.utils import (
@ -524,6 +525,7 @@ class Logging:
TranscriptionResponse, TranscriptionResponse,
TextCompletionResponse, TextCompletionResponse,
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
RerankResponse,
], ],
cache_hit: Optional[bool] = None, cache_hit: Optional[bool] = None,
): ):
@ -585,6 +587,7 @@ class Logging:
or isinstance(result, TranscriptionResponse) or isinstance(result, TranscriptionResponse)
or isinstance(result, TextCompletionResponse) or isinstance(result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts or isinstance(result, HttpxBinaryResponseContent) # tts
or isinstance(result, RerankResponse)
): ):
## RESPONSE COST ## ## RESPONSE COST ##
self.model_call_details["response_cost"] = ( self.model_call_details["response_cost"] = (

View file

@ -1,3 +1,4 @@
import asyncio
import json import json
import os import os
import sys import sys
@ -184,11 +185,14 @@ class TestLogger(CustomLogger):
def __init__(self): def __init__(self):
self.kwargs = None self.kwargs = None
self.response_obj = None
super().__init__() super().__init__()
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print("in success event for rerank, kwargs = ", kwargs) print("in success event for rerank, kwargs = ", kwargs)
print("in success event for rerank, response_obj = ", response_obj)
self.kwargs = kwargs self.kwargs = kwargs
self.response_obj = response_obj
@pytest.mark.asyncio() @pytest.mark.asyncio()
@ -202,6 +206,10 @@ async def test_rerank_custom_callbacks():
top_n=3, top_n=3,
) )
assert self.kwargs is not None await asyncio.sleep(5)
print("async re rank response: ", response) print("async re rank response: ", response)
assert custom_logger.kwargs is not None
assert custom_logger.kwargs.get("response_cost") > 0.0
assert custom_logger.response_obj is not None
assert custom_logger.response_obj.results is not None