Merge pull request #4009 from BerriAI/litellm_fix_streaming_cost_cal

fix(utils.py): fix cost calculation for openai-compatible streaming object
This commit is contained in:
Krish Dholakia 2024-06-04 21:00:22 -07:00 committed by GitHub
commit c544ba3654
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 230 additions and 88 deletions

View file

@ -1501,51 +1501,21 @@ class Logging:
)
and self.stream != True
): # handle streaming separately
try:
if self.model_call_details.get("cache_hit", False) == True:
self.model_call_details["response_cost"] = 0.0
else:
result._hidden_params["optional_params"] = self.optional_params
if (
self.call_type == CallTypes.aimage_generation.value
or self.call_type == CallTypes.image_generation.value
):
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=result,
model=self.model,
call_type=self.call_type,
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
), # set for img gen models
)
)
else:
base_model: Optional[str] = None
# check if base_model set on azure
base_model = _get_base_model_from_metadata(
model_call_details=self.model_call_details
)
# litellm model name
litellm_model = self.model_call_details["model"]
if (
litellm_model in litellm.model_cost
and self.custom_pricing == True
):
base_model = litellm_model
# base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=result,
call_type=self.call_type,
model=base_model,
)
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
self.model_call_details["response_cost"] = (
litellm.response_cost_calculator(
response_object=result,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details
),
call_type=self.call_type,
optional_params=self.optional_params,
)
self.model_call_details["response_cost"] = None
)
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None
@ -1609,29 +1579,21 @@ class Logging:
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
try:
if self.model_call_details.get("cache_hit", False) == True:
self.model_call_details["response_cost"] = 0.0
else:
# check if base_model set on azure
base_model = _get_base_model_from_metadata(
self.model_call_details["response_cost"] = (
litellm.response_cost_calculator(
response_object=complete_streaming_response,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=complete_streaming_response,
model=base_model,
)
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
),
call_type=self.call_type,
optional_params=self.optional_params,
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
)
if self.dynamic_success_callbacks is not None and isinstance(
self.dynamic_success_callbacks, list
):
@ -4579,16 +4541,20 @@ def completion_cost(
completion="",
total_time=0.0, # used for replicate, sagemaker
call_type: Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"completion",
"acompletion",
"atext_completion",
"text_completion",
"image_generation",
"aimage_generation",
"transcription",
"moderation",
"amoderation",
"atranscription",
"transcription",
"aspeech",
"speech",
] = "completion",
### REGION ###
custom_llm_provider=None,
@ -5494,7 +5460,7 @@ def get_optional_params(
optional_params["top_p"] = top_p
if stop is not None:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "huggingface":
elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@ -5949,7 +5915,6 @@ def get_optional_params(
optional_params["logprobs"] = logprobs
if top_logprobs is not None:
optional_params["top_logprobs"] = top_logprobs
elif custom_llm_provider == "openrouter":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@ -11106,8 +11071,16 @@ class CustomStreamWrapper:
return ""
def model_response_creator(self):
_model = self.model
_received_llm_provider = self.custom_llm_provider
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
if (
_received_llm_provider == "openai"
and _received_llm_provider != _logging_obj_llm_provider
):
_model = "{}/{}".format(_logging_obj_llm_provider, _model)
model_response = ModelResponse(
stream=True, model=self.model, stream_options=self.stream_options
stream=True, model=_model, stream_options=self.stream_options
)
if self.response_id is not None:
model_response.id = self.response_id
@ -11115,7 +11088,7 @@ class CustomStreamWrapper:
self.response_id = model_response.id
if self.system_fingerprint is not None:
model_response.system_fingerprint = self.system_fingerprint
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
model_response._hidden_params["created_at"] = time.time()
model_response.choices = [StreamingChoices(finish_reason=None)]
return model_response