[Bug Fix] Timestamp Granularities are not properly passed to whisper in Azure (#10299)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 23s
Helm unit test / unit-test (push) Successful in 29s

* test fix form data parsing

* test fix form data parsing

* fix types
This commit is contained in:
Ishaan Jaff 2025-04-24 18:57:11 -07:00 committed by GitHub
parent 5de101ab7b
commit 164017119d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 90 additions and 18 deletions

View file

@ -1,5 +1,5 @@
import json import json
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
import orjson import orjson
from fastapi import Request, UploadFile, status from fastapi import Request, UploadFile, status
@ -147,10 +147,10 @@ def check_file_size_under_limit(
if llm_router is not None and request_data["model"] in router_model_names: if llm_router is not None and request_data["model"] in router_model_names:
try: try:
deployment: Optional[ deployment: Optional[Deployment] = (
Deployment llm_router.get_deployment_by_model_group_name(
] = llm_router.get_deployment_by_model_group_name( model_group_name=request_data["model"]
model_group_name=request_data["model"] )
) )
if ( if (
deployment deployment
@ -185,3 +185,23 @@ def check_file_size_under_limit(
) )
return True return True
async def get_form_data(request: Request) -> Dict[str, Any]:
"""
Read form data from request
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
"""
form = await request.form()
form_data = dict(form)
parsed_form_data: dict[str, Any] = {}
for key, value in form_data.items():
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
if key.endswith("[]"):
clean_key = key[:-2]
parsed_form_data.setdefault(clean_key, []).append(value)
else:
parsed_form_data[key] = value
return parsed_form_data

View file

@ -179,6 +179,7 @@ from litellm.proxy.common_utils.html_forms.ui_login import html_form
from litellm.proxy.common_utils.http_parsing_utils import ( from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body, _read_request_body,
check_file_size_under_limit, check_file_size_under_limit,
get_form_data,
) )
from litellm.proxy.common_utils.load_config_utils import ( from litellm.proxy.common_utils.load_config_utils import (
get_config_file_contents_from_gcs, get_config_file_contents_from_gcs,
@ -804,9 +805,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache dual_cache=user_api_key_cache
) )
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[ redis_usage_cache: Optional[RedisCache] = (
RedisCache None # redis cache used for tracking spend, tpm/rpm limits
] = None # redis cache used for tracking spend, tpm/rpm limits )
user_custom_auth = None user_custom_auth = None
user_custom_key_generate = None user_custom_key_generate = None
user_custom_sso = None user_custom_sso = None
@ -1132,9 +1133,9 @@ async def update_cache( # noqa: PLR0915
_id = "team_id:{}".format(team_id) _id = "team_id:{}".format(team_id)
try: try:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
existing_spend_obj: Optional[ existing_spend_obj: Optional[LiteLLM_TeamTable] = (
LiteLLM_TeamTable await user_api_key_cache.async_get_cache(key=_id)
] = await user_api_key_cache.async_get_cache(key=_id) )
if existing_spend_obj is None: if existing_spend_obj is None:
# do nothing if team not in api key cache # do nothing if team not in api key cache
return return
@ -2806,9 +2807,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base dynamic_config[user_model]["api_base"] = api_base
if api_version: if api_version:
os.environ[ os.environ["AZURE_API_VERSION"] = (
"AZURE_API_VERSION" api_version # set this for azure - litellm can read this from the env
] = api_version # set this for azure - litellm can read this from the env )
if max_tokens: # model-specific param if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param if temperature: # model-specific param
@ -4120,7 +4121,7 @@ async def audio_transcriptions(
data: Dict = {} data: Dict = {}
try: try:
# Use orjson to parse JSON data, orjson speeds up requests significantly # Use orjson to parse JSON data, orjson speeds up requests significantly
form_data = await request.form() form_data = await get_form_data(request)
data = {key: value for key, value in form_data.items() if key != "file"} data = {key: value for key, value in form_data.items() if key != "file"}
# Include original request and headers in the data # Include original request and headers in the data
@ -7758,9 +7759,9 @@ async def get_config_list(
hasattr(sub_field_info, "description") hasattr(sub_field_info, "description")
and sub_field_info.description is not None and sub_field_info.description is not None
): ):
nested_fields[ nested_fields[idx].field_description = (
idx sub_field_info.description
].field_description = sub_field_info.description )
idx += 1 idx += 1
_stored_in_db = None _stored_in_db = None

View file

@ -18,6 +18,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body, _read_request_body,
_safe_get_request_parsed_body, _safe_get_request_parsed_body,
_safe_set_request_parsed_body, _safe_set_request_parsed_body,
get_form_data,
) )
@ -147,3 +148,53 @@ async def test_circular_reference_handling():
assert ( assert (
"proxy_server_request" not in result2 "proxy_server_request" not in result2
) # This will pass, showing the cache pollution ) # This will pass, showing the cache pollution
@pytest.mark.asyncio
async def test_get_form_data():
"""
Test that get_form_data correctly handles form data with array notation.
Tests audio transcription parameters as a specific example.
"""
# Create a mock request with transcription form data
mock_request = MagicMock()
# Create mock form data with array notation for timestamp_granularities
mock_form_data = {
"file": "file_object", # In a real request this would be an UploadFile
"model": "gpt-4o-transcribe",
"include[]": "logprobs", # Array notation
"language": "en",
"prompt": "Transcribe this audio file",
"response_format": "json",
"stream": "false",
"temperature": "0.2",
"timestamp_granularities[]": "word", # First array item
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
}
# Mock the form method to return the test data
mock_request.form = AsyncMock(return_value=mock_form_data)
# Call the function being tested
result = await get_form_data(mock_request)
# Verify regular form fields are preserved
assert result["file"] == "file_object"
assert result["model"] == "gpt-4o-transcribe"
assert result["language"] == "en"
assert result["prompt"] == "Transcribe this audio file"
assert result["response_format"] == "json"
assert result["stream"] == "false"
assert result["temperature"] == "0.2"
# Verify array fields are correctly parsed
assert "include" in result
assert isinstance(result["include"], list)
assert "logprobs" in result["include"]
assert "timestamp_granularities" in result
assert isinstance(result["timestamp_granularities"], list)
# Note: In a real MultiDict, both values would be present
# But in our mock dictionary the second value overwrites the first
assert "segment" in result["timestamp_granularities"]