Add OpenAI gpt-4o-transcribe support (#9517)

* refactor: introduce new transformation config for gpt-4o-transcribe models

* refactor: expose new transformation configs for audio transcription

* ci: fix config yml

* feat(openai/transcriptions): support provider config transformation on openai audio transcriptions

allows gpt-4o and whisper audio transformation to work as expected

* refactor: migrate fireworks ai + deepgram to new transform request pattern

* feat(openai/): working support for gpt-4o-audio-transcribe

* build(model_prices_and_context_window.json): add gpt-4o-transcribe to model cost map

* build(model_prices_and_context_window.json): specify what endpoints are supported for `/audio/transcriptions`

* fix(get_supported_openai_params.py): fix return

* refactor(deepgram/): migrate unit test to deepgram handler

* refactor: cleanup unused imports

* fix(get_supported_openai_params.py): fix linting error

* test: update test
This commit is contained in:
Krish Dholakia 2025-03-26 23:10:25 -07:00 committed by GitHub
parent 109add7946
commit c0845fec1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 402 additions and 92 deletions

View file

@ -950,6 +950,12 @@ openaiOSeriesConfig = OpenAIOSeriesConfig()
from .llms.openai.chat.gpt_transformation import (
OpenAIGPTConfig,
)
from .llms.openai.transcriptions.whisper_transformation import (
OpenAIWhisperAudioTranscriptionConfig,
)
from .llms.openai.transcriptions.gpt_transformation import (
OpenAIGPTAudioTranscriptionConfig,
)
openAIGPTConfig = OpenAIGPTConfig()
from .llms.openai.chat.gpt_audio_transformation import (

View file

@ -79,6 +79,22 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "maritalk":
return litellm.MaritalkConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "openai":
if request_type == "transcription":
transcription_provider_config = (
litellm.ProviderConfigManager.get_provider_audio_transcription_config(
model=model, provider=LlmProviders.OPENAI
)
)
if isinstance(
transcription_provider_config, litellm.OpenAIGPTAudioTranscriptionConfig
):
return transcription_provider_config.get_supported_openai_params(
model=model
)
else:
raise ValueError(
f"Unsupported provider config: {transcription_provider_config} for model: {model}"
)
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "azure":
if litellm.AzureOpenAIO1Config().is_o_series_model(model=model):

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
@ -8,7 +8,7 @@ from litellm.types.llms.openai import (
AllMessageValues,
OpenAIAudioTranscriptionOptionalParams,
)
from litellm.types.utils import ModelResponse
from litellm.types.utils import FileTypes, ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -42,6 +42,18 @@ class BaseAudioTranscriptionConfig(BaseConfig, ABC):
"""
return api_base or ""
@abstractmethod
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> Union[dict, bytes]:
raise NotImplementedError(
"AudioTranscriptionConfig needs a request transformation for audio transcription models"
)
def transform_request(
self,
model: str,

View file

@ -1,4 +1,3 @@
import io
import json
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
@ -8,6 +7,9 @@ import litellm
import litellm.litellm_core_utils
import litellm.types
import litellm.types.utils
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
@ -852,54 +854,12 @@ class BaseLLMHTTPHandler:
request_data=request_data,
)
def handle_audio_file(self, audio_file: FileTypes) -> bytes:
"""
Processes the audio file input based on its type and returns the binary data.
Args:
audio_file: Can be a file path (str), a tuple (filename, file_content), or binary data (bytes).
Returns:
The binary data of the audio file.
"""
binary_data: bytes # Explicitly declare the type
# Handle the audio file based on type
if isinstance(audio_file, str):
# If it's a file path
with open(audio_file, "rb") as f:
binary_data = f.read() # `f.read()` always returns `bytes`
elif isinstance(audio_file, tuple):
# Handle tuple case
_, file_content = audio_file[:2]
if isinstance(file_content, str):
with open(file_content, "rb") as f:
binary_data = f.read() # `f.read()` always returns `bytes`
elif isinstance(file_content, bytes):
binary_data = file_content
else:
raise TypeError(
f"Unexpected type in tuple: {type(file_content)}. Expected str or bytes."
)
elif isinstance(audio_file, bytes):
# Assume it's already binary data
binary_data = audio_file
elif isinstance(audio_file, io.BufferedReader) or isinstance(
audio_file, io.BytesIO
):
# Handle file-like objects
binary_data = audio_file.read()
else:
raise TypeError(f"Unsupported type for audio_file: {type(audio_file)}")
return binary_data
def audio_transcriptions(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
@ -910,11 +870,8 @@ class BaseLLMHTTPHandler:
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
atranscription: bool = False,
headers: dict = {},
litellm_params: dict = {},
provider_config: Optional[BaseAudioTranscriptionConfig] = None,
) -> TranscriptionResponse:
provider_config = ProviderConfigManager.get_provider_audio_transcription_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
if provider_config is None:
raise ValueError(
f"No provider config found for model: {model} and provider: {custom_llm_provider}"
@ -938,7 +895,18 @@ class BaseLLMHTTPHandler:
)
# Handle the audio file based on type
binary_data = self.handle_audio_file(audio_file)
data = provider_config.transform_audio_transcription_request(
model=model,
audio_file=audio_file,
optional_params=optional_params,
litellm_params=litellm_params,
)
binary_data: Optional[bytes] = None
json_data: Optional[dict] = None
if isinstance(data, bytes):
binary_data = data
else:
json_data = data
try:
# Make the POST request
@ -946,6 +914,7 @@ class BaseLLMHTTPHandler:
url=complete_url,
headers=headers,
content=binary_data,
json=json_data,
timeout=timeout,
)
except Exception as e:

View file

@ -2,6 +2,7 @@
Translates from OpenAI's `/v1/audio/transcriptions` to Deepgram's `/v1/listen`
"""
import io
from typing import List, Optional, Union
from httpx import Headers, Response
@ -12,7 +13,7 @@ from litellm.types.llms.openai import (
AllMessageValues,
OpenAIAudioTranscriptionOptionalParams,
)
from litellm.types.utils import TranscriptionResponse
from litellm.types.utils import FileTypes, TranscriptionResponse
from ...base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
@ -47,6 +48,55 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
message=error_message, status_code=status_code, headers=headers
)
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> Union[dict, bytes]:
"""
Processes the audio file input based on its type and returns the binary data.
Args:
audio_file: Can be a file path (str), a tuple (filename, file_content), or binary data (bytes).
Returns:
The binary data of the audio file.
"""
binary_data: bytes # Explicitly declare the type
# Handle the audio file based on type
if isinstance(audio_file, str):
# If it's a file path
with open(audio_file, "rb") as f:
binary_data = f.read() # `f.read()` always returns `bytes`
elif isinstance(audio_file, tuple):
# Handle tuple case
_, file_content = audio_file[:2]
if isinstance(file_content, str):
with open(file_content, "rb") as f:
binary_data = f.read() # `f.read()` always returns `bytes`
elif isinstance(file_content, bytes):
binary_data = file_content
else:
raise TypeError(
f"Unexpected type in tuple: {type(file_content)}. Expected str or bytes."
)
elif isinstance(audio_file, bytes):
# Assume it's already binary data
binary_data = audio_file
elif isinstance(audio_file, io.BufferedReader) or isinstance(
audio_file, io.BytesIO
):
# Handle file-like objects
binary_data = audio_file.read()
else:
raise TypeError(f"Unsupported type for audio_file: {type(audio_file)}")
return binary_data
def transform_audio_transcription_response(
self,
model: str,

View file

@ -2,27 +2,16 @@ from typing import List
from litellm.types.llms.openai import OpenAIAudioTranscriptionOptionalParams
from ...base_llm.audio_transcription.transformation import BaseAudioTranscriptionConfig
from ...openai.transcriptions.whisper_transformation import (
OpenAIWhisperAudioTranscriptionConfig,
)
from ..common_utils import FireworksAIMixin
class FireworksAIAudioTranscriptionConfig(
FireworksAIMixin, BaseAudioTranscriptionConfig
FireworksAIMixin, OpenAIWhisperAudioTranscriptionConfig
):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
return ["language", "prompt", "response_format", "timestamp_granularities"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params

View file

@ -0,0 +1,34 @@
from typing import List
from litellm.types.llms.openai import OpenAIAudioTranscriptionOptionalParams
from litellm.types.utils import FileTypes
from .whisper_transformation import OpenAIWhisperAudioTranscriptionConfig
class OpenAIGPTAudioTranscriptionConfig(OpenAIWhisperAudioTranscriptionConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
"""
Get the supported OpenAI params for the `gpt-4o-transcribe` models
"""
return [
"language",
"prompt",
"response_format",
"temperature",
"include",
]
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> dict:
"""
Transform the audio transcription request
"""
return {"model": model, "file": audio_file, **optional_params}

View file

@ -7,6 +7,9 @@ from pydantic import BaseModel
import litellm
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
from litellm.types.utils import FileTypes
from litellm.utils import (
TranscriptionResponse,
@ -75,6 +78,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
@ -83,16 +87,24 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
api_base: Optional[str],
client=None,
atranscription: bool = False,
provider_config: Optional[BaseAudioTranscriptionConfig] = None,
) -> TranscriptionResponse:
data = {"model": model, "file": audio_file, **optional_params}
if "response_format" not in data or (
data["response_format"] == "text" or data["response_format"] == "json"
):
data["response_format"] = (
"verbose_json" # ensures 'duration' is received - used for cost calculation
"""
Handle audio transcription request
"""
if provider_config is not None:
data = provider_config.transform_audio_transcription_request(
model=model,
audio_file=audio_file,
optional_params=optional_params,
litellm_params=litellm_params,
)
if isinstance(data, bytes):
raise ValueError("OpenAI transformation route requires a dict")
else:
data = {"model": model, "file": audio_file, **optional_params}
if atranscription is True:
return self.async_audio_transcriptions( # type: ignore
audio_file=audio_file,

View file

@ -0,0 +1,97 @@
from typing import List, Optional, Union
from httpx import Headers
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIAudioTranscriptionOptionalParams,
)
from litellm.types.utils import FileTypes
from ..common_utils import OpenAIError
class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
"""
Get the supported OpenAI params for the `whisper-1` models
"""
return [
"language",
"prompt",
"response_format",
"temperature",
"timestamp_granularities",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map the OpenAI params to the Whisper params
"""
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
api_key = api_key or get_secret_str("OPENAI_API_KEY")
auth_header = {
"Authorization": f"Bearer {api_key}",
}
headers.update(auth_header)
return headers
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> dict:
"""
Transform the audio transcription request
"""
data = {"model": model, "file": audio_file, **optional_params}
if "response_format" not in data or (
data["response_format"] == "text" or data["response_format"] == "json"
):
data["response_format"] = (
"verbose_json" # ensures 'duration' is received - used for cost calculation
)
return data
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return OpenAIError(
status_code=status_code,
message=error_message,
headers=headers,
)

View file

@ -5095,6 +5095,12 @@ def transcription(
response: Optional[
Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]
] = None
provider_config = ProviderConfigManager.get_provider_audio_transcription_config(
model=model,
provider=LlmProviders(custom_llm_provider),
)
if custom_llm_provider == "azure":
# azure configs
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
@ -5161,12 +5167,15 @@ def transcription(
max_retries=max_retries,
api_base=api_base,
api_key=api_key,
provider_config=provider_config,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "deepgram":
response = base_llm_http_handler.audio_transcriptions(
model=model,
audio_file=file,
optional_params=optional_params,
litellm_params=litellm_params_dict,
model_response=model_response,
atranscription=atranscription,
client=(
@ -5185,6 +5194,7 @@ def transcription(
api_key=api_key,
custom_llm_provider="deepgram",
headers={},
provider_config=provider_config,
)
if response is None:
raise ValueError("Unmapped provider passed in. Unable to get the response.")

View file

@ -1176,21 +1176,40 @@
"output_cost_per_pixel": 0.0,
"litellm_provider": "openai"
},
"gpt-4o-transcribe": {
"mode": "audio_transcription",
"input_cost_per_token": 0.0000025,
"input_cost_per_audio_token": 0.000006,
"output_cost_per_token": 0.00001,
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/transcriptions"]
},
"gpt-4o-mini-transcribe": {
"mode": "audio_transcription",
"input_cost_per_token": 0.00000125,
"input_cost_per_audio_token": 0.000003,
"output_cost_per_token": 0.000005,
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/transcriptions"]
},
"whisper-1": {
"mode": "audio_transcription",
"input_cost_per_second": 0.0001,
"output_cost_per_second": 0.0001,
"litellm_provider": "openai"
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/transcriptions"]
},
"tts-1": {
"mode": "audio_speech",
"input_cost_per_character": 0.000015,
"litellm_provider": "openai"
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/speech"]
},
"tts-1-hd": {
"mode": "audio_speech",
"input_cost_per_character": 0.000030,
"litellm_provider": "openai"
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/speech"]
},
"azure/gpt-4o-mini-realtime-preview-2024-12-17": {
"max_tokens": 4096,

View file

@ -9,6 +9,10 @@ model_list:
litellm_params:
model: gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY
- model_name: "openai/*"
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: "bedrock-nova"
litellm_params:
model: us.amazon.nova-pro-v1:0

View file

@ -779,7 +779,12 @@ class LiteLLMFineTuningJobCreate(FineTuningJobCreate):
AllEmbeddingInputValues = Union[str, List[str], List[int], List[List[int]]]
OpenAIAudioTranscriptionOptionalParams = Literal[
"language", "prompt", "temperature", "response_format", "timestamp_granularities"
"language",
"prompt",
"temperature",
"response_format",
"timestamp_granularities",
"include",
]

View file

@ -6364,6 +6364,11 @@ class ProviderConfigManager:
return litellm.FireworksAIAudioTranscriptionConfig()
elif litellm.LlmProviders.DEEPGRAM == provider:
return litellm.DeepgramAudioTranscriptionConfig()
elif litellm.LlmProviders.OPENAI == provider:
if "gpt-4o" in model:
return litellm.OpenAIGPTAudioTranscriptionConfig()
else:
return litellm.OpenAIWhisperAudioTranscriptionConfig()
return None
@staticmethod

View file

@ -1176,21 +1176,40 @@
"output_cost_per_pixel": 0.0,
"litellm_provider": "openai"
},
"gpt-4o-transcribe": {
"mode": "audio_transcription",
"input_cost_per_token": 0.0000025,
"input_cost_per_audio_token": 0.000006,
"output_cost_per_token": 0.00001,
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/transcriptions"]
},
"gpt-4o-mini-transcribe": {
"mode": "audio_transcription",
"input_cost_per_token": 0.00000125,
"input_cost_per_audio_token": 0.000003,
"output_cost_per_token": 0.000005,
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/transcriptions"]
},
"whisper-1": {
"mode": "audio_transcription",
"input_cost_per_second": 0.0001,
"output_cost_per_second": 0.0001,
"litellm_provider": "openai"
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/transcriptions"]
},
"tts-1": {
"mode": "audio_speech",
"input_cost_per_character": 0.000015,
"litellm_provider": "openai"
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/speech"]
},
"tts-1-hd": {
"mode": "audio_speech",
"input_cost_per_character": 0.000030,
"litellm_provider": "openai"
"litellm_provider": "openai",
"supported_endpoints": ["/v1/audio/speech"]
},
"azure/gpt-4o-mini-realtime-preview-2024-12-17": {
"max_tokens": 4096,

View file

@ -1,47 +1,57 @@
import os
import io
import os
import pathlib
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../../../..")
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
import litellm
from litellm.llms.deepgram.audio_transcription.transformation import (
DeepgramAudioTranscriptionConfig,
)
@pytest.fixture
def test_bytes():
return b'litellm', b'litellm'
return b"litellm", b"litellm"
@pytest.fixture
def test_io_bytes(test_bytes):
return io.BytesIO(test_bytes[0]), test_bytes[1]
@pytest.fixture
def test_file():
pwd = os.path.dirname(os.path.realpath(__file__))
pwd_path = pathlib.Path(pwd)
test_root = pwd_path.parents[2]
test_root = pwd_path.parents[3]
print(f"test_root: {test_root}")
file_path = os.path.join(test_root, "gettysburg.wav")
f = open(file_path, "rb")
content = f.read()
f.seek(0)
return f, content
@pytest.mark.parametrize(
"fixture_name",
[
"test_bytes",
"test_io_bytes",
"test_file",
]
],
)
def test_audio_file_handling(fixture_name, request):
handler = BaseLLMHTTPHandler()
handler = DeepgramAudioTranscriptionConfig()
(audio_file, expected_output) = request.getfixturevalue(fixture_name)
assert expected_output == handler.handle_audio_file(audio_file)
assert expected_output == handler.transform_audio_transcription_request(
model="deepseek-audio-transcription",
audio_file=audio_file,
optional_params={},
litellm_params={},
)

View file

@ -2074,3 +2074,13 @@ def test_delta_object():
assert delta.role == "user"
assert not hasattr(delta, "thinking_blocks")
assert not hasattr(delta, "reasoning_content")
def test_get_provider_audio_transcription_config():
from litellm.utils import ProviderConfigManager
from litellm.types.utils import LlmProviders
for provider in LlmProviders:
config = ProviderConfigManager.get_provider_audio_transcription_config(
model="whisper-1", provider=provider
)

View file

@ -22,6 +22,7 @@ from litellm.types.llms.openai import (
ChatCompletionAnnotation,
ChatCompletionAnnotationURLCitation,
)
from base_audio_transcription_unit_tests import BaseLLMAudioTranscriptionTest
def test_openai_prediction_param():
@ -458,3 +459,13 @@ def test_openai_web_search_streaming():
# Assert this request has at-least one web search annotation
assert test_openai_web_search is not None
validate_web_search_annotations(test_openai_web_search)
class TestOpenAIGPT4OAudioTranscription(BaseLLMAudioTranscriptionTest):
def get_base_audio_transcription_call_args(self) -> dict:
return {
"model": "openai/gpt-4o-transcribe",
}
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.OPENAI

View file

@ -520,6 +520,8 @@ def test_aaamodel_prices_and_context_window_json_is_valid():
"/v1/images/variations",
"/v1/images/edits",
"/v1/batch",
"/v1/audio/transcriptions",
"/v1/audio/speech",
],
},
},

View file

@ -134,3 +134,33 @@ async def test_whisper_log_pre_call():
file=audio_file,
)
mock_log_pre_call.assert_called_once()
@pytest.mark.asyncio
async def test_whisper_log_pre_call():
from litellm.litellm_core_utils.litellm_logging import Logging
from datetime import datetime
from unittest.mock import patch, MagicMock
from litellm.integrations.custom_logger import CustomLogger
custom_logger = CustomLogger()
litellm.callbacks = [custom_logger]
with patch.object(custom_logger, "log_pre_api_call") as mock_log_pre_call:
await litellm.atranscription(
model="whisper-1",
file=audio_file,
)
mock_log_pre_call.assert_called_once()
@pytest.mark.asyncio
async def test_gpt_4o_transcribe():
from litellm.litellm_core_utils.litellm_logging import Logging
from datetime import datetime
from unittest.mock import patch, MagicMock
await litellm.atranscription(
model="openai/gpt-4o-transcribe", file=audio_file, response_format="json"
)