forked from phoenix/litellm-mirror
Merge pull request #4009 from BerriAI/litellm_fix_streaming_cost_cal
fix(utils.py): fix cost calculation for openai-compatible streaming object
This commit is contained in:
commit
c544ba3654
9 changed files with 230 additions and 88 deletions
117
litellm/utils.py
117
litellm/utils.py
|
@ -1501,51 +1501,21 @@ class Logging:
|
|||
)
|
||||
and self.stream != True
|
||||
): # handle streaming separately
|
||||
try:
|
||||
if self.model_call_details.get("cache_hit", False) == True:
|
||||
self.model_call_details["response_cost"] = 0.0
|
||||
else:
|
||||
result._hidden_params["optional_params"] = self.optional_params
|
||||
if (
|
||||
self.call_type == CallTypes.aimage_generation.value
|
||||
or self.call_type == CallTypes.image_generation.value
|
||||
):
|
||||
self.model_call_details["response_cost"] = (
|
||||
litellm.completion_cost(
|
||||
completion_response=result,
|
||||
model=self.model,
|
||||
call_type=self.call_type,
|
||||
custom_llm_provider=self.model_call_details.get(
|
||||
"custom_llm_provider", None
|
||||
), # set for img gen models
|
||||
)
|
||||
)
|
||||
else:
|
||||
base_model: Optional[str] = None
|
||||
# check if base_model set on azure
|
||||
base_model = _get_base_model_from_metadata(
|
||||
model_call_details=self.model_call_details
|
||||
)
|
||||
# litellm model name
|
||||
litellm_model = self.model_call_details["model"]
|
||||
if (
|
||||
litellm_model in litellm.model_cost
|
||||
and self.custom_pricing == True
|
||||
):
|
||||
base_model = litellm_model
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details["response_cost"] = (
|
||||
litellm.completion_cost(
|
||||
completion_response=result,
|
||||
call_type=self.call_type,
|
||||
model=base_model,
|
||||
)
|
||||
)
|
||||
except litellm.NotFoundError as e:
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model} not found in completion cost map."
|
||||
self.model_call_details["response_cost"] = (
|
||||
litellm.response_cost_calculator(
|
||||
response_object=result,
|
||||
model=self.model,
|
||||
cache_hit=self.model_call_details.get("cache_hit", False),
|
||||
custom_llm_provider=self.model_call_details.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
base_model=_get_base_model_from_metadata(
|
||||
model_call_details=self.model_call_details
|
||||
),
|
||||
call_type=self.call_type,
|
||||
optional_params=self.optional_params,
|
||||
)
|
||||
self.model_call_details["response_cost"] = None
|
||||
)
|
||||
else: # streaming chunks + image gen.
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
|
@ -1609,29 +1579,21 @@ class Logging:
|
|||
self.model_call_details["complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
try:
|
||||
if self.model_call_details.get("cache_hit", False) == True:
|
||||
self.model_call_details["response_cost"] = 0.0
|
||||
else:
|
||||
# check if base_model set on azure
|
||||
base_model = _get_base_model_from_metadata(
|
||||
self.model_call_details["response_cost"] = (
|
||||
litellm.response_cost_calculator(
|
||||
response_object=complete_streaming_response,
|
||||
model=self.model,
|
||||
cache_hit=self.model_call_details.get("cache_hit", False),
|
||||
custom_llm_provider=self.model_call_details.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
base_model=_get_base_model_from_metadata(
|
||||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details["response_cost"] = (
|
||||
litellm.completion_cost(
|
||||
completion_response=complete_streaming_response,
|
||||
model=base_model,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||
),
|
||||
call_type=self.call_type,
|
||||
optional_params=self.optional_params,
|
||||
)
|
||||
except litellm.NotFoundError as e:
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model} not found in completion cost map."
|
||||
)
|
||||
self.model_call_details["response_cost"] = None
|
||||
)
|
||||
if self.dynamic_success_callbacks is not None and isinstance(
|
||||
self.dynamic_success_callbacks, list
|
||||
):
|
||||
|
@ -4579,16 +4541,20 @@ def completion_cost(
|
|||
completion="",
|
||||
total_time=0.0, # used for replicate, sagemaker
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"completion",
|
||||
"acompletion",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"image_generation",
|
||||
"aimage_generation",
|
||||
"transcription",
|
||||
"moderation",
|
||||
"amoderation",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"aspeech",
|
||||
"speech",
|
||||
] = "completion",
|
||||
### REGION ###
|
||||
custom_llm_provider=None,
|
||||
|
@ -5494,7 +5460,7 @@ def get_optional_params(
|
|||
optional_params["top_p"] = top_p
|
||||
if stop is not None:
|
||||
optional_params["stop_sequences"] = stop
|
||||
elif custom_llm_provider == "huggingface":
|
||||
elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -5949,7 +5915,6 @@ def get_optional_params(
|
|||
optional_params["logprobs"] = logprobs
|
||||
if top_logprobs is not None:
|
||||
optional_params["top_logprobs"] = top_logprobs
|
||||
|
||||
elif custom_llm_provider == "openrouter":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -11106,8 +11071,16 @@ class CustomStreamWrapper:
|
|||
return ""
|
||||
|
||||
def model_response_creator(self):
|
||||
_model = self.model
|
||||
_received_llm_provider = self.custom_llm_provider
|
||||
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
|
||||
if (
|
||||
_received_llm_provider == "openai"
|
||||
and _received_llm_provider != _logging_obj_llm_provider
|
||||
):
|
||||
_model = "{}/{}".format(_logging_obj_llm_provider, _model)
|
||||
model_response = ModelResponse(
|
||||
stream=True, model=self.model, stream_options=self.stream_options
|
||||
stream=True, model=_model, stream_options=self.stream_options
|
||||
)
|
||||
if self.response_id is not None:
|
||||
model_response.id = self.response_id
|
||||
|
@ -11115,7 +11088,7 @@ class CustomStreamWrapper:
|
|||
self.response_id = model_response.id
|
||||
if self.system_fingerprint is not None:
|
||||
model_response.system_fingerprint = self.system_fingerprint
|
||||
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
|
||||
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
|
||||
model_response._hidden_params["created_at"] = time.time()
|
||||
model_response.choices = [StreamingChoices(finish_reason=None)]
|
||||
return model_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue