mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
[Bug Fix] Timestamp Granularities are not properly passed to whisper in Azure (#10299)
* test fix form data parsing * test fix form data parsing * fix types
This commit is contained in:
parent
5de101ab7b
commit
164017119d
3 changed files with 90 additions and 18 deletions
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import orjson
|
||||
from fastapi import Request, UploadFile, status
|
||||
|
@ -147,10 +147,10 @@ def check_file_size_under_limit(
|
|||
|
||||
if llm_router is not None and request_data["model"] in router_model_names:
|
||||
try:
|
||||
deployment: Optional[
|
||||
Deployment
|
||||
] = llm_router.get_deployment_by_model_group_name(
|
||||
model_group_name=request_data["model"]
|
||||
deployment: Optional[Deployment] = (
|
||||
llm_router.get_deployment_by_model_group_name(
|
||||
model_group_name=request_data["model"]
|
||||
)
|
||||
)
|
||||
if (
|
||||
deployment
|
||||
|
@ -185,3 +185,23 @@ def check_file_size_under_limit(
|
|||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def get_form_data(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
Read form data from request
|
||||
|
||||
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
||||
"""
|
||||
form = await request.form()
|
||||
form_data = dict(form)
|
||||
parsed_form_data: dict[str, Any] = {}
|
||||
for key, value in form_data.items():
|
||||
|
||||
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
||||
if key.endswith("[]"):
|
||||
clean_key = key[:-2]
|
||||
parsed_form_data.setdefault(clean_key, []).append(value)
|
||||
else:
|
||||
parsed_form_data[key] = value
|
||||
return parsed_form_data
|
||||
|
|
|
@ -179,6 +179,7 @@ from litellm.proxy.common_utils.html_forms.ui_login import html_form
|
|||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
_read_request_body,
|
||||
check_file_size_under_limit,
|
||||
get_form_data,
|
||||
)
|
||||
from litellm.proxy.common_utils.load_config_utils import (
|
||||
get_config_file_contents_from_gcs,
|
||||
|
@ -804,9 +805,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
|||
dual_cache=user_api_key_cache
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
||||
redis_usage_cache: Optional[
|
||||
RedisCache
|
||||
] = None # redis cache used for tracking spend, tpm/rpm limits
|
||||
redis_usage_cache: Optional[RedisCache] = (
|
||||
None # redis cache used for tracking spend, tpm/rpm limits
|
||||
)
|
||||
user_custom_auth = None
|
||||
user_custom_key_generate = None
|
||||
user_custom_sso = None
|
||||
|
@ -1132,9 +1133,9 @@ async def update_cache( # noqa: PLR0915
|
|||
_id = "team_id:{}".format(team_id)
|
||||
try:
|
||||
# Fetch the existing cost for the given user
|
||||
existing_spend_obj: Optional[
|
||||
LiteLLM_TeamTable
|
||||
] = await user_api_key_cache.async_get_cache(key=_id)
|
||||
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
|
||||
await user_api_key_cache.async_get_cache(key=_id)
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
# do nothing if team not in api key cache
|
||||
return
|
||||
|
@ -2806,9 +2807,9 @@ async def initialize( # noqa: PLR0915
|
|||
user_api_base = api_base
|
||||
dynamic_config[user_model]["api_base"] = api_base
|
||||
if api_version:
|
||||
os.environ[
|
||||
"AZURE_API_VERSION"
|
||||
] = api_version # set this for azure - litellm can read this from the env
|
||||
os.environ["AZURE_API_VERSION"] = (
|
||||
api_version # set this for azure - litellm can read this from the env
|
||||
)
|
||||
if max_tokens: # model-specific param
|
||||
dynamic_config[user_model]["max_tokens"] = max_tokens
|
||||
if temperature: # model-specific param
|
||||
|
@ -4120,7 +4121,7 @@ async def audio_transcriptions(
|
|||
data: Dict = {}
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
form_data = await request.form()
|
||||
form_data = await get_form_data(request)
|
||||
data = {key: value for key, value in form_data.items() if key != "file"}
|
||||
|
||||
# Include original request and headers in the data
|
||||
|
@ -7758,9 +7759,9 @@ async def get_config_list(
|
|||
hasattr(sub_field_info, "description")
|
||||
and sub_field_info.description is not None
|
||||
):
|
||||
nested_fields[
|
||||
idx
|
||||
].field_description = sub_field_info.description
|
||||
nested_fields[idx].field_description = (
|
||||
sub_field_info.description
|
||||
)
|
||||
idx += 1
|
||||
|
||||
_stored_in_db = None
|
||||
|
|
|
@ -18,6 +18,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
|||
_read_request_body,
|
||||
_safe_get_request_parsed_body,
|
||||
_safe_set_request_parsed_body,
|
||||
get_form_data,
|
||||
)
|
||||
|
||||
|
||||
|
@ -147,3 +148,53 @@ async def test_circular_reference_handling():
|
|||
assert (
|
||||
"proxy_server_request" not in result2
|
||||
) # This will pass, showing the cache pollution
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_form_data():
|
||||
"""
|
||||
Test that get_form_data correctly handles form data with array notation.
|
||||
Tests audio transcription parameters as a specific example.
|
||||
"""
|
||||
# Create a mock request with transcription form data
|
||||
mock_request = MagicMock()
|
||||
|
||||
# Create mock form data with array notation for timestamp_granularities
|
||||
mock_form_data = {
|
||||
"file": "file_object", # In a real request this would be an UploadFile
|
||||
"model": "gpt-4o-transcribe",
|
||||
"include[]": "logprobs", # Array notation
|
||||
"language": "en",
|
||||
"prompt": "Transcribe this audio file",
|
||||
"response_format": "json",
|
||||
"stream": "false",
|
||||
"temperature": "0.2",
|
||||
"timestamp_granularities[]": "word", # First array item
|
||||
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
|
||||
}
|
||||
|
||||
# Mock the form method to return the test data
|
||||
mock_request.form = AsyncMock(return_value=mock_form_data)
|
||||
|
||||
# Call the function being tested
|
||||
result = await get_form_data(mock_request)
|
||||
|
||||
# Verify regular form fields are preserved
|
||||
assert result["file"] == "file_object"
|
||||
assert result["model"] == "gpt-4o-transcribe"
|
||||
assert result["language"] == "en"
|
||||
assert result["prompt"] == "Transcribe this audio file"
|
||||
assert result["response_format"] == "json"
|
||||
assert result["stream"] == "false"
|
||||
assert result["temperature"] == "0.2"
|
||||
|
||||
# Verify array fields are correctly parsed
|
||||
assert "include" in result
|
||||
assert isinstance(result["include"], list)
|
||||
assert "logprobs" in result["include"]
|
||||
|
||||
assert "timestamp_granularities" in result
|
||||
assert isinstance(result["timestamp_granularities"], list)
|
||||
# Note: In a real MultiDict, both values would be present
|
||||
# But in our mock dictionary the second value overwrites the first
|
||||
assert "segment" in result["timestamp_granularities"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue