Merge pull request #4724 from BerriAI/litellm_Set_max_file_size_transc

[Feat] - set max file size on /audio/transcriptions
This commit is contained in:
Ishaan Jaff 2024-07-15 20:42:24 -07:00 committed by GitHub
commit 254ac37f65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 144 additions and 5 deletions

View file

@ -27,6 +27,7 @@ This covers:
- ✅ IP addressbased access control lists - ✅ IP addressbased access control lists
- ✅ Track Request IP Address - ✅ Track Request IP Address
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](./proxy/pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints) - ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](./proxy/pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)
- ✅ Set Max Request / File Size on Requests
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](./proxy/enterprise#enforce-required-params-for-llm-requests) - ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](./proxy/enterprise#enforce-required-params-for-llm-requests)
- **Spend Tracking** - **Spend Tracking**
- ✅ [Tracking Spend for Custom Tags](./proxy/enterprise#tracking-spend-for-custom-tags) - ✅ [Tracking Spend for Custom Tags](./proxy/enterprise#tracking-spend-for-custom-tags)

View file

@ -21,6 +21,7 @@ Features:
- ✅ IP addressbased access control lists - ✅ IP addressbased access control lists
- ✅ Track Request IP Address - ✅ Track Request IP Address
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints) - ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)
- ✅ Set Max Request / File Size on Requests
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests) - ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests)
- **Spend Tracking** - **Spend Tracking**
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags) - ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)

View file

@ -109,4 +109,33 @@ response = speech(
input="the quick brown fox jumped over the lazy dogs", input="the quick brown fox jumped over the lazy dogs",
) )
response.stream_to_file(speech_file_path) response.stream_to_file(speech_file_path)
```
## ✨ Enterprise LiteLLM Proxy - Set Max Request File Size
Use this when you want to limit the file size for requests sent to `audio/transcriptions`
```yaml
- model_name: whisper
litellm_params:
model: whisper-1
api_key: sk-*******
max_file_size_mb: 0.00001 # 👈 max file size in MB (Set this intentionally very small for testing)
model_info:
mode: audio_transcription
```
Make a test Request with a valid file
```shell
curl --location 'http://localhost:4000/v1/audio/transcriptions' \
--header 'Authorization: Bearer sk-1234' \
--form 'file=@"/Users/ishaanjaffer/Github/litellm/tests/gettysburg.wav"' \
--form 'model="whisper"'
```
Expect to see the follow response
```shell
{"error":{"message":"File size is too large. Please check your file size. Passed file size: 0.7392807006835938 MB. Max file size: 0.0001 MB","type":"bad_request","param":"file","code":500}}%
``` ```

View file

@ -1,6 +1,11 @@
from typing import Optional import ast
from fastapi import Request import json
import ast, json from typing import List, Optional
from fastapi import Request, UploadFile, status
from litellm._logging import verbose_proxy_logger
from litellm.types.router import Deployment
async def _read_request_body(request: Optional[Request]) -> dict: async def _read_request_body(request: Optional[Request]) -> dict:
@ -29,3 +34,66 @@ async def _read_request_body(request: Optional[Request]) -> dict:
return request_data return request_data
except: except:
return {} return {}
def check_file_size_under_limit(
request_data: dict,
file: UploadFile,
router_model_names: List[str],
) -> bool:
"""
Check if any files passed in request are under max_file_size_mb
Returns True -> when file size is under max_file_size_mb limit
Raises ProxyException -> when file size is over max_file_size_mb limit or not a premium_user
"""
from litellm.proxy.proxy_server import (
CommonProxyErrors,
ProxyException,
llm_router,
premium_user,
)
file_contents_size = file.size or 0
file_content_size_in_mb = file_contents_size / (1024 * 1024)
if llm_router is not None and request_data["model"] in router_model_names:
try:
deployment: Optional[Deployment] = (
llm_router.get_deployment_by_model_group_name(
model_group_name=request_data["model"]
)
)
if (
deployment
and deployment.litellm_params is not None
and deployment.litellm_params.max_file_size_mb is not None
):
max_file_size_mb = deployment.litellm_params.max_file_size_mb
except Exception as e:
verbose_proxy_logger.error(
"Got error when checking file size: %s", (str(e))
)
if max_file_size_mb is not None:
verbose_proxy_logger.debug(
"Checking file size, file content size=%s, max_file_size_mb=%s",
file_content_size_in_mb,
max_file_size_mb,
)
if not premium_user:
raise ProxyException(
message=f"Tried setting max_file_size_mb for /audio/transcriptions. {CommonProxyErrors.not_premium_user.value}",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
if file_content_size_in_mb > max_file_size_mb:
raise ProxyException(
message=f"File size is too large. Please check your file size. Passed file size: {file_content_size_in_mb} MB. Max file size: {max_file_size_mb} MB",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
return True

View file

@ -7,6 +7,13 @@ model_list:
- model_name: gemini-flash - model_name: gemini-flash
litellm_params: litellm_params:
model: gemini/gemini-1.5-flash model: gemini/gemini-1.5-flash
- model_name: whisper
litellm_params:
model: whisper-1
api_key: sk-*******
max_file_size_mb: 1000
model_info:
mode: audio_transcription
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234

View file

@ -143,7 +143,10 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper, decrypt_value_helper,
encrypt_value_helper, encrypt_value_helper,
) )
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
check_file_size_under_limit,
)
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.openai_endpoint_utils import ( from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment, remove_sensitive_info_from_deployment,
@ -3796,7 +3799,13 @@ async def audio_transcriptions(
param="file", param="file",
) )
# Instead of writing to a file # Check if File can be read in memory before reading
check_file_size_under_limit(
request_data=data,
file=file,
router_model_names=router_model_names,
)
file_content = await file.read() file_content = await file.read()
file_object = io.BytesIO(file_content) file_object = io.BytesIO(file_content)
file_object.name = file.filename file_object.name = file.filename

View file

@ -3684,6 +3684,24 @@ class Router:
raise Exception("Model invalid format - {}".format(type(model))) raise Exception("Model invalid format - {}".format(type(model)))
return None return None
def get_deployment_by_model_group_name(
self, model_group_name: str
) -> Optional[Deployment]:
"""
Returns -> Deployment or None
Raise Exception -> if model found in invalid format
"""
for model in self.model_list:
if model["model_name"] == model_group_name:
if isinstance(model, dict):
return Deployment(**model)
elif isinstance(model, Deployment):
return model
else:
raise Exception("Model Name invalid - {}".format(type(model)))
return None
def get_router_model_info(self, deployment: dict) -> ModelMapInfo: def get_router_model_info(self, deployment: dict) -> ModelMapInfo:
""" """
For a given model id, return the model info (max tokens, input cost, output cost, etc.). For a given model id, return the model info (max tokens, input cost, output cost, etc.).

View file

@ -154,6 +154,8 @@ class GenericLiteLLMParams(BaseModel):
input_cost_per_second: Optional[float] = None input_cost_per_second: Optional[float] = None
output_cost_per_second: Optional[float] = None output_cost_per_second: Optional[float] = None
max_file_size_mb: Optional[float] = None
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
def __init__( def __init__(
@ -185,6 +187,7 @@ class GenericLiteLLMParams(BaseModel):
output_cost_per_token: Optional[float] = None, output_cost_per_token: Optional[float] = None,
input_cost_per_second: Optional[float] = None, input_cost_per_second: Optional[float] = None,
output_cost_per_second: Optional[float] = None, output_cost_per_second: Optional[float] = None,
max_file_size_mb: Optional[float] = None,
**params, **params,
): ):
args = locals() args = locals()
@ -243,6 +246,9 @@ class LiteLLM_Params(GenericLiteLLMParams):
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None, aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None, aws_region_name: Optional[str] = None,
# OpenAI / Azure Whisper
# set a max-size of file that can be passed to litellm proxy
max_file_size_mb: Optional[float] = None,
**params, **params,
): ):
args = locals() args = locals()