[Bug Fix] Timestamp Granularities are not properly passed to whisper in Azure (#10299)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 23s
Helm unit test / unit-test (push) Successful in 29s

* test fix form data parsing

* test fix form data parsing

* fix types
This commit is contained in:
Ishaan Jaff 2025-04-24 18:57:11 -07:00 committed by GitHub
parent 5de101ab7b
commit 164017119d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 90 additions and 18 deletions

View file

@ -1,5 +1,5 @@
import json
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
import orjson
from fastapi import Request, UploadFile, status
@ -147,11 +147,11 @@ def check_file_size_under_limit(
if llm_router is not None and request_data["model"] in router_model_names:
try:
deployment: Optional[
Deployment
] = llm_router.get_deployment_by_model_group_name(
deployment: Optional[Deployment] = (
llm_router.get_deployment_by_model_group_name(
model_group_name=request_data["model"]
)
)
if (
deployment
and deployment.litellm_params is not None
@ -185,3 +185,23 @@ def check_file_size_under_limit(
)
return True
async def get_form_data(request: Request) -> Dict[str, Any]:
"""
Read form data from request
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
"""
form = await request.form()
form_data = dict(form)
parsed_form_data: dict[str, Any] = {}
for key, value in form_data.items():
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
if key.endswith("[]"):
clean_key = key[:-2]
parsed_form_data.setdefault(clean_key, []).append(value)
else:
parsed_form_data[key] = value
return parsed_form_data

View file

@ -179,6 +179,7 @@ from litellm.proxy.common_utils.html_forms.ui_login import html_form
from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
check_file_size_under_limit,
get_form_data,
)
from litellm.proxy.common_utils.load_config_utils import (
get_config_file_contents_from_gcs,
@ -804,9 +805,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache
)
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[
RedisCache
] = None # redis cache used for tracking spend, tpm/rpm limits
redis_usage_cache: Optional[RedisCache] = (
None # redis cache used for tracking spend, tpm/rpm limits
)
user_custom_auth = None
user_custom_key_generate = None
user_custom_sso = None
@ -1132,9 +1133,9 @@ async def update_cache( # noqa: PLR0915
_id = "team_id:{}".format(team_id)
try:
# Fetch the existing cost for the given user
existing_spend_obj: Optional[
LiteLLM_TeamTable
] = await user_api_key_cache.async_get_cache(key=_id)
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
await user_api_key_cache.async_get_cache(key=_id)
)
if existing_spend_obj is None:
# do nothing if team not in api key cache
return
@ -2806,9 +2807,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base
if api_version:
os.environ[
"AZURE_API_VERSION"
] = api_version # set this for azure - litellm can read this from the env
os.environ["AZURE_API_VERSION"] = (
api_version # set this for azure - litellm can read this from the env
)
if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param
@ -4120,7 +4121,7 @@ async def audio_transcriptions(
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
form_data = await request.form()
form_data = await get_form_data(request)
data = {key: value for key, value in form_data.items() if key != "file"}
# Include original request and headers in the data
@ -7758,9 +7759,9 @@ async def get_config_list(
hasattr(sub_field_info, "description")
and sub_field_info.description is not None
):
nested_fields[
idx
].field_description = sub_field_info.description
nested_fields[idx].field_description = (
sub_field_info.description
)
idx += 1
_stored_in_db = None

View file

@ -18,6 +18,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
_safe_get_request_parsed_body,
_safe_set_request_parsed_body,
get_form_data,
)
@ -147,3 +148,53 @@ async def test_circular_reference_handling():
assert (
"proxy_server_request" not in result2
) # This will pass, showing the cache pollution
@pytest.mark.asyncio
async def test_get_form_data():
"""
Test that get_form_data correctly handles form data with array notation.
Tests audio transcription parameters as a specific example.
"""
# Create a mock request with transcription form data
mock_request = MagicMock()
# Create mock form data with array notation for timestamp_granularities
mock_form_data = {
"file": "file_object", # In a real request this would be an UploadFile
"model": "gpt-4o-transcribe",
"include[]": "logprobs", # Array notation
"language": "en",
"prompt": "Transcribe this audio file",
"response_format": "json",
"stream": "false",
"temperature": "0.2",
"timestamp_granularities[]": "word", # First array item
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
}
# Mock the form method to return the test data
mock_request.form = AsyncMock(return_value=mock_form_data)
# Call the function being tested
result = await get_form_data(mock_request)
# Verify regular form fields are preserved
assert result["file"] == "file_object"
assert result["model"] == "gpt-4o-transcribe"
assert result["language"] == "en"
assert result["prompt"] == "Transcribe this audio file"
assert result["response_format"] == "json"
assert result["stream"] == "false"
assert result["temperature"] == "0.2"
# Verify array fields are correctly parsed
assert "include" in result
assert isinstance(result["include"], list)
assert "logprobs" in result["include"]
assert "timestamp_granularities" in result
assert isinstance(result["timestamp_granularities"], list)
# Note: In a real MultiDict, both values would be present
# But in our mock dictionary the second value overwrites the first
assert "segment" in result["timestamp_granularities"]