fix: support streaming custom cost completion tracking

This commit is contained in:
Krrish Dholakia 2024-01-22 13:41:22 -08:00
parent 82bbf336d5
commit 074ea17325
4 changed files with 58 additions and 11 deletions

View file

@ -1105,7 +1105,7 @@ class Logging:
self.sync_streaming_chunks.append(result)
if complete_streaming_response:
verbose_logger.info(
verbose_logger.debug(
f"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details[
@ -1305,7 +1305,9 @@ class Logging:
)
== False
): # custom logger class
print_verbose(f"success callbacks: Running Custom Logger Class")
verbose_logger.info(
f"success callbacks: Running SYNC Custom Logger Class"
)
if self.stream and complete_streaming_response is None:
callback.log_stream_event(
kwargs=self.model_call_details,
@ -1327,7 +1329,17 @@ class Logging:
start_time=start_time,
end_time=end_time,
)
if callable(callback): # custom logger functions
elif (
callable(callback) == True
and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False
)
== False
): # custom logger functions
print_verbose(
f"success callbacks: Running Custom Callback Function"
)
@ -1362,6 +1374,9 @@ class Logging:
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None
if self.stream:
@ -1372,6 +1387,8 @@ class Logging:
complete_streaming_response = litellm.stream_chunk_builder(
self.streaming_chunks,
messages=self.model_call_details.get("messages", None),
start_time=start_time,
end_time=end_time,
)
except Exception as e:
print_verbose(
@ -1385,9 +1402,7 @@ class Logging:
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
for callback in litellm._async_success_callback:
try:
if callback == "cache" and litellm.cache is not None:
@ -1434,7 +1449,6 @@ class Logging:
end_time=end_time,
)
if callable(callback): # custom logger functions
print_verbose(f"Async success callbacks: async_log_event")
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
@ -2835,6 +2849,7 @@ def cost_per_token(
verbose_logger.debug(f"Looking up model={model} in model_cost_map")
if model in model_cost_ref:
verbose_logger.debug(f"Success: model={model} in model_cost_map")
if (
model_cost_ref[model].get("input_cost_per_token", None) is not None
and model_cost_ref[model].get("output_cost_per_token", None) is not None
@ -2850,11 +2865,17 @@ def cost_per_token(
model_cost_ref[model].get("input_cost_per_second", None) is not None
and response_time_ms is not None
):
verbose_logger.debug(
f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ##
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000
)
completion_tokens_cost_usd_dollar = 0.0
verbose_logger.debug(
f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model_with_provider in model_cost_ref:
print_verbose(f"Looking up model={model_with_provider} in model_cost_map")
@ -2957,6 +2978,9 @@ def completion_cost(
"completion_tokens", 0
)
total_time = completion_response.get("_response_ms", 0)
verbose_logger.debug(
f"completion_response response ms: {completion_response.get('_response_ms')} "
)
model = (
model or completion_response["model"]
) # check if user passed an override for model, if it's none check completion_response['model']
@ -3026,6 +3050,7 @@ def register_model(model_cost: Union[str, dict]):
for key, value in loaded_model_cost.items():
## override / add new keys to the existing model cost dictionary
litellm.model_cost.setdefault(key, {}).update(value)
verbose_logger.debug(f"{key} added to model cost map")
# add new model names to provider lists
if value.get("litellm_provider") == "openai":
if key not in litellm.open_ai_chat_completion_models:
@ -5170,6 +5195,8 @@ def convert_to_model_response_object(
"completion", "embedding", "image_generation"
] = "completion",
stream=False,
start_time=None,
end_time=None,
):
try:
if response_type == "completion" and (
@ -5223,6 +5250,12 @@ def convert_to_model_response_object(
if "model" in response_object:
model_response_object.model = response_object["model"]
if start_time is not None and end_time is not None:
model_response_object._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
return model_response_object
elif response_type == "embedding" and (
model_response_object is None
@ -5247,6 +5280,11 @@ def convert_to_model_response_object(
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
if start_time is not None and end_time is not None:
model_response_object._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
return model_response_object
elif response_type == "image_generation" and (
model_response_object is None