feat: add cost tracking + caching for transcription calls

This commit is contained in:
Krrish Dholakia 2024-03-09 15:43:38 -08:00
parent e10991e02b
commit fa45c569fd
8 changed files with 225 additions and 37 deletions

View file

@ -10,7 +10,7 @@
import litellm
import time, logging, asyncio
import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any
from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger
@ -764,8 +764,24 @@ class Cache:
password: Optional[str] = None,
similarity_threshold: Optional[float] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[
Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None,
@ -880,9 +896,18 @@ class Cache:
"input",
"encoding_format",
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
transcription_only_kwargs = [
"model",
"file",
"language",
"prompt",
"response_format",
"temperature",
]
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs
combined_kwargs = (
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
)
for param in combined_kwargs:
# ignore litellm params here
if param in kwargs:
@ -914,6 +939,17 @@ class Cache:
param_value = (
caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
elif param == "file":
metadata_file_name = kwargs.get("metadata", {}).get(
"file_name", None
)
litellm_params_file_name = kwargs.get("litellm_params", {}).get(
"file_name", None
)
if metadata_file_name is not None:
param_value = metadata_file_name
elif litellm_params_file_name is not None:
param_value = litellm_params_file_name
else:
if kwargs[param] is None:
continue # ignore None params
@ -1143,8 +1179,24 @@ def enable_cache(
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[
Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs,
):
"""
@ -1192,8 +1244,24 @@ def update_cache(
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
List[
Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs,
):
"""