fix(utils.py): fix cost calculation for openai-compatible streaming object

This commit is contained in:
Krrish Dholakia 2024-06-04 10:36:25 -07:00
parent 7b474ec267
commit 52a2f5150c
9 changed files with 189 additions and 79 deletions

View file

@ -1499,51 +1499,21 @@ class Logging:
)
and self.stream != True
): # handle streaming separately
try:
if self.model_call_details.get("cache_hit", False) == True:
self.model_call_details["response_cost"] = 0.0
else:
result._hidden_params["optional_params"] = self.optional_params
if (
self.call_type == CallTypes.aimage_generation.value
or self.call_type == CallTypes.image_generation.value
):
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=result,
model=self.model,
call_type=self.call_type,
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
), # set for img gen models
)
)
else:
base_model: Optional[str] = None
# check if base_model set on azure
base_model = _get_base_model_from_metadata(
model_call_details=self.model_call_details
)
# litellm model name
litellm_model = self.model_call_details["model"]
if (
litellm_model in litellm.model_cost
and self.custom_pricing == True
):
base_model = litellm_model
# base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=result,
call_type=self.call_type,
model=base_model,
)
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
self.model_call_details["response_cost"] = (
litellm.response_cost_calculator(
response_object=result,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details
),
call_type=self.call_type,
optional_params=self.optional_params,
)
self.model_call_details["response_cost"] = None
)
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None
@ -1607,29 +1577,21 @@ class Logging:
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
try:
if self.model_call_details.get("cache_hit", False) == True:
self.model_call_details["response_cost"] = 0.0
else:
# check if base_model set on azure
base_model = _get_base_model_from_metadata(
self.model_call_details["response_cost"] = (
litellm.response_cost_calculator(
response_object=complete_streaming_response,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=complete_streaming_response,
model=base_model,
)
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
),
call_type=self.call_type,
optional_params=self.optional_params,
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
)
if self.dynamic_success_callbacks is not None and isinstance(
self.dynamic_success_callbacks, list
):
@ -4576,16 +4538,20 @@ def completion_cost(
completion="",
total_time=0.0, # used for replicate, sagemaker
call_type: Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"completion",
"acompletion",
"atext_completion",
"text_completion",
"image_generation",
"aimage_generation",
"transcription",
"moderation",
"amoderation",
"atranscription",
"transcription",
"aspeech",
"speech",
] = "completion",
### REGION ###
custom_llm_provider=None,
@ -11096,8 +11062,16 @@ class CustomStreamWrapper:
return ""
def model_response_creator(self):
_model = self.model
_received_llm_provider = self.custom_llm_provider
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
if (
_received_llm_provider == "openai"
and _received_llm_provider != _logging_obj_llm_provider
):
_model = "{}/{}".format(_logging_obj_llm_provider, _model)
model_response = ModelResponse(
stream=True, model=self.model, stream_options=self.stream_options
stream=True, model=_model, stream_options=self.stream_options
)
if self.response_id is not None:
model_response.id = self.response_id
@ -11105,10 +11079,9 @@ class CustomStreamWrapper:
self.response_id = model_response.id
if self.system_fingerprint is not None:
model_response.system_fingerprint = self.system_fingerprint
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
model_response._hidden_params["created_at"] = time.time()
model_response.choices = [StreamingChoices()]
model_response.choices[0].finish_reason = None
model_response.choices = [StreamingChoices(finish_reason=None)]
return model_response
def is_delta_empty(self, delta: Delta) -> bool: