mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(utils.py): fix cost tracking for cache hits (should be 0)
This commit is contained in:
parent
62ad6f19b7
commit
1ed6842009
2 changed files with 129 additions and 19 deletions
|
@ -31,6 +31,7 @@ class MyCustomHandler(CustomLogger):
|
||||||
self.sync_stream_collected_response = None # type: ignore
|
self.sync_stream_collected_response = None # type: ignore
|
||||||
self.user = None # type: ignore
|
self.user = None # type: ignore
|
||||||
self.data_sent_to_api: dict = {}
|
self.data_sent_to_api: dict = {}
|
||||||
|
self.response_cost = 0
|
||||||
|
|
||||||
def log_pre_api_call(self, model, messages, kwargs):
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
print(f"Pre-API Call")
|
print(f"Pre-API Call")
|
||||||
|
@ -47,6 +48,8 @@ class MyCustomHandler(CustomLogger):
|
||||||
self.success = True
|
self.success = True
|
||||||
if kwargs.get("stream") == True:
|
if kwargs.get("stream") == True:
|
||||||
self.sync_stream_collected_response = response_obj
|
self.sync_stream_collected_response = response_obj
|
||||||
|
print(f"response cost in log_success_event: {kwargs.get('response_cost')}")
|
||||||
|
self.response_cost = kwargs.get("response_cost", 0)
|
||||||
|
|
||||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Failure")
|
print(f"On Failure")
|
||||||
|
@ -64,6 +67,10 @@ class MyCustomHandler(CustomLogger):
|
||||||
self.stream_collected_response = response_obj
|
self.stream_collected_response = response_obj
|
||||||
self.async_completion_kwargs = kwargs
|
self.async_completion_kwargs = kwargs
|
||||||
self.user = kwargs.get("user", None)
|
self.user = kwargs.get("user", None)
|
||||||
|
print(
|
||||||
|
f"response cost in log_async_success_event: {kwargs.get('response_cost')}"
|
||||||
|
)
|
||||||
|
self.response_cost = kwargs.get("response_cost", 0)
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Async Failure")
|
print(f"On Async Failure")
|
||||||
|
@ -400,6 +407,50 @@ async def test_async_custom_handler_embedding_optional_param_bedrock():
|
||||||
assert "user" not in customHandler_optional_params.data_sent_to_api
|
assert "user" not in customHandler_optional_params.data_sent_to_api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cost_tracking_with_caching():
|
||||||
|
"""
|
||||||
|
Important Test - This tests if that cost is 0 for cached responses
|
||||||
|
"""
|
||||||
|
from litellm import Cache
|
||||||
|
|
||||||
|
litellm.set_verbose = False
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
)
|
||||||
|
customHandler_optional_params = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_optional_params]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"write a one sentence poem about: {time.time()}",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response1 = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=40,
|
||||||
|
temperature=0.2,
|
||||||
|
caching=True,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1) # success callback is async
|
||||||
|
response_cost = customHandler_optional_params.response_cost
|
||||||
|
assert response_cost > 0
|
||||||
|
response2 = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=40,
|
||||||
|
temperature=0.2,
|
||||||
|
caching=True,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1) # success callback is async
|
||||||
|
response_cost_2 = customHandler_optional_params.response_cost
|
||||||
|
assert response_cost_2 == 0
|
||||||
|
|
||||||
|
|
||||||
def test_redis_cache_completion_stream():
|
def test_redis_cache_completion_stream():
|
||||||
from litellm import Cache
|
from litellm import Cache
|
||||||
|
|
||||||
|
|
|
@ -1060,7 +1060,12 @@ class Logging:
|
||||||
and self.stream != True
|
and self.stream != True
|
||||||
): # handle streaming separately
|
): # handle streaming separately
|
||||||
try:
|
try:
|
||||||
self.model_call_details["response_cost"] = litellm.completion_cost(
|
if self.model_call_details.get("cache_hit", False) == True:
|
||||||
|
self.model_call_details["response_cost"] = 0.0
|
||||||
|
else:
|
||||||
|
self.model_call_details[
|
||||||
|
"response_cost"
|
||||||
|
] = litellm.completion_cost(
|
||||||
completion_response=result,
|
completion_response=result,
|
||||||
)
|
)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -1096,7 +1101,7 @@ class Logging:
|
||||||
def success_handler(
|
def success_handler(
|
||||||
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
|
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
|
||||||
):
|
):
|
||||||
verbose_logger.debug(f"Logging Details LiteLLM-Success Call")
|
verbose_logger.debug(f"Logging Details LiteLLM-Success Call: {cache_hit}")
|
||||||
start_time, end_time, result = self._success_handler_helper_fn(
|
start_time, end_time, result = self._success_handler_helper_fn(
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
|
@ -1134,7 +1139,12 @@ class Logging:
|
||||||
"complete_streaming_response"
|
"complete_streaming_response"
|
||||||
] = complete_streaming_response
|
] = complete_streaming_response
|
||||||
try:
|
try:
|
||||||
self.model_call_details["response_cost"] = litellm.completion_cost(
|
if self.model_call_details.get("cache_hit", False) == True:
|
||||||
|
self.model_call_details["response_cost"] = 0.0
|
||||||
|
else:
|
||||||
|
self.model_call_details[
|
||||||
|
"response_cost"
|
||||||
|
] = litellm.completion_cost(
|
||||||
completion_response=complete_streaming_response,
|
completion_response=complete_streaming_response,
|
||||||
)
|
)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -1158,6 +1168,7 @@ class Logging:
|
||||||
callbacks.append(callback)
|
callbacks.append(callback)
|
||||||
else:
|
else:
|
||||||
callbacks = litellm.success_callback
|
callbacks = litellm.success_callback
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if callback == "lite_debugger":
|
if callback == "lite_debugger":
|
||||||
|
@ -1342,7 +1353,7 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
elif (
|
if (
|
||||||
isinstance(callback, CustomLogger)
|
isinstance(callback, CustomLogger)
|
||||||
and self.model_call_details.get("litellm_params", {}).get(
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
"acompletion", False
|
"acompletion", False
|
||||||
|
@ -1353,9 +1364,6 @@ class Logging:
|
||||||
)
|
)
|
||||||
== False
|
== False
|
||||||
): # custom logger class
|
): # custom logger class
|
||||||
verbose_logger.info(
|
|
||||||
f"success callbacks: Running SYNC Custom Logger Class"
|
|
||||||
)
|
|
||||||
if self.stream and complete_streaming_response is None:
|
if self.stream and complete_streaming_response is None:
|
||||||
callback.log_stream_event(
|
callback.log_stream_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
@ -1377,7 +1385,7 @@ class Logging:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
elif (
|
if (
|
||||||
callable(callback) == True
|
callable(callback) == True
|
||||||
and self.model_call_details.get("litellm_params", {}).get(
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
"acompletion", False
|
"acompletion", False
|
||||||
|
@ -1452,6 +1460,9 @@ class Logging:
|
||||||
"complete_streaming_response"
|
"complete_streaming_response"
|
||||||
] = complete_streaming_response
|
] = complete_streaming_response
|
||||||
try:
|
try:
|
||||||
|
if self.model_call_details.get("cache_hit", False) == True:
|
||||||
|
self.model_call_details["response_cost"] = 0.0
|
||||||
|
else:
|
||||||
self.model_call_details["response_cost"] = litellm.completion_cost(
|
self.model_call_details["response_cost"] = litellm.completion_cost(
|
||||||
completion_response=complete_streaming_response,
|
completion_response=complete_streaming_response,
|
||||||
)
|
)
|
||||||
|
@ -2217,7 +2228,7 @@ def client(original_function):
|
||||||
if call_type == CallTypes.completion.value and isinstance(
|
if call_type == CallTypes.completion.value and isinstance(
|
||||||
cached_result, dict
|
cached_result, dict
|
||||||
):
|
):
|
||||||
return convert_to_model_response_object(
|
cached_result = convert_to_model_response_object(
|
||||||
response_object=cached_result,
|
response_object=cached_result,
|
||||||
model_response_object=ModelResponse(),
|
model_response_object=ModelResponse(),
|
||||||
stream=kwargs.get("stream", False),
|
stream=kwargs.get("stream", False),
|
||||||
|
@ -2225,11 +2236,59 @@ def client(original_function):
|
||||||
elif call_type == CallTypes.embedding.value and isinstance(
|
elif call_type == CallTypes.embedding.value and isinstance(
|
||||||
cached_result, dict
|
cached_result, dict
|
||||||
):
|
):
|
||||||
return convert_to_model_response_object(
|
cached_result = convert_to_model_response_object(
|
||||||
response_object=cached_result,
|
response_object=cached_result,
|
||||||
response_type="embedding",
|
response_type="embedding",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
# LOG SUCCESS
|
||||||
|
cache_hit = True
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
(
|
||||||
|
model,
|
||||||
|
custom_llm_provider,
|
||||||
|
dynamic_api_key,
|
||||||
|
api_base,
|
||||||
|
) = litellm.get_llm_provider(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=kwargs.get(
|
||||||
|
"custom_llm_provider", None
|
||||||
|
),
|
||||||
|
api_base=kwargs.get("api_base", None),
|
||||||
|
api_key=kwargs.get("api_key", None),
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
||||||
|
)
|
||||||
|
logging_obj.update_environment_variables(
|
||||||
|
model=model,
|
||||||
|
user=kwargs.get("user", None),
|
||||||
|
optional_params={},
|
||||||
|
litellm_params={
|
||||||
|
"logger_fn": kwargs.get("logger_fn", None),
|
||||||
|
"acompletion": False,
|
||||||
|
"metadata": kwargs.get("metadata", {}),
|
||||||
|
"model_info": kwargs.get("model_info", {}),
|
||||||
|
"proxy_server_request": kwargs.get(
|
||||||
|
"proxy_server_request", None
|
||||||
|
),
|
||||||
|
"preset_cache_key": kwargs.get(
|
||||||
|
"preset_cache_key", None
|
||||||
|
),
|
||||||
|
"stream_response": kwargs.get(
|
||||||
|
"stream_response", {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
input=kwargs.get("messages", ""),
|
||||||
|
api_key=kwargs.get("api_key", None),
|
||||||
|
original_response=str(cached_result),
|
||||||
|
additional_args=None,
|
||||||
|
stream=kwargs.get("stream", False),
|
||||||
|
)
|
||||||
|
threading.Thread(
|
||||||
|
target=logging_obj.success_handler,
|
||||||
|
args=(cached_result, start_time, end_time, cache_hit),
|
||||||
|
).start()
|
||||||
return cached_result
|
return cached_result
|
||||||
|
|
||||||
# CHECK MAX TOKENS
|
# CHECK MAX TOKENS
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue