feat: add cost tracking + caching for transcription calls

This commit is contained in:
Krrish Dholakia 2024-03-09 15:43:38 -08:00
parent e10991e02b
commit fa45c569fd
8 changed files with 225 additions and 37 deletions

View file

@ -10,7 +10,7 @@
import litellm import litellm
import time, logging, asyncio import time, logging, asyncio
import json, traceback, ast, hashlib import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -764,8 +764,24 @@ class Cache:
password: Optional[str] = None, password: Optional[str] = None,
similarity_threshold: Optional[float] = None, similarity_threshold: Optional[float] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[
] = ["completion", "acompletion", "embedding", "aembedding"], Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
# s3 Bucket, boto3 configuration # s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None, s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None, s3_region_name: Optional[str] = None,
@ -880,9 +896,18 @@ class Cache:
"input", "input",
"encoding_format", "encoding_format",
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs ] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
transcription_only_kwargs = [
"model",
"file",
"language",
"prompt",
"response_format",
"temperature",
]
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set() # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs combined_kwargs = (
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
)
for param in combined_kwargs: for param in combined_kwargs:
# ignore litellm params here # ignore litellm params here
if param in kwargs: if param in kwargs:
@ -914,6 +939,17 @@ class Cache:
param_value = ( param_value = (
caching_group or model_group or kwargs[param] caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"] ) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
elif param == "file":
metadata_file_name = kwargs.get("metadata", {}).get(
"file_name", None
)
litellm_params_file_name = kwargs.get("litellm_params", {}).get(
"file_name", None
)
if metadata_file_name is not None:
param_value = metadata_file_name
elif litellm_params_file_name is not None:
param_value = litellm_params_file_name
else: else:
if kwargs[param] is None: if kwargs[param] is None:
continue # ignore None params continue # ignore None params
@ -1143,8 +1179,24 @@ def enable_cache(
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[
] = ["completion", "acompletion", "embedding", "aembedding"], Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs, **kwargs,
): ):
""" """
@ -1192,8 +1244,24 @@ def update_cache(
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[
] = ["completion", "acompletion", "embedding", "aembedding"], Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs, **kwargs,
): ):
""" """

View file

@ -861,7 +861,8 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response return final_response
async def async_audio_transcriptions( async def async_audio_transcriptions(
@ -921,7 +922,8 @@ class AzureChatCompletion(BaseLLM):
}, },
original_response=stringified_response, original_response=stringified_response,
) )
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return response return response
except Exception as e: except Exception as e:
## LOGGING ## LOGGING

View file

@ -824,7 +824,8 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response return final_response
async def async_audio_transcriptions( async def async_audio_transcriptions(
@ -862,7 +863,8 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(

View file

@ -3282,6 +3282,7 @@ async def audio_transcriptions(
user_api_key_dict, "team_id", None user_api_key_dict, "team_id", None
) )
data["metadata"]["endpoint"] = str(request.url) data["metadata"]["endpoint"] = str(request.url)
data["metadata"]["file_name"] = file.filename
### TEAM-SPECIFIC PARAMS ### ### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None: if user_api_key_dict.team_id is not None:
@ -3316,7 +3317,7 @@ async def audio_transcriptions(
data = await proxy_logging_obj.pre_call_hook( data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
data=data, data=data,
call_type="moderation", call_type="audio_transcription",
) )
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##

View file

@ -96,7 +96,11 @@ class ProxyLogging:
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
data: dict, data: dict,
call_type: Literal[ call_type: Literal[
"completion", "embeddings", "image_generation", "moderation" "completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
], ],
): ):
""" """

View file

@ -6,7 +6,12 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import time import time
import litellm import litellm
from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models from litellm import (
get_max_tokens,
model_cost,
open_ai_chat_completion_models,
TranscriptionResponse,
)
import pytest import pytest
@ -238,3 +243,57 @@ def test_cost_bedrock_pricing_actual_calls():
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
) )
assert cost > 0 assert cost > 0
def test_whisper_openai():
litellm.set_verbose = True
transcription = TranscriptionResponse(
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
)
transcription._hidden_params = {
"model": "whisper-1",
"custom_llm_provider": "openai",
"optional_params": {},
"model_id": None,
}
_total_time_in_seconds = 3
transcription._response_ms = _total_time_in_seconds * 1000
cost = litellm.completion_cost(model="whisper-1", completion_response=transcription)
print(f"cost: {cost}")
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
expected_cost = round(
litellm.model_cost["whisper-1"]["output_cost_per_second"]
* _total_time_in_seconds,
5,
)
assert cost == expected_cost
def test_whisper_azure():
litellm.set_verbose = True
transcription = TranscriptionResponse(
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
)
transcription._hidden_params = {
"model": "whisper-1",
"custom_llm_provider": "azure",
"optional_params": {},
"model_id": None,
}
_total_time_in_seconds = 3
transcription._response_ms = _total_time_in_seconds * 1000
cost = litellm.completion_cost(
model="azure/azure-whisper", completion_response=transcription
)
print(f"cost: {cost}")
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
expected_cost = round(
litellm.model_cost["whisper-1"]["output_cost_per_second"]
* _total_time_in_seconds,
5,
)
assert cost == expected_cost

View file

@ -1168,6 +1168,7 @@ class Logging:
isinstance(result, ModelResponse) isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse) or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse) or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse)
) )
and self.stream != True and self.stream != True
): # handle streaming separately ): # handle streaming separately
@ -1203,9 +1204,6 @@ class Logging:
model=base_model, model=base_model,
) )
) )
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e: except litellm.NotFoundError as e:
verbose_logger.debug( verbose_logger.debug(
f"Model={self.model} not found in completion cost map." f"Model={self.model} not found in completion cost map."
@ -1236,7 +1234,7 @@ class Logging:
def success_handler( def success_handler(
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
): ):
verbose_logger.debug(f"Logging Details LiteLLM-Success Call: {cache_hit}") print_verbose(f"Logging Details LiteLLM-Success Call: {cache_hit}")
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
@ -1681,6 +1679,7 @@ class Logging:
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
print_verbose(f"Logging Details LiteLLM-Async Success Call: {cache_hit}")
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
) )
@ -2473,6 +2472,7 @@ def client(original_function):
and kwargs.get("aembedding", False) != True and kwargs.get("aembedding", False) != True
and kwargs.get("acompletion", False) != True and kwargs.get("acompletion", False) != True
and kwargs.get("aimg_generation", False) != True and kwargs.get("aimg_generation", False) != True
and kwargs.get("atranscription", False) != True
): # allow users to control returning cached responses from the completion function ): # allow users to control returning cached responses from the completion function
# checking cache # checking cache
print_verbose(f"INSIDE CHECKING CACHE") print_verbose(f"INSIDE CHECKING CACHE")
@ -2875,6 +2875,19 @@ def client(original_function):
model_response_object=EmbeddingResponse(), model_response_object=EmbeddingResponse(),
response_type="embedding", response_type="embedding",
) )
elif call_type == CallTypes.atranscription.value and isinstance(
cached_result, dict
):
hidden_params = {
"model": "whisper-1",
"custom_llm_provider": custom_llm_provider,
}
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=TranscriptionResponse(),
response_type="audio_transcription",
hidden_params=hidden_params,
)
if kwargs.get("stream", False) == False: if kwargs.get("stream", False) == False:
# LOG SUCCESS # LOG SUCCESS
asyncio.create_task( asyncio.create_task(
@ -3001,6 +3014,20 @@ def client(original_function):
else: else:
return result return result
# ADD HIDDEN PARAMS - additional call metadata
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
if (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
):
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
### POST-CALL RULES ### ### POST-CALL RULES ###
post_call_processing(original_response=result, model=model) post_call_processing(original_response=result, model=model)
@ -3013,8 +3040,10 @@ def client(original_function):
) )
and (kwargs.get("cache", {}).get("no-store", False) != True) and (kwargs.get("cache", {}).get("no-store", False) != True)
): ):
if isinstance(result, litellm.ModelResponse) or isinstance( if (
result, litellm.EmbeddingResponse isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
): ):
if ( if (
isinstance(result, EmbeddingResponse) isinstance(result, EmbeddingResponse)
@ -3058,18 +3087,7 @@ def client(original_function):
args=(result, start_time, end_time), args=(result, start_time, end_time),
).start() ).start()
# RETURN RESULT # REBUILD EMBEDDING CACHING
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
if isinstance(result, ModelResponse) or isinstance(
result, EmbeddingResponse
):
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
if ( if (
isinstance(result, EmbeddingResponse) isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None and final_embedding_cached_response is not None
@ -3575,6 +3593,20 @@ def cost_per_token(
completion_tokens_cost_usd_dollar = ( completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens model_cost_ref[model]["output_cost_per_token"] * completion_tokens
) )
elif (
model_cost_ref[model].get("output_cost_per_second", None) is not None
and response_time_ms is not None
):
print_verbose(
f"For model={model} - output_cost_per_second: {model_cost_ref[model].get('output_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ##
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_second"]
* response_time_ms
/ 1000
)
elif ( elif (
model_cost_ref[model].get("input_cost_per_second", None) is not None model_cost_ref[model].get("input_cost_per_second", None) is not None
and response_time_ms is not None and response_time_ms is not None
@ -3659,6 +3691,8 @@ def completion_cost(
"text_completion", "text_completion",
"image_generation", "image_generation",
"aimage_generation", "aimage_generation",
"transcription",
"atranscription",
] = "completion", ] = "completion",
### REGION ### ### REGION ###
custom_llm_provider=None, custom_llm_provider=None,
@ -3703,6 +3737,7 @@ def completion_cost(
and custom_llm_provider == "azure" and custom_llm_provider == "azure"
): ):
model = "dall-e-2" # for dall-e-2, azure expects an empty model name model = "dall-e-2" # for dall-e-2, azure expects an empty model name
# Handle Inputs to completion_cost # Handle Inputs to completion_cost
prompt_tokens = 0 prompt_tokens = 0
completion_tokens = 0 completion_tokens = 0
@ -3717,10 +3752,11 @@ def completion_cost(
verbose_logger.debug( verbose_logger.debug(
f"completion_response response ms: {completion_response.get('_response_ms')} " f"completion_response response ms: {completion_response.get('_response_ms')} "
) )
model = ( model = model or completion_response.get(
model or completion_response["model"] "model", None
) # check if user passed an override for model, if it's none check completion_response['model'] ) # check if user passed an override for model, if it's none check completion_response['model']
if hasattr(completion_response, "_hidden_params"): if hasattr(completion_response, "_hidden_params"):
model = completion_response._hidden_params.get("model", model)
custom_llm_provider = completion_response._hidden_params.get( custom_llm_provider = completion_response._hidden_params.get(
"custom_llm_provider", "" "custom_llm_provider", ""
) )
@ -3801,6 +3837,7 @@ def completion_cost(
# see https://replicate.com/pricing # see https://replicate.com/pricing
elif model in litellm.replicate_models or "replicate" in model: elif model in litellm.replicate_models or "replicate" in model:
return get_replicate_completion_pricing(completion_response, total_time) return get_replicate_completion_pricing(completion_response, total_time)
( (
prompt_tokens_cost_usd_dollar, prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar,
@ -6314,6 +6351,7 @@ def convert_to_model_response_object(
stream=False, stream=False,
start_time=None, start_time=None,
end_time=None, end_time=None,
hidden_params: Optional[dict] = None,
): ):
try: try:
if response_type == "completion" and ( if response_type == "completion" and (
@ -6373,6 +6411,9 @@ def convert_to_model_response_object(
end_time - start_time end_time - start_time
).total_seconds() * 1000 ).total_seconds() * 1000
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object return model_response_object
elif response_type == "embedding" and ( elif response_type == "embedding" and (
model_response_object is None model_response_object is None
@ -6402,6 +6443,9 @@ def convert_to_model_response_object(
end_time - start_time end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai ).total_seconds() * 1000 # return response latency in ms like openai
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object return model_response_object
elif response_type == "image_generation" and ( elif response_type == "image_generation" and (
model_response_object is None model_response_object is None
@ -6419,6 +6463,9 @@ def convert_to_model_response_object(
if "data" in response_object: if "data" in response_object:
model_response_object.data = response_object["data"] model_response_object.data = response_object["data"]
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object return model_response_object
elif response_type == "audio_transcription" and ( elif response_type == "audio_transcription" and (
model_response_object is None model_response_object is None
@ -6432,6 +6479,9 @@ def convert_to_model_response_object(
if "text" in response_object: if "text" in response_object:
model_response_object.text = response_object["text"] model_response_object.text = response_object["text"]
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object return model_response_object
except Exception as e: except Exception as e:
raise Exception(f"Invalid response object {traceback.format_exc()}") raise Exception(f"Invalid response object {traceback.format_exc()}")

View file

@ -31,7 +31,8 @@ def test_transcription():
model="whisper-1", model="whisper-1",
file=audio_file, file=audio_file,
) )
print(f"transcript: {transcript}") print(f"transcript: {transcript.model_dump()}")
print(f"transcript: {transcript._hidden_params}")
# test_transcription() # test_transcription()
@ -47,6 +48,7 @@ def test_transcription_azure():
api_version="2024-02-15-preview", api_version="2024-02-15-preview",
) )
print(f"transcript: {transcript}")
assert transcript.text is not None assert transcript.text is not None
assert isinstance(transcript.text, str) assert isinstance(transcript.text, str)