Merge pull request #4009 from BerriAI/litellm_fix_streaming_cost_cal

fix(utils.py): fix cost calculation for openai-compatible streaming object
This commit is contained in:
Krish Dholakia 2024-06-04 21:00:22 -07:00 committed by GitHub
commit e678dce88b
9 changed files with 230 additions and 88 deletions

5
.gitignore vendored
View file

@ -55,4 +55,7 @@ litellm/proxy/_super_secret_config.yaml
litellm/proxy/_super_secret_config.yaml litellm/proxy/_super_secret_config.yaml
litellm/proxy/myenv/bin/activate litellm/proxy/myenv/bin/activate
litellm/proxy/myenv/bin/Activate.ps1 litellm/proxy/myenv/bin/Activate.ps1
myenv/* myenv/*
litellm/proxy/_experimental/out/404/index.html
litellm/proxy/_experimental/out/model_hub/index.html
litellm/proxy/_experimental/out/onboarding/index.html

View file

@ -815,3 +815,4 @@ from .router import Router
from .assistants.main import * from .assistants.main import *
from .batches.main import * from .batches.main import *
from .scheduler import * from .scheduler import *
from .cost_calculator import response_cost_calculator

View file

@ -0,0 +1,80 @@
# What is this?
## File for 'response_cost' calculation in Logging
from typing import Optional, Union, Literal
from litellm.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
CallTypes,
completion_cost,
print_verbose,
)
import litellm
def response_cost_calculator(
response_object: Union[
ModelResponse,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
],
model: str,
custom_llm_provider: str,
call_type: Literal[
"embedding",
"aembedding",
"completion",
"acompletion",
"atext_completion",
"text_completion",
"image_generation",
"aimage_generation",
"moderation",
"amoderation",
"atranscription",
"transcription",
"aspeech",
"speech",
],
optional_params: dict,
cache_hit: Optional[bool] = None,
base_model: Optional[str] = None,
custom_pricing: Optional[bool] = None,
) -> Optional[float]:
try:
response_cost: float = 0.0
if cache_hit is not None and cache_hit == True:
response_cost = 0.0
else:
response_object._hidden_params["optional_params"] = optional_params
if isinstance(response_object, ImageResponse):
response_cost = completion_cost(
completion_response=response_object,
model=model,
call_type=call_type,
custom_llm_provider=custom_llm_provider,
)
else:
if (
model in litellm.model_cost
and custom_pricing is not None
and custom_llm_provider == True
): # override defaults if custom pricing is set
base_model = model
# base_model defaults to None if not set on model_info
response_cost = completion_cost(
completion_response=response_object,
call_type=call_type,
model=base_model,
custom_llm_provider=custom_llm_provider,
)
return response_cost
except litellm.NotFoundError as e:
print_verbose(
f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map."
)
return None

View file

@ -4536,7 +4536,7 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
def stream_chunk_builder( def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
): ) -> Union[ModelResponse, TextCompletionResponse]:
model_response = litellm.ModelResponse() model_response = litellm.ModelResponse()
### SORT CHUNKS BASED ON CREATED ORDER ## ### SORT CHUNKS BASED ON CREATED ORDER ##
print_verbose("Goes into checking if chunk has hiddden created at param") print_verbose("Goes into checking if chunk has hiddden created at param")

View file

@ -5,6 +5,9 @@ model_list:
model: openai/my-fake-model model: openai/my-fake-model
rpm: 800 rpm: 800
model_name: gpt-3.5-turbo-fake-model model_name: gpt-3.5-turbo-fake-model
- model_name: llama3-70b-8192
litellm_params:
model: groq/llama3-70b-8192
# - litellm_params: # - litellm_params:
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ # api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
# api_key: os.environ/AZURE_EUROPE_API_KEY # api_key: os.environ/AZURE_EUROPE_API_KEY

View file

@ -14,6 +14,7 @@ from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
# litellm.num_retries=3 # litellm.num_retries=3
litellm.cache = None litellm.cache = None
@ -152,29 +153,63 @@ async def test_completion_databricks(sync_mode):
response_format_tests(response=response) response_format_tests(response=response)
def predibase_mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"generated_text": " Is it to find happiness, to achieve success,",
"details": {
"finish_reason": "length",
"prompt_tokens": 8,
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{"id": 2209, "text": " Is", "logprob": -1.7568359, "special": False},
{"id": 433, "text": " it", "logprob": -0.2220459, "special": False},
{"id": 311, "text": " to", "logprob": -0.6928711, "special": False},
{"id": 1505, "text": " find", "logprob": -0.6425781, "special": False},
{
"id": 23871,
"text": " happiness",
"logprob": -0.07519531,
"special": False,
},
{"id": 11, "text": ",", "logprob": -0.07110596, "special": False},
{"id": 311, "text": " to", "logprob": -0.79296875, "special": False},
{
"id": 11322,
"text": " achieve",
"logprob": -0.7602539,
"special": False,
},
{
"id": 2450,
"text": " success",
"logprob": -0.03656006,
"special": False,
},
{"id": 11, "text": ",", "logprob": -0.0011510849, "special": False},
],
},
}
return mock_response
# @pytest.mark.skip(reason="local only test") # @pytest.mark.skip(reason="local only test")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_predibase(sync_mode): async def test_completion_predibase():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
if sync_mode: with patch("requests.post", side_effect=predibase_mock_post):
response = completion( response = completion(
model="predibase/llama-3-8b-instruct", model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95", tenant_id="c4768f95",
api_key=os.getenv("PREDIBASE_API_KEY"), api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}], messages=[{"role": "user", "content": "What is the meaning of life?"}],
) max_tokens=10,
print(response)
else:
response = await litellm.acompletion(
model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95",
api_base="https://serving.app.predibase.com",
api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
) )
print(response) print(response)

View file

@ -470,3 +470,50 @@ def test_replicate_llama3_cost_tracking():
5, 5,
) )
assert cost == expected_cost assert cost == expected_cost
@pytest.mark.parametrize("is_streaming", [True, False]) #
def test_groq_response_cost_tracking(is_streaming):
from litellm.utils import (
ModelResponse,
Choices,
Message,
Usage,
CallTypes,
StreamingChoices,
Delta,
)
response = ModelResponse(
id="chatcmpl-876cce24-e520-4cf8-8649-562a9be11c02",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="Hi! I'm an AI, so I don't have emotions or feelings like humans do, but I'm functioning properly and ready to help with any questions or topics you'd like to discuss! How can I assist you today?",
role="assistant",
),
)
],
created=1717519830,
model="llama3-70b-8192",
object="chat.completion",
system_fingerprint="fp_c1a4bcec29",
usage=Usage(completion_tokens=46, prompt_tokens=17, total_tokens=63),
)
response._hidden_params["custom_llm_provider"] = "groq"
print(response)
response_cost = litellm.response_cost_calculator(
response_object=response,
model="groq/llama3-70b-8192",
custom_llm_provider="groq",
call_type=CallTypes.acompletion.value,
optional_params={},
)
assert isinstance(response_cost, float)
assert response_cost > 0.0
print(f"response_cost: {response_cost}")

View file

@ -885,6 +885,7 @@ def test_completion_mistral_api_mistral_large_function_call_with_streaming():
idx = 0 idx = 0
for chunk in response: for chunk in response:
print(f"chunk in response: {chunk}") print(f"chunk in response: {chunk}")
assert chunk._hidden_params["custom_llm_provider"] == "mistral"
if idx == 0: if idx == 0:
assert ( assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None chunk.choices[0].delta.tool_calls[0].function.arguments is not None
@ -898,7 +899,6 @@ def test_completion_mistral_api_mistral_large_function_call_with_streaming():
elif chunk.choices[0].finish_reason is not None: # last chunk elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk) validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1 idx += 1
# raise Exception("it worked!")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -1501,51 +1501,21 @@ class Logging:
) )
and self.stream != True and self.stream != True
): # handle streaming separately ): # handle streaming separately
try: self.model_call_details["response_cost"] = (
if self.model_call_details.get("cache_hit", False) == True: litellm.response_cost_calculator(
self.model_call_details["response_cost"] = 0.0 response_object=result,
else: model=self.model,
result._hidden_params["optional_params"] = self.optional_params cache_hit=self.model_call_details.get("cache_hit", False),
if ( custom_llm_provider=self.model_call_details.get(
self.call_type == CallTypes.aimage_generation.value "custom_llm_provider", None
or self.call_type == CallTypes.image_generation.value ),
): base_model=_get_base_model_from_metadata(
self.model_call_details["response_cost"] = ( model_call_details=self.model_call_details
litellm.completion_cost( ),
completion_response=result, call_type=self.call_type,
model=self.model, optional_params=self.optional_params,
call_type=self.call_type,
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
), # set for img gen models
)
)
else:
base_model: Optional[str] = None
# check if base_model set on azure
base_model = _get_base_model_from_metadata(
model_call_details=self.model_call_details
)
# litellm model name
litellm_model = self.model_call_details["model"]
if (
litellm_model in litellm.model_cost
and self.custom_pricing == True
):
base_model = litellm_model
# base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=result,
call_type=self.call_type,
model=base_model,
)
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
) )
self.model_call_details["response_cost"] = None )
else: # streaming chunks + image gen. else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None self.model_call_details["response_cost"] = None
@ -1609,29 +1579,21 @@ class Logging:
self.model_call_details["complete_streaming_response"] = ( self.model_call_details["complete_streaming_response"] = (
complete_streaming_response complete_streaming_response
) )
try: self.model_call_details["response_cost"] = (
if self.model_call_details.get("cache_hit", False) == True: litellm.response_cost_calculator(
self.model_call_details["response_cost"] = 0.0 response_object=complete_streaming_response,
else: model=self.model,
# check if base_model set on azure cache_hit=self.model_call_details.get("cache_hit", False),
base_model = _get_base_model_from_metadata( custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details model_call_details=self.model_call_details
) ),
# base_model defaults to None if not set on model_info call_type=self.call_type,
self.model_call_details["response_cost"] = ( optional_params=self.optional_params,
litellm.completion_cost(
completion_response=complete_streaming_response,
model=base_model,
)
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
) )
except litellm.NotFoundError as e: )
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
if self.dynamic_success_callbacks is not None and isinstance( if self.dynamic_success_callbacks is not None and isinstance(
self.dynamic_success_callbacks, list self.dynamic_success_callbacks, list
): ):
@ -4579,16 +4541,20 @@ def completion_cost(
completion="", completion="",
total_time=0.0, # used for replicate, sagemaker total_time=0.0, # used for replicate, sagemaker
call_type: Literal[ call_type: Literal[
"completion",
"acompletion",
"embedding", "embedding",
"aembedding", "aembedding",
"completion",
"acompletion",
"atext_completion", "atext_completion",
"text_completion", "text_completion",
"image_generation", "image_generation",
"aimage_generation", "aimage_generation",
"transcription", "moderation",
"amoderation",
"atranscription", "atranscription",
"transcription",
"aspeech",
"speech",
] = "completion", ] = "completion",
### REGION ### ### REGION ###
custom_llm_provider=None, custom_llm_provider=None,
@ -5494,7 +5460,7 @@ def get_optional_params(
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if stop is not None: if stop is not None:
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -5949,7 +5915,6 @@ def get_optional_params(
optional_params["logprobs"] = logprobs optional_params["logprobs"] = logprobs
if top_logprobs is not None: if top_logprobs is not None:
optional_params["top_logprobs"] = top_logprobs optional_params["top_logprobs"] = top_logprobs
elif custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -11106,8 +11071,16 @@ class CustomStreamWrapper:
return "" return ""
def model_response_creator(self): def model_response_creator(self):
_model = self.model
_received_llm_provider = self.custom_llm_provider
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
if (
_received_llm_provider == "openai"
and _received_llm_provider != _logging_obj_llm_provider
):
_model = "{}/{}".format(_logging_obj_llm_provider, _model)
model_response = ModelResponse( model_response = ModelResponse(
stream=True, model=self.model, stream_options=self.stream_options stream=True, model=_model, stream_options=self.stream_options
) )
if self.response_id is not None: if self.response_id is not None:
model_response.id = self.response_id model_response.id = self.response_id
@ -11115,7 +11088,7 @@ class CustomStreamWrapper:
self.response_id = model_response.id self.response_id = model_response.id
if self.system_fingerprint is not None: if self.system_fingerprint is not None:
model_response.system_fingerprint = self.system_fingerprint model_response.system_fingerprint = self.system_fingerprint
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
model_response._hidden_params["created_at"] = time.time() model_response._hidden_params["created_at"] = time.time()
model_response.choices = [StreamingChoices(finish_reason=None)] model_response.choices = [StreamingChoices(finish_reason=None)]
return model_response return model_response