mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
[Bug Fix] Timestamp Granularities are not properly passed to whisper in Azure (#10299)
* test fix form data parsing * test fix form data parsing * fix types
This commit is contained in:
parent
5de101ab7b
commit
164017119d
3 changed files with 90 additions and 18 deletions
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from fastapi import Request, UploadFile, status
|
from fastapi import Request, UploadFile, status
|
||||||
|
@ -147,11 +147,11 @@ def check_file_size_under_limit(
|
||||||
|
|
||||||
if llm_router is not None and request_data["model"] in router_model_names:
|
if llm_router is not None and request_data["model"] in router_model_names:
|
||||||
try:
|
try:
|
||||||
deployment: Optional[
|
deployment: Optional[Deployment] = (
|
||||||
Deployment
|
llm_router.get_deployment_by_model_group_name(
|
||||||
] = llm_router.get_deployment_by_model_group_name(
|
|
||||||
model_group_name=request_data["model"]
|
model_group_name=request_data["model"]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
deployment
|
deployment
|
||||||
and deployment.litellm_params is not None
|
and deployment.litellm_params is not None
|
||||||
|
@ -185,3 +185,23 @@ def check_file_size_under_limit(
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def get_form_data(request: Request) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Read form data from request
|
||||||
|
|
||||||
|
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
||||||
|
"""
|
||||||
|
form = await request.form()
|
||||||
|
form_data = dict(form)
|
||||||
|
parsed_form_data: dict[str, Any] = {}
|
||||||
|
for key, value in form_data.items():
|
||||||
|
|
||||||
|
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
||||||
|
if key.endswith("[]"):
|
||||||
|
clean_key = key[:-2]
|
||||||
|
parsed_form_data.setdefault(clean_key, []).append(value)
|
||||||
|
else:
|
||||||
|
parsed_form_data[key] = value
|
||||||
|
return parsed_form_data
|
||||||
|
|
|
@ -179,6 +179,7 @@ from litellm.proxy.common_utils.html_forms.ui_login import html_form
|
||||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||||
_read_request_body,
|
_read_request_body,
|
||||||
check_file_size_under_limit,
|
check_file_size_under_limit,
|
||||||
|
get_form_data,
|
||||||
)
|
)
|
||||||
from litellm.proxy.common_utils.load_config_utils import (
|
from litellm.proxy.common_utils.load_config_utils import (
|
||||||
get_config_file_contents_from_gcs,
|
get_config_file_contents_from_gcs,
|
||||||
|
@ -804,9 +805,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
||||||
dual_cache=user_api_key_cache
|
dual_cache=user_api_key_cache
|
||||||
)
|
)
|
||||||
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
||||||
redis_usage_cache: Optional[
|
redis_usage_cache: Optional[RedisCache] = (
|
||||||
RedisCache
|
None # redis cache used for tracking spend, tpm/rpm limits
|
||||||
] = None # redis cache used for tracking spend, tpm/rpm limits
|
)
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
user_custom_key_generate = None
|
user_custom_key_generate = None
|
||||||
user_custom_sso = None
|
user_custom_sso = None
|
||||||
|
@ -1132,9 +1133,9 @@ async def update_cache( # noqa: PLR0915
|
||||||
_id = "team_id:{}".format(team_id)
|
_id = "team_id:{}".format(team_id)
|
||||||
try:
|
try:
|
||||||
# Fetch the existing cost for the given user
|
# Fetch the existing cost for the given user
|
||||||
existing_spend_obj: Optional[
|
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
|
||||||
LiteLLM_TeamTable
|
await user_api_key_cache.async_get_cache(key=_id)
|
||||||
] = await user_api_key_cache.async_get_cache(key=_id)
|
)
|
||||||
if existing_spend_obj is None:
|
if existing_spend_obj is None:
|
||||||
# do nothing if team not in api key cache
|
# do nothing if team not in api key cache
|
||||||
return
|
return
|
||||||
|
@ -2806,9 +2807,9 @@ async def initialize( # noqa: PLR0915
|
||||||
user_api_base = api_base
|
user_api_base = api_base
|
||||||
dynamic_config[user_model]["api_base"] = api_base
|
dynamic_config[user_model]["api_base"] = api_base
|
||||||
if api_version:
|
if api_version:
|
||||||
os.environ[
|
os.environ["AZURE_API_VERSION"] = (
|
||||||
"AZURE_API_VERSION"
|
api_version # set this for azure - litellm can read this from the env
|
||||||
] = api_version # set this for azure - litellm can read this from the env
|
)
|
||||||
if max_tokens: # model-specific param
|
if max_tokens: # model-specific param
|
||||||
dynamic_config[user_model]["max_tokens"] = max_tokens
|
dynamic_config[user_model]["max_tokens"] = max_tokens
|
||||||
if temperature: # model-specific param
|
if temperature: # model-specific param
|
||||||
|
@ -4120,7 +4121,7 @@ async def audio_transcriptions(
|
||||||
data: Dict = {}
|
data: Dict = {}
|
||||||
try:
|
try:
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
form_data = await request.form()
|
form_data = await get_form_data(request)
|
||||||
data = {key: value for key, value in form_data.items() if key != "file"}
|
data = {key: value for key, value in form_data.items() if key != "file"}
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
|
@ -7758,9 +7759,9 @@ async def get_config_list(
|
||||||
hasattr(sub_field_info, "description")
|
hasattr(sub_field_info, "description")
|
||||||
and sub_field_info.description is not None
|
and sub_field_info.description is not None
|
||||||
):
|
):
|
||||||
nested_fields[
|
nested_fields[idx].field_description = (
|
||||||
idx
|
sub_field_info.description
|
||||||
].field_description = sub_field_info.description
|
)
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
_stored_in_db = None
|
_stored_in_db = None
|
||||||
|
|
|
@ -18,6 +18,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||||
_read_request_body,
|
_read_request_body,
|
||||||
_safe_get_request_parsed_body,
|
_safe_get_request_parsed_body,
|
||||||
_safe_set_request_parsed_body,
|
_safe_set_request_parsed_body,
|
||||||
|
get_form_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,3 +148,53 @@ async def test_circular_reference_handling():
|
||||||
assert (
|
assert (
|
||||||
"proxy_server_request" not in result2
|
"proxy_server_request" not in result2
|
||||||
) # This will pass, showing the cache pollution
|
) # This will pass, showing the cache pollution
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_form_data():
|
||||||
|
"""
|
||||||
|
Test that get_form_data correctly handles form data with array notation.
|
||||||
|
Tests audio transcription parameters as a specific example.
|
||||||
|
"""
|
||||||
|
# Create a mock request with transcription form data
|
||||||
|
mock_request = MagicMock()
|
||||||
|
|
||||||
|
# Create mock form data with array notation for timestamp_granularities
|
||||||
|
mock_form_data = {
|
||||||
|
"file": "file_object", # In a real request this would be an UploadFile
|
||||||
|
"model": "gpt-4o-transcribe",
|
||||||
|
"include[]": "logprobs", # Array notation
|
||||||
|
"language": "en",
|
||||||
|
"prompt": "Transcribe this audio file",
|
||||||
|
"response_format": "json",
|
||||||
|
"stream": "false",
|
||||||
|
"temperature": "0.2",
|
||||||
|
"timestamp_granularities[]": "word", # First array item
|
||||||
|
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock the form method to return the test data
|
||||||
|
mock_request.form = AsyncMock(return_value=mock_form_data)
|
||||||
|
|
||||||
|
# Call the function being tested
|
||||||
|
result = await get_form_data(mock_request)
|
||||||
|
|
||||||
|
# Verify regular form fields are preserved
|
||||||
|
assert result["file"] == "file_object"
|
||||||
|
assert result["model"] == "gpt-4o-transcribe"
|
||||||
|
assert result["language"] == "en"
|
||||||
|
assert result["prompt"] == "Transcribe this audio file"
|
||||||
|
assert result["response_format"] == "json"
|
||||||
|
assert result["stream"] == "false"
|
||||||
|
assert result["temperature"] == "0.2"
|
||||||
|
|
||||||
|
# Verify array fields are correctly parsed
|
||||||
|
assert "include" in result
|
||||||
|
assert isinstance(result["include"], list)
|
||||||
|
assert "logprobs" in result["include"]
|
||||||
|
|
||||||
|
assert "timestamp_granularities" in result
|
||||||
|
assert isinstance(result["timestamp_granularities"], list)
|
||||||
|
# Note: In a real MultiDict, both values would be present
|
||||||
|
# But in our mock dictionary the second value overwrites the first
|
||||||
|
assert "segment" in result["timestamp_granularities"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue