Litellm dev 12 28 2024 p3 (#7464)

* feat(deepgram/): initial e2e support for deepgram stt

Uses deepgram's `/listen` endpoint to transcribe speech to text

 Closes https://github.com/BerriAI/litellm/issues/4875

* fix: fix linting errors

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-28 19:18:58 -08:00 committed by GitHub
parent 480d838790
commit ebc28b1921
10 changed files with 303 additions and 5 deletions

View file

@ -1107,6 +1107,9 @@ from .llms.cohere.chat.transformation import CohereChatConfig
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
from .llms.deepinfra.chat.transformation import DeepInfraConfig from .llms.deepinfra.chat.transformation import DeepInfraConfig
from .llms.deepgram.audio_transcription.transformation import (
DeepgramAudioTranscriptionConfig,
)
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig

View file

@ -192,4 +192,12 @@ def get_supported_openai_params( # noqa: PLR0915
) )
else: else:
return litellm.TritonConfig().get_supported_openai_params(model=model) return litellm.TritonConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepgram":
if request_type == "transcription":
return (
litellm.DeepgramAudioTranscriptionConfig().get_supported_openai_params(
model=model
)
)
return None return None

View file

@ -492,16 +492,18 @@ class HTTPHandler:
headers: Optional[dict] = None, headers: Optional[dict] = None,
stream: bool = False, stream: bool = False,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
files: Optional[dict] = None,
content: Any = None,
): ):
try: try:
if timeout is not None: if timeout is not None:
req = self.client.build_request( req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore "POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout, files=files, content=content # type: ignore
) )
else: else:
req = self.client.build_request( req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore "POST", url, data=data, json=json, params=params, headers=headers, files=files, content=content # type: ignore
) )
response = self.client.send(req, stream=stream) response = self.client.send(req, stream=stream)
response.raise_for_status() response.raise_for_status()
@ -513,7 +515,6 @@ class HTTPHandler:
llm_provider="litellm-httpx-handler", llm_provider="litellm-httpx-handler",
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if stream is True: if stream is True:
setattr(e, "message", mask_sensitive_info(e.response.read())) setattr(e, "message", mask_sensitive_info(e.response.read()))
setattr(e, "text", mask_sensitive_info(e.response.read())) setattr(e, "text", mask_sensitive_info(e.response.read()))

View file

@ -1,3 +1,4 @@
import io
import json import json
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
@ -17,7 +18,7 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.rerank import OptionalRerankParams, RerankResponse from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.utils import EmbeddingResponse from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
if TYPE_CHECKING: if TYPE_CHECKING:
@ -667,6 +668,115 @@ class BaseLLMHTTPHandler:
request_data=request_data, request_data=request_data,
) )
def handle_audio_file(self, audio_file: FileTypes) -> bytes:
"""
Processes the audio file input based on its type and returns the binary data.
Args:
audio_file: Can be a file path (str), a tuple (filename, file_content), or binary data (bytes).
Returns:
The binary data of the audio file.
"""
binary_data: bytes # Explicitly declare the type
# Handle the audio file based on type
if isinstance(audio_file, str):
# If it's a file path
with open(audio_file, "rb") as f:
binary_data = f.read() # `f.read()` always returns `bytes`
elif isinstance(audio_file, tuple):
# Handle tuple case
_, file_content = audio_file[:2]
if isinstance(file_content, str):
with open(file_content, "rb") as f:
binary_data = f.read() # `f.read()` always returns `bytes`
elif isinstance(file_content, bytes):
binary_data = file_content
else:
raise TypeError(
f"Unexpected type in tuple: {type(file_content)}. Expected str or bytes."
)
elif isinstance(audio_file, bytes):
# Assume it's already binary data
binary_data = audio_file
elif isinstance(audio_file, io.BufferedReader):
# Handle file-like objects
binary_data = audio_file.read()
else:
raise TypeError(f"Unsupported type for audio_file: {type(audio_file)}")
return binary_data
def audio_transcriptions(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
api_base: Optional[str],
custom_llm_provider: str,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
atranscription: bool = False,
headers: dict = {},
) -> TranscriptionResponse:
provider_config = ProviderConfigManager.get_provider_audio_transcription_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
if provider_config is None:
raise ValueError(
f"No provider config found for model: {model} and provider: {custom_llm_provider}"
)
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=[],
optional_params=optional_params,
)
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
complete_url = provider_config.get_complete_url(
api_base=api_base,
model=model,
optional_params=optional_params,
)
# Handle the audio file based on type
binary_data = self.handle_audio_file(audio_file)
try:
# Make the POST request
response = client.post(
url=complete_url,
headers=headers,
content=binary_data,
timeout=timeout,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
if isinstance(provider_config, litellm.DeepgramAudioTranscriptionConfig):
returned_response = provider_config.transform_audio_transcription_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
request_data={},
optional_params=optional_params,
litellm_params={},
api_key=api_key,
)
return returned_response
return model_response
def _handle_error( def _handle_error(
self, e: Exception, provider_config: Union[BaseConfig, BaseRerankConfig] self, e: Exception, provider_config: Union[BaseConfig, BaseRerankConfig]
): ):

View file

@ -0,0 +1,125 @@
"""
Translates from OpenAI's `/v1/audio/transcriptions` to Deepgram's `/v1/listen`
"""
from typing import List, Optional, Union
from httpx import Headers, Response
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIAudioTranscriptionOptionalParams,
)
from litellm.types.utils import TranscriptionResponse
from ...base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
LiteLLMLoggingObj,
)
from ..common_utils import DeepgramException
class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
return ["language"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return DeepgramException(
message=error_message, status_code=status_code, headers=headers
)
def transform_audio_transcription_response(
self,
model: str,
raw_response: Response,
model_response: TranscriptionResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
) -> TranscriptionResponse:
"""
Transforms the raw response from Deepgram to the TranscriptionResponse format
"""
try:
response_json = raw_response.json()
# Get the first alternative from the first channel
first_channel = response_json["results"]["channels"][0]
first_alternative = first_channel["alternatives"][0]
# Extract the full transcript
text = first_alternative["transcript"]
# Create TranscriptionResponse object
response = TranscriptionResponse(text=text)
# Add additional metadata matching OpenAI format
response["task"] = "transcribe"
response["language"] = (
"english" # Deepgram auto-detects but doesn't return language
)
response["duration"] = response_json["metadata"]["duration"]
# Transform words to match OpenAI format
if "words" in first_alternative:
response["words"] = [
{"word": word["word"], "start": word["start"], "end": word["end"]}
for word in first_alternative["words"]
]
# Store full response in hidden params
response._hidden_params = response_json
return response
except Exception as e:
raise ValueError(
f"Error transforming Deepgram response: {str(e)}\nResponse: {raw_response.text}"
)
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
if api_base is None:
api_base = "https://api.deepgram.com/v1"
api_base = api_base.rstrip("/") # Remove trailing slash if present
return f"{api_base}/listen?model={model}"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
api_key = api_key or get_secret_str("DEEPGRAM_API_KEY")
return {
"Authorization": f"Token {api_key}",
}

View file

@ -0,0 +1,5 @@
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class DeepgramException(BaseLLMException):
pass

View file

@ -4867,7 +4867,30 @@ def transcription(
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
) )
elif custom_llm_provider == "deepgram":
response = base_llm_http_handler.audio_transcriptions(
model=model,
audio_file=file,
optional_params=optional_params,
model_response=model_response,
atranscription=atranscription,
client=(
client
if client is not None
and (
isinstance(client, HTTPHandler)
or isinstance(client, AsyncHTTPHandler)
)
else None
),
timeout=timeout,
max_retries=max_retries,
logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
custom_llm_provider="deepgram",
headers={},
)
if response is None: if response is None:
raise ValueError("Unmapped provider passed in. Unable to get the response.") raise ValueError("Unmapped provider passed in. Unable to get the response.")
return response return response

View file

@ -1788,6 +1788,7 @@ class LlmProviders(str, Enum):
LM_STUDIO = "lm_studio" LM_STUDIO = "lm_studio"
GALADRIEL = "galadriel" GALADRIEL = "galadriel"
INFINITY = "infinity" INFINITY = "infinity"
DEEPGRAM = "deepgram"
class LiteLLMLoggingBaseClass: class LiteLLMLoggingBaseClass:

View file

@ -6323,6 +6323,8 @@ class ProviderConfigManager:
) -> Optional[BaseAudioTranscriptionConfig]: ) -> Optional[BaseAudioTranscriptionConfig]:
if litellm.LlmProviders.FIREWORKS_AI == provider: if litellm.LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIAudioTranscriptionConfig() return litellm.FireworksAIAudioTranscriptionConfig()
elif litellm.LlmProviders.DEEPGRAM == provider:
return litellm.DeepgramAudioTranscriptionConfig()
return None return None
@staticmethod @staticmethod

View file

@ -0,0 +1,20 @@
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from base_audio_transcription_unit_tests import BaseLLMAudioTranscriptionTest
class TestDeepgramAudioTranscription(BaseLLMAudioTranscriptionTest):
def get_base_audio_transcription_call_args(self) -> dict:
return {
"model": "deepgram/nova-2",
}
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.DEEPGRAM