Litellm dev 12 28 2024 p3 (#7464)

* feat(deepgram/): initial e2e support for deepgram stt

Uses deepgram's `/listen` endpoint to transcribe speech to text

 Closes https://github.com/BerriAI/litellm/issues/4875

* fix: fix linting errors

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-28 19:18:58 -08:00 committed by GitHub
parent 480d838790
commit ebc28b1921
10 changed files with 303 additions and 5 deletions

View file

@ -1107,6 +1107,9 @@ from .llms.cohere.chat.transformation import CohereChatConfig
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
from .llms.deepinfra.chat.transformation import DeepInfraConfig
from .llms.deepgram.audio_transcription.transformation import (
DeepgramAudioTranscriptionConfig,
)
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig

View file

@ -192,4 +192,12 @@ def get_supported_openai_params( # noqa: PLR0915
)
else:
return litellm.TritonConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepgram":
if request_type == "transcription":
return (
litellm.DeepgramAudioTranscriptionConfig().get_supported_openai_params(
model=model
)
)
return None

View file

@ -492,16 +492,18 @@ class HTTPHandler:
headers: Optional[dict] = None,
stream: bool = False,
timeout: Optional[Union[float, httpx.Timeout]] = None,
files: Optional[dict] = None,
content: Any = None,
):
try:
if timeout is not None:
req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout, files=files, content=content # type: ignore
)
else:
req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
"POST", url, data=data, json=json, params=params, headers=headers, files=files, content=content # type: ignore
)
response = self.client.send(req, stream=stream)
response.raise_for_status()
@ -513,7 +515,6 @@ class HTTPHandler:
llm_provider="litellm-httpx-handler",
)
except httpx.HTTPStatusError as e:
if stream is True:
setattr(e, "message", mask_sensitive_info(e.response.read()))
setattr(e, "text", mask_sensitive_info(e.response.read()))

View file

@ -1,3 +1,4 @@
import io
import json
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
@ -17,7 +18,7 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.utils import EmbeddingResponse
from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
if TYPE_CHECKING:
@ -667,6 +668,115 @@ class BaseLLMHTTPHandler:
request_data=request_data,
)
def handle_audio_file(self, audio_file: FileTypes) -> bytes:
"""
Processes the audio file input based on its type and returns the binary data.
Args:
audio_file: Can be a file path (str), a tuple (filename, file_content), or binary data (bytes).
Returns:
The binary data of the audio file.
"""
binary_data: bytes # Explicitly declare the type
# Handle the audio file based on type
if isinstance(audio_file, str):
# If it's a file path
with open(audio_file, "rb") as f:
binary_data = f.read() # `f.read()` always returns `bytes`
elif isinstance(audio_file, tuple):
# Handle tuple case
_, file_content = audio_file[:2]
if isinstance(file_content, str):
with open(file_content, "rb") as f:
binary_data = f.read() # `f.read()` always returns `bytes`
elif isinstance(file_content, bytes):
binary_data = file_content
else:
raise TypeError(
f"Unexpected type in tuple: {type(file_content)}. Expected str or bytes."
)
elif isinstance(audio_file, bytes):
# Assume it's already binary data
binary_data = audio_file
elif isinstance(audio_file, io.BufferedReader):
# Handle file-like objects
binary_data = audio_file.read()
else:
raise TypeError(f"Unsupported type for audio_file: {type(audio_file)}")
return binary_data
def audio_transcriptions(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
api_base: Optional[str],
custom_llm_provider: str,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
atranscription: bool = False,
headers: dict = {},
) -> TranscriptionResponse:
provider_config = ProviderConfigManager.get_provider_audio_transcription_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
if provider_config is None:
raise ValueError(
f"No provider config found for model: {model} and provider: {custom_llm_provider}"
)
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=[],
optional_params=optional_params,
)
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
complete_url = provider_config.get_complete_url(
api_base=api_base,
model=model,
optional_params=optional_params,
)
# Handle the audio file based on type
binary_data = self.handle_audio_file(audio_file)
try:
# Make the POST request
response = client.post(
url=complete_url,
headers=headers,
content=binary_data,
timeout=timeout,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
if isinstance(provider_config, litellm.DeepgramAudioTranscriptionConfig):
returned_response = provider_config.transform_audio_transcription_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
request_data={},
optional_params=optional_params,
litellm_params={},
api_key=api_key,
)
return returned_response
return model_response
def _handle_error(
self, e: Exception, provider_config: Union[BaseConfig, BaseRerankConfig]
):

View file

@ -0,0 +1,125 @@
"""
Translates from OpenAI's `/v1/audio/transcriptions` to Deepgram's `/v1/listen`
"""
from typing import List, Optional, Union
from httpx import Headers, Response
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIAudioTranscriptionOptionalParams,
)
from litellm.types.utils import TranscriptionResponse
from ...base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
LiteLLMLoggingObj,
)
from ..common_utils import DeepgramException
class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
return ["language"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return DeepgramException(
message=error_message, status_code=status_code, headers=headers
)
def transform_audio_transcription_response(
self,
model: str,
raw_response: Response,
model_response: TranscriptionResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
) -> TranscriptionResponse:
"""
Transforms the raw response from Deepgram to the TranscriptionResponse format
"""
try:
response_json = raw_response.json()
# Get the first alternative from the first channel
first_channel = response_json["results"]["channels"][0]
first_alternative = first_channel["alternatives"][0]
# Extract the full transcript
text = first_alternative["transcript"]
# Create TranscriptionResponse object
response = TranscriptionResponse(text=text)
# Add additional metadata matching OpenAI format
response["task"] = "transcribe"
response["language"] = (
"english" # Deepgram auto-detects but doesn't return language
)
response["duration"] = response_json["metadata"]["duration"]
# Transform words to match OpenAI format
if "words" in first_alternative:
response["words"] = [
{"word": word["word"], "start": word["start"], "end": word["end"]}
for word in first_alternative["words"]
]
# Store full response in hidden params
response._hidden_params = response_json
return response
except Exception as e:
raise ValueError(
f"Error transforming Deepgram response: {str(e)}\nResponse: {raw_response.text}"
)
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
if api_base is None:
api_base = "https://api.deepgram.com/v1"
api_base = api_base.rstrip("/") # Remove trailing slash if present
return f"{api_base}/listen?model={model}"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
api_key = api_key or get_secret_str("DEEPGRAM_API_KEY")
return {
"Authorization": f"Token {api_key}",
}

View file

@ -0,0 +1,5 @@
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class DeepgramException(BaseLLMException):
pass

View file

@ -4867,7 +4867,30 @@ def transcription(
api_base=api_base,
api_key=api_key,
)
elif custom_llm_provider == "deepgram":
response = base_llm_http_handler.audio_transcriptions(
model=model,
audio_file=file,
optional_params=optional_params,
model_response=model_response,
atranscription=atranscription,
client=(
client
if client is not None
and (
isinstance(client, HTTPHandler)
or isinstance(client, AsyncHTTPHandler)
)
else None
),
timeout=timeout,
max_retries=max_retries,
logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
custom_llm_provider="deepgram",
headers={},
)
if response is None:
raise ValueError("Unmapped provider passed in. Unable to get the response.")
return response

View file

@ -1788,6 +1788,7 @@ class LlmProviders(str, Enum):
LM_STUDIO = "lm_studio"
GALADRIEL = "galadriel"
INFINITY = "infinity"
DEEPGRAM = "deepgram"
class LiteLLMLoggingBaseClass:

View file

@ -6323,6 +6323,8 @@ class ProviderConfigManager:
) -> Optional[BaseAudioTranscriptionConfig]:
if litellm.LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIAudioTranscriptionConfig()
elif litellm.LlmProviders.DEEPGRAM == provider:
return litellm.DeepgramAudioTranscriptionConfig()
return None
@staticmethod

View file

@ -0,0 +1,20 @@
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from base_audio_transcription_unit_tests import BaseLLMAudioTranscriptionTest
class TestDeepgramAudioTranscription(BaseLLMAudioTranscriptionTest):
def get_base_audio_transcription_call_args(self) -> dict:
return {
"model": "deepgram/nova-2",
}
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.DEEPGRAM