fix: support streaming custom cost completion tracking

This commit is contained in:
Krrish Dholakia 2024-01-22 13:41:22 -08:00
parent 82bbf336d5
commit 074ea17325
4 changed files with 58 additions and 11 deletions

View file

@ -3334,7 +3334,9 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
return response return response
def stream_chunk_builder(chunks: list, messages: Optional[list] = None): def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
):
model_response = litellm.ModelResponse() model_response = litellm.ModelResponse()
# set hidden params from chunk to model_response # set hidden params from chunk to model_response
if model_response is not None and hasattr(model_response, "_hidden_params"): if model_response is not None and hasattr(model_response, "_hidden_params"):
@ -3509,5 +3511,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
) )
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=response, model_response_object=model_response response_object=response,
model_response_object=model_response,
start_time=start_time,
end_time=end_time,
) )

View file

@ -577,7 +577,7 @@ async def track_cost_callback(
"user_api_key_user_id", None "user_api_key_user_id", None
) )
verbose_proxy_logger.debug( verbose_proxy_logger.info(
f"streaming response_cost {response_cost}, for user_id {user_id}" f"streaming response_cost {response_cost}, for user_id {user_id}"
) )
if user_api_key and ( if user_api_key and (
@ -602,7 +602,7 @@ async def track_cost_callback(
user_id = user_id or kwargs["litellm_params"]["metadata"].get( user_id = user_id or kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None "user_api_key_user_id", None
) )
verbose_proxy_logger.debug( verbose_proxy_logger.info(
f"response_cost {response_cost}, for user_id {user_id}" f"response_cost {response_cost}, for user_id {user_id}"
) )
if user_api_key and ( if user_api_key and (

View file

@ -449,6 +449,7 @@ class PrismaClient:
"update": {}, # don't do anything if it already exists "update": {}, # don't do anything if it already exists
}, },
) )
verbose_proxy_logger.info(f"Data Inserted into Keys Table")
return new_verification_token return new_verification_token
elif table_name == "user": elif table_name == "user":
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
@ -459,6 +460,7 @@ class PrismaClient:
"update": {}, # don't do anything if it already exists "update": {}, # don't do anything if it already exists
}, },
) )
verbose_proxy_logger.info(f"Data Inserted into User Table")
return new_user_row return new_user_row
elif table_name == "config": elif table_name == "config":
""" """
@ -483,6 +485,7 @@ class PrismaClient:
tasks.append(updated_table_row) tasks.append(updated_table_row)
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
verbose_proxy_logger.info(f"Data Inserted into Config Table")
elif table_name == "spend": elif table_name == "spend":
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
new_spend_row = await self.db.litellm_spendlogs.upsert( new_spend_row = await self.db.litellm_spendlogs.upsert(
@ -492,6 +495,7 @@ class PrismaClient:
"update": {}, # don't do anything if it already exists "update": {}, # don't do anything if it already exists
}, },
) )
verbose_proxy_logger.info(f"Data Inserted into Spend Table")
return new_spend_row return new_spend_row
except Exception as e: except Exception as e:

View file

@ -1105,7 +1105,7 @@ class Logging:
self.sync_streaming_chunks.append(result) self.sync_streaming_chunks.append(result)
if complete_streaming_response: if complete_streaming_response:
verbose_logger.info( verbose_logger.debug(
f"Logging Details LiteLLM-Success Call streaming complete" f"Logging Details LiteLLM-Success Call streaming complete"
) )
self.model_call_details[ self.model_call_details[
@ -1305,7 +1305,9 @@ class Logging:
) )
== False == False
): # custom logger class ): # custom logger class
print_verbose(f"success callbacks: Running Custom Logger Class") verbose_logger.info(
f"success callbacks: Running SYNC Custom Logger Class"
)
if self.stream and complete_streaming_response is None: if self.stream and complete_streaming_response is None:
callback.log_stream_event( callback.log_stream_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -1327,7 +1329,17 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
if callable(callback): # custom logger functions elif (
callable(callback) == True
and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False
)
== False
): # custom logger functions
print_verbose( print_verbose(
f"success callbacks: Running Custom Callback Function" f"success callbacks: Running Custom Callback Function"
) )
@ -1362,6 +1374,9 @@ class Logging:
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
print_verbose(f"Async success callbacks: {litellm._async_success_callback}") print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
## BUILD COMPLETE STREAMED RESPONSE ## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None complete_streaming_response = None
if self.stream: if self.stream:
@ -1372,6 +1387,8 @@ class Logging:
complete_streaming_response = litellm.stream_chunk_builder( complete_streaming_response = litellm.stream_chunk_builder(
self.streaming_chunks, self.streaming_chunks,
messages=self.model_call_details.get("messages", None), messages=self.model_call_details.get("messages", None),
start_time=start_time,
end_time=end_time,
) )
except Exception as e: except Exception as e:
print_verbose( print_verbose(
@ -1385,9 +1402,7 @@ class Logging:
self.model_call_details[ self.model_call_details[
"complete_streaming_response" "complete_streaming_response"
] = complete_streaming_response ] = complete_streaming_response
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
for callback in litellm._async_success_callback: for callback in litellm._async_success_callback:
try: try:
if callback == "cache" and litellm.cache is not None: if callback == "cache" and litellm.cache is not None:
@ -1434,7 +1449,6 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
if callable(callback): # custom logger functions if callable(callback): # custom logger functions
print_verbose(f"Async success callbacks: async_log_event")
await customLogger.async_log_event( await customLogger.async_log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
response_obj=result, response_obj=result,
@ -2835,6 +2849,7 @@ def cost_per_token(
verbose_logger.debug(f"Looking up model={model} in model_cost_map") verbose_logger.debug(f"Looking up model={model} in model_cost_map")
if model in model_cost_ref: if model in model_cost_ref:
verbose_logger.debug(f"Success: model={model} in model_cost_map")
if ( if (
model_cost_ref[model].get("input_cost_per_token", None) is not None model_cost_ref[model].get("input_cost_per_token", None) is not None
and model_cost_ref[model].get("output_cost_per_token", None) is not None and model_cost_ref[model].get("output_cost_per_token", None) is not None
@ -2850,11 +2865,17 @@ def cost_per_token(
model_cost_ref[model].get("input_cost_per_second", None) is not None model_cost_ref[model].get("input_cost_per_second", None) is not None
and response_time_ms is not None and response_time_ms is not None
): ):
verbose_logger.debug(
f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ## ## COST PER SECOND ##
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000 model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000
) )
completion_tokens_cost_usd_dollar = 0.0 completion_tokens_cost_usd_dollar = 0.0
verbose_logger.debug(
f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model_with_provider in model_cost_ref: elif model_with_provider in model_cost_ref:
print_verbose(f"Looking up model={model_with_provider} in model_cost_map") print_verbose(f"Looking up model={model_with_provider} in model_cost_map")
@ -2957,6 +2978,9 @@ def completion_cost(
"completion_tokens", 0 "completion_tokens", 0
) )
total_time = completion_response.get("_response_ms", 0) total_time = completion_response.get("_response_ms", 0)
verbose_logger.debug(
f"completion_response response ms: {completion_response.get('_response_ms')} "
)
model = ( model = (
model or completion_response["model"] model or completion_response["model"]
) # check if user passed an override for model, if it's none check completion_response['model'] ) # check if user passed an override for model, if it's none check completion_response['model']
@ -3026,6 +3050,7 @@ def register_model(model_cost: Union[str, dict]):
for key, value in loaded_model_cost.items(): for key, value in loaded_model_cost.items():
## override / add new keys to the existing model cost dictionary ## override / add new keys to the existing model cost dictionary
litellm.model_cost.setdefault(key, {}).update(value) litellm.model_cost.setdefault(key, {}).update(value)
verbose_logger.debug(f"{key} added to model cost map")
# add new model names to provider lists # add new model names to provider lists
if value.get("litellm_provider") == "openai": if value.get("litellm_provider") == "openai":
if key not in litellm.open_ai_chat_completion_models: if key not in litellm.open_ai_chat_completion_models:
@ -5170,6 +5195,8 @@ def convert_to_model_response_object(
"completion", "embedding", "image_generation" "completion", "embedding", "image_generation"
] = "completion", ] = "completion",
stream=False, stream=False,
start_time=None,
end_time=None,
): ):
try: try:
if response_type == "completion" and ( if response_type == "completion" and (
@ -5223,6 +5250,12 @@ def convert_to_model_response_object(
if "model" in response_object: if "model" in response_object:
model_response_object.model = response_object["model"] model_response_object.model = response_object["model"]
if start_time is not None and end_time is not None:
model_response_object._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
return model_response_object return model_response_object
elif response_type == "embedding" and ( elif response_type == "embedding" and (
model_response_object is None model_response_object is None
@ -5247,6 +5280,11 @@ def convert_to_model_response_object(
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
if start_time is not None and end_time is not None:
model_response_object._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
return model_response_object return model_response_object
elif response_type == "image_generation" and ( elif response_type == "image_generation" and (
model_response_object is None model_response_object is None