Merge pull request #2426 from BerriAI/litellm_whisper_cost_tracking

feat: add cost tracking + caching for `/audio/transcription` calls
This commit is contained in:
Krish Dholakia 2024-03-09 19:12:06 -08:00 committed by GitHub
commit c7d0af0a2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 247 additions and 41 deletions

View file

@ -10,7 +10,7 @@
import litellm
import time, logging, asyncio
import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any
from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger
@ -765,8 +765,24 @@ class Cache:
password: Optional[str] = None,
similarity_threshold: Optional[float] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[
Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None,
@ -881,9 +897,14 @@ class Cache:
"input",
"encoding_format",
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
transcription_only_kwargs = [
"file",
"language",
]
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs
combined_kwargs = (
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
)
for param in combined_kwargs:
# ignore litellm params here
if param in kwargs:
@ -915,6 +936,17 @@ class Cache:
param_value = (
caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
elif param == "file":
metadata_file_name = kwargs.get("metadata", {}).get(
"file_name", None
)
litellm_params_file_name = kwargs.get("litellm_params", {}).get(
"file_name", None
)
if metadata_file_name is not None:
param_value = metadata_file_name
elif litellm_params_file_name is not None:
param_value = litellm_params_file_name
else:
if kwargs[param] is None:
continue # ignore None params
@ -1144,8 +1176,24 @@ def enable_cache(
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[
Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs,
):
"""
@ -1193,8 +1241,24 @@ def update_cache(
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[
Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs,
):
"""

View file

@ -861,7 +861,8 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
@ -921,7 +922,8 @@ class AzureChatCompletion(BaseLLM):
},
original_response=stringified_response,
)
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return response
except Exception as e:
## LOGGING

View file

@ -753,6 +753,7 @@ class OpenAIChatCompletion(BaseLLM):
# return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e:
exception_mapping_worked = True
## LOGGING
logging_obj.post_call(
@ -824,7 +825,8 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
@ -862,7 +864,8 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(

View file

@ -3295,6 +3295,7 @@ async def audio_transcriptions(
user_api_key_dict, "team_id", None
)
data["metadata"]["endpoint"] = str(request.url)
data["metadata"]["file_name"] = file.filename
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
@ -3329,7 +3330,7 @@ async def audio_transcriptions(
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict,
data=data,
call_type="moderation",
call_type="audio_transcription",
)
## ROUTE TO CORRECT ENDPOINT ##

View file

@ -96,7 +96,11 @@ class ProxyLogging:
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal[
"completion", "embeddings", "image_generation", "moderation"
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""

View file

@ -6,7 +6,12 @@ sys.path.insert(
) # Adds the parent directory to the system path
import time
import litellm
from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models
from litellm import (
get_max_tokens,
model_cost,
open_ai_chat_completion_models,
TranscriptionResponse,
)
import pytest
@ -238,3 +243,57 @@ def test_cost_bedrock_pricing_actual_calls():
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
assert cost > 0
def test_whisper_openai():
litellm.set_verbose = True
transcription = TranscriptionResponse(
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
)
transcription._hidden_params = {
"model": "whisper-1",
"custom_llm_provider": "openai",
"optional_params": {},
"model_id": None,
}
_total_time_in_seconds = 3
transcription._response_ms = _total_time_in_seconds * 1000
cost = litellm.completion_cost(model="whisper-1", completion_response=transcription)
print(f"cost: {cost}")
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
expected_cost = round(
litellm.model_cost["whisper-1"]["output_cost_per_second"]
* _total_time_in_seconds,
5,
)
assert cost == expected_cost
def test_whisper_azure():
litellm.set_verbose = True
transcription = TranscriptionResponse(
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
)
transcription._hidden_params = {
"model": "whisper-1",
"custom_llm_provider": "azure",
"optional_params": {},
"model_id": None,
}
_total_time_in_seconds = 3
transcription._response_ms = _total_time_in_seconds * 1000
cost = litellm.completion_cost(
model="azure/azure-whisper", completion_response=transcription
)
print(f"cost: {cost}")
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
expected_cost = round(
litellm.model_cost["whisper-1"]["output_cost_per_second"]
* _total_time_in_seconds,
5,
)
assert cost == expected_cost

View file

@ -973,6 +973,7 @@ def test_image_generation_openai():
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
time.sleep(2)
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback

View file

@ -100,7 +100,7 @@ class TmpFunction:
def test_async_chat_openai_stream():
try:
tmp_function = TmpFunction()
# litellm.set_verbose = True
litellm.set_verbose = True
litellm.success_callback = [tmp_function.async_test_logging_fn]
complete_streaming_response = ""

View file

@ -336,6 +336,8 @@ def test_load_router_config():
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
] # init with all call types
litellm.disable_cache()

View file

@ -1168,6 +1168,7 @@ class Logging:
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse)
)
and self.stream != True
): # handle streaming separately
@ -1203,9 +1204,6 @@ class Logging:
model=base_model,
)
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
@ -1236,7 +1234,7 @@ class Logging:
def success_handler(
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
):
verbose_logger.debug(f"Logging Details LiteLLM-Success Call: {cache_hit}")
print_verbose(f"Logging Details LiteLLM-Success Call: {cache_hit}")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
@ -1245,7 +1243,7 @@ class Logging:
)
# print(f"original response in success handler: {self.model_call_details['original_response']}")
try:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
print_verbose(f"success callbacks: {litellm.success_callback}")
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None
if self.stream and isinstance(result, ModelResponse):
@ -1268,7 +1266,7 @@ class Logging:
self.sync_streaming_chunks.append(result)
if complete_streaming_response is not None:
verbose_logger.debug(
print_verbose(
f"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details["complete_streaming_response"] = (
@ -1615,6 +1613,14 @@ class Logging:
"aembedding", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
): # custom logger class
if self.stream and complete_streaming_response is None:
callback.log_stream_event(
@ -1647,6 +1653,14 @@ class Logging:
"aembedding", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
): # custom logger functions
print_verbose(
f"success callbacks: Running Custom Callback Function"
@ -1681,6 +1695,7 @@ class Logging:
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
print_verbose(f"Logging Details LiteLLM-Async Success Call: {cache_hit}")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
@ -2473,6 +2488,7 @@ def client(original_function):
and kwargs.get("aembedding", False) != True
and kwargs.get("acompletion", False) != True
and kwargs.get("aimg_generation", False) != True
and kwargs.get("atranscription", False) != True
): # allow users to control returning cached responses from the completion function
# checking cache
print_verbose(f"INSIDE CHECKING CACHE")
@ -2875,6 +2891,19 @@ def client(original_function):
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
elif call_type == CallTypes.atranscription.value and isinstance(
cached_result, dict
):
hidden_params = {
"model": "whisper-1",
"custom_llm_provider": custom_llm_provider,
}
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=TranscriptionResponse(),
response_type="audio_transcription",
hidden_params=hidden_params,
)
if kwargs.get("stream", False) == False:
# LOG SUCCESS
asyncio.create_task(
@ -3001,6 +3030,20 @@ def client(original_function):
else:
return result
# ADD HIDDEN PARAMS - additional call metadata
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
if (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
):
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
@ -3013,8 +3056,10 @@ def client(original_function):
)
and (kwargs.get("cache", {}).get("no-store", False) != True)
):
if isinstance(result, litellm.ModelResponse) or isinstance(
result, litellm.EmbeddingResponse
if (
isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
):
if (
isinstance(result, EmbeddingResponse)
@ -3058,18 +3103,7 @@ def client(original_function):
args=(result, start_time, end_time),
).start()
# RETURN RESULT
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
if isinstance(result, ModelResponse) or isinstance(
result, EmbeddingResponse
):
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
# REBUILD EMBEDDING CACHING
if (
isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None
@ -3575,6 +3609,20 @@ def cost_per_token(
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
elif (
model_cost_ref[model].get("output_cost_per_second", None) is not None
and response_time_ms is not None
):
print_verbose(
f"For model={model} - output_cost_per_second: {model_cost_ref[model].get('output_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ##
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_second"]
* response_time_ms
/ 1000
)
elif (
model_cost_ref[model].get("input_cost_per_second", None) is not None
and response_time_ms is not None
@ -3659,6 +3707,8 @@ def completion_cost(
"text_completion",
"image_generation",
"aimage_generation",
"transcription",
"atranscription",
] = "completion",
### REGION ###
custom_llm_provider=None,
@ -3694,7 +3744,6 @@ def completion_cost(
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
"""
try:
if (
(call_type == "aimage_generation" or call_type == "image_generation")
and model is not None
@ -3717,10 +3766,15 @@ def completion_cost(
verbose_logger.debug(
f"completion_response response ms: {completion_response.get('_response_ms')} "
)
model = (
model or completion_response["model"]
model = model or completion_response.get(
"model", None
) # check if user passed an override for model, if it's none check completion_response['model']
if hasattr(completion_response, "_hidden_params"):
if (
completion_response._hidden_params.get("model", None) is not None
and len(completion_response._hidden_params["model"]) > 0
):
model = completion_response._hidden_params.get("model", model)
custom_llm_provider = completion_response._hidden_params.get(
"custom_llm_provider", ""
)
@ -3801,6 +3855,7 @@ def completion_cost(
# see https://replicate.com/pricing
elif model in litellm.replicate_models or "replicate" in model:
return get_replicate_completion_pricing(completion_response, total_time)
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
@ -6314,6 +6369,7 @@ def convert_to_model_response_object(
stream=False,
start_time=None,
end_time=None,
hidden_params: Optional[dict] = None,
):
try:
if response_type == "completion" and (
@ -6373,6 +6429,9 @@ def convert_to_model_response_object(
end_time - start_time
).total_seconds() * 1000
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object
elif response_type == "embedding" and (
model_response_object is None
@ -6402,6 +6461,9 @@ def convert_to_model_response_object(
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object
elif response_type == "image_generation" and (
model_response_object is None
@ -6419,6 +6481,9 @@ def convert_to_model_response_object(
if "data" in response_object:
model_response_object.data = response_object["data"]
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object
elif response_type == "audio_transcription" and (
model_response_object is None
@ -6432,6 +6497,9 @@ def convert_to_model_response_object(
if "text" in response_object:
model_response_object.text = response_object["text"]
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object
except Exception as e:
raise Exception(f"Invalid response object {traceback.format_exc()}")

View file

@ -31,7 +31,8 @@ def test_transcription():
model="whisper-1",
file=audio_file,
)
print(f"transcript: {transcript}")
print(f"transcript: {transcript.model_dump()}")
print(f"transcript: {transcript._hidden_params}")
# test_transcription()
@ -47,6 +48,7 @@ def test_transcription_azure():
api_version="2024-02-15-preview",
)
print(f"transcript: {transcript}")
assert transcript.text is not None
assert isinstance(transcript.text, str)