mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
add cost tracking for rerank+ test
This commit is contained in:
parent
186f3dad7f
commit
8bd57b6167
2 changed files with 12 additions and 1 deletions
|
@ -30,6 +30,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.redact_messages import (
|
from litellm.litellm_core_utils.redact_messages import (
|
||||||
redact_message_input_output_from_logging,
|
redact_message_input_output_from_logging,
|
||||||
)
|
)
|
||||||
|
from litellm.rerank_api.types import RerankResponse
|
||||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
|
@ -524,6 +525,7 @@ class Logging:
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
|
RerankResponse,
|
||||||
],
|
],
|
||||||
cache_hit: Optional[bool] = None,
|
cache_hit: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
|
@ -585,6 +587,7 @@ class Logging:
|
||||||
or isinstance(result, TranscriptionResponse)
|
or isinstance(result, TranscriptionResponse)
|
||||||
or isinstance(result, TextCompletionResponse)
|
or isinstance(result, TextCompletionResponse)
|
||||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
or isinstance(result, HttpxBinaryResponseContent) # tts
|
||||||
|
or isinstance(result, RerankResponse)
|
||||||
):
|
):
|
||||||
## RESPONSE COST ##
|
## RESPONSE COST ##
|
||||||
self.model_call_details["response_cost"] = (
|
self.model_call_details["response_cost"] = (
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
@ -184,11 +185,14 @@ class TestLogger(CustomLogger):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.kwargs = None
|
self.kwargs = None
|
||||||
|
self.response_obj = None
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print("in success event for rerank, kwargs = ", kwargs)
|
print("in success event for rerank, kwargs = ", kwargs)
|
||||||
|
print("in success event for rerank, response_obj = ", response_obj)
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
self.response_obj = response_obj
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
|
@ -202,6 +206,10 @@ async def test_rerank_custom_callbacks():
|
||||||
top_n=3,
|
top_n=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert self.kwargs is not None
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
print("async re rank response: ", response)
|
print("async re rank response: ", response)
|
||||||
|
assert custom_logger.kwargs is not None
|
||||||
|
assert custom_logger.kwargs.get("response_cost") > 0.0
|
||||||
|
assert custom_logger.response_obj is not None
|
||||||
|
assert custom_logger.response_obj.results is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue