mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Litellm dev 12 28 2024 p3 (#7464)
* feat(deepgram/): initial e2e support for deepgram stt Uses deepgram's `/listen` endpoint to transcribe speech to text Closes https://github.com/BerriAI/litellm/issues/4875 * fix: fix linting errors * test: fix test
This commit is contained in:
parent
480d838790
commit
ebc28b1921
10 changed files with 303 additions and 5 deletions
|
@ -1107,6 +1107,9 @@ from .llms.cohere.chat.transformation import CohereChatConfig
|
||||||
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
|
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
|
||||||
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
|
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
|
||||||
from .llms.deepinfra.chat.transformation import DeepInfraConfig
|
from .llms.deepinfra.chat.transformation import DeepInfraConfig
|
||||||
|
from .llms.deepgram.audio_transcription.transformation import (
|
||||||
|
DeepgramAudioTranscriptionConfig,
|
||||||
|
)
|
||||||
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||||
from .llms.groq.chat.transformation import GroqChatConfig
|
from .llms.groq.chat.transformation import GroqChatConfig
|
||||||
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
||||||
|
|
|
@ -192,4 +192,12 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return litellm.TritonConfig().get_supported_openai_params(model=model)
|
return litellm.TritonConfig().get_supported_openai_params(model=model)
|
||||||
|
elif custom_llm_provider == "deepgram":
|
||||||
|
if request_type == "transcription":
|
||||||
|
return (
|
||||||
|
litellm.DeepgramAudioTranscriptionConfig().get_supported_openai_params(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -492,16 +492,18 @@ class HTTPHandler:
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
files: Optional[dict] = None,
|
||||||
|
content: Any = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
req = self.client.build_request(
|
req = self.client.build_request(
|
||||||
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout, files=files, content=content # type: ignore
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
req = self.client.build_request(
|
req = self.client.build_request(
|
||||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
"POST", url, data=data, json=json, params=params, headers=headers, files=files, content=content # type: ignore
|
||||||
)
|
)
|
||||||
response = self.client.send(req, stream=stream)
|
response = self.client.send(req, stream=stream)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -513,7 +515,6 @@ class HTTPHandler:
|
||||||
llm_provider="litellm-httpx-handler",
|
llm_provider="litellm-httpx-handler",
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
setattr(e, "message", mask_sensitive_info(e.response.read()))
|
setattr(e, "message", mask_sensitive_info(e.response.read()))
|
||||||
setattr(e, "text", mask_sensitive_info(e.response.read()))
|
setattr(e, "text", mask_sensitive_info(e.response.read()))
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -17,7 +18,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
||||||
from litellm.types.utils import EmbeddingResponse
|
from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -667,6 +668,115 @@ class BaseLLMHTTPHandler:
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def handle_audio_file(self, audio_file: FileTypes) -> bytes:
|
||||||
|
"""
|
||||||
|
Processes the audio file input based on its type and returns the binary data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_file: Can be a file path (str), a tuple (filename, file_content), or binary data (bytes).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The binary data of the audio file.
|
||||||
|
"""
|
||||||
|
binary_data: bytes # Explicitly declare the type
|
||||||
|
|
||||||
|
# Handle the audio file based on type
|
||||||
|
if isinstance(audio_file, str):
|
||||||
|
# If it's a file path
|
||||||
|
with open(audio_file, "rb") as f:
|
||||||
|
binary_data = f.read() # `f.read()` always returns `bytes`
|
||||||
|
elif isinstance(audio_file, tuple):
|
||||||
|
# Handle tuple case
|
||||||
|
_, file_content = audio_file[:2]
|
||||||
|
if isinstance(file_content, str):
|
||||||
|
with open(file_content, "rb") as f:
|
||||||
|
binary_data = f.read() # `f.read()` always returns `bytes`
|
||||||
|
elif isinstance(file_content, bytes):
|
||||||
|
binary_data = file_content
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unexpected type in tuple: {type(file_content)}. Expected str or bytes."
|
||||||
|
)
|
||||||
|
elif isinstance(audio_file, bytes):
|
||||||
|
# Assume it's already binary data
|
||||||
|
binary_data = audio_file
|
||||||
|
elif isinstance(audio_file, io.BufferedReader):
|
||||||
|
# Handle file-like objects
|
||||||
|
binary_data = audio_file.read()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported type for audio_file: {type(audio_file)}")
|
||||||
|
|
||||||
|
return binary_data
|
||||||
|
|
||||||
|
def audio_transcriptions(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
audio_file: FileTypes,
|
||||||
|
optional_params: dict,
|
||||||
|
model_response: TranscriptionResponse,
|
||||||
|
timeout: float,
|
||||||
|
max_retries: int,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
custom_llm_provider: str,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
atranscription: bool = False,
|
||||||
|
headers: dict = {},
|
||||||
|
) -> TranscriptionResponse:
|
||||||
|
provider_config = ProviderConfigManager.get_provider_audio_transcription_config(
|
||||||
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
if provider_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"No provider config found for model: {model} and provider: {custom_llm_provider}"
|
||||||
|
)
|
||||||
|
headers = provider_config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
messages=[],
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = _get_httpx_client()
|
||||||
|
|
||||||
|
complete_url = provider_config.get_complete_url(
|
||||||
|
api_base=api_base,
|
||||||
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle the audio file based on type
|
||||||
|
binary_data = self.handle_audio_file(audio_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make the POST request
|
||||||
|
response = client.post(
|
||||||
|
url=complete_url,
|
||||||
|
headers=headers,
|
||||||
|
content=binary_data,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
|
||||||
|
if isinstance(provider_config, litellm.DeepgramAudioTranscriptionConfig):
|
||||||
|
returned_response = provider_config.transform_audio_transcription_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
request_data={},
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params={},
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
return returned_response
|
||||||
|
return model_response
|
||||||
|
|
||||||
def _handle_error(
|
def _handle_error(
|
||||||
self, e: Exception, provider_config: Union[BaseConfig, BaseRerankConfig]
|
self, e: Exception, provider_config: Union[BaseConfig, BaseRerankConfig]
|
||||||
):
|
):
|
||||||
|
|
125
litellm/llms/deepgram/audio_transcription/transformation.py
Normal file
125
litellm/llms/deepgram/audio_transcription/transformation.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
"""
|
||||||
|
Translates from OpenAI's `/v1/audio/transcriptions` to Deepgram's `/v1/listen`
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from httpx import Headers, Response
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
OpenAIAudioTranscriptionOptionalParams,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import TranscriptionResponse
|
||||||
|
|
||||||
|
from ...base_llm.audio_transcription.transformation import (
|
||||||
|
BaseAudioTranscriptionConfig,
|
||||||
|
LiteLLMLoggingObj,
|
||||||
|
)
|
||||||
|
from ..common_utils import DeepgramException
|
||||||
|
|
||||||
|
|
||||||
|
class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||||
|
def get_supported_openai_params(
|
||||||
|
self, model: str
|
||||||
|
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||||
|
return ["language"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
supported_params = self.get_supported_openai_params(model)
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k in supported_params:
|
||||||
|
optional_params[k] = v
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return DeepgramException(
|
||||||
|
message=error_message, status_code=status_code, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_audio_transcription_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: TranscriptionResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> TranscriptionResponse:
|
||||||
|
"""
|
||||||
|
Transforms the raw response from Deepgram to the TranscriptionResponse format
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response_json = raw_response.json()
|
||||||
|
|
||||||
|
# Get the first alternative from the first channel
|
||||||
|
first_channel = response_json["results"]["channels"][0]
|
||||||
|
first_alternative = first_channel["alternatives"][0]
|
||||||
|
|
||||||
|
# Extract the full transcript
|
||||||
|
text = first_alternative["transcript"]
|
||||||
|
|
||||||
|
# Create TranscriptionResponse object
|
||||||
|
response = TranscriptionResponse(text=text)
|
||||||
|
|
||||||
|
# Add additional metadata matching OpenAI format
|
||||||
|
response["task"] = "transcribe"
|
||||||
|
response["language"] = (
|
||||||
|
"english" # Deepgram auto-detects but doesn't return language
|
||||||
|
)
|
||||||
|
response["duration"] = response_json["metadata"]["duration"]
|
||||||
|
|
||||||
|
# Transform words to match OpenAI format
|
||||||
|
if "words" in first_alternative:
|
||||||
|
response["words"] = [
|
||||||
|
{"word": word["word"], "start": word["start"], "end": word["end"]}
|
||||||
|
for word in first_alternative["words"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Store full response in hidden params
|
||||||
|
response._hidden_params = response_json
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error transforming Deepgram response: {str(e)}\nResponse: {raw_response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
|
if api_base is None:
|
||||||
|
api_base = "https://api.deepgram.com/v1"
|
||||||
|
api_base = api_base.rstrip("/") # Remove trailing slash if present
|
||||||
|
|
||||||
|
return f"{api_base}/listen?model={model}"
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
api_key = api_key or get_secret_str("DEEPGRAM_API_KEY")
|
||||||
|
return {
|
||||||
|
"Authorization": f"Token {api_key}",
|
||||||
|
}
|
5
litellm/llms/deepgram/common_utils.py
Normal file
5
litellm/llms/deepgram/common_utils.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class DeepgramException(BaseLLMException):
|
||||||
|
pass
|
|
@ -4867,7 +4867,30 @@ def transcription(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "deepgram":
|
||||||
|
response = base_llm_http_handler.audio_transcriptions(
|
||||||
|
model=model,
|
||||||
|
audio_file=file,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=model_response,
|
||||||
|
atranscription=atranscription,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None
|
||||||
|
and (
|
||||||
|
isinstance(client, HTTPHandler)
|
||||||
|
or isinstance(client, AsyncHTTPHandler)
|
||||||
|
)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
custom_llm_provider="deepgram",
|
||||||
|
headers={},
|
||||||
|
)
|
||||||
if response is None:
|
if response is None:
|
||||||
raise ValueError("Unmapped provider passed in. Unable to get the response.")
|
raise ValueError("Unmapped provider passed in. Unable to get the response.")
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -1788,6 +1788,7 @@ class LlmProviders(str, Enum):
|
||||||
LM_STUDIO = "lm_studio"
|
LM_STUDIO = "lm_studio"
|
||||||
GALADRIEL = "galadriel"
|
GALADRIEL = "galadriel"
|
||||||
INFINITY = "infinity"
|
INFINITY = "infinity"
|
||||||
|
DEEPGRAM = "deepgram"
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMLoggingBaseClass:
|
class LiteLLMLoggingBaseClass:
|
||||||
|
|
|
@ -6323,6 +6323,8 @@ class ProviderConfigManager:
|
||||||
) -> Optional[BaseAudioTranscriptionConfig]:
|
) -> Optional[BaseAudioTranscriptionConfig]:
|
||||||
if litellm.LlmProviders.FIREWORKS_AI == provider:
|
if litellm.LlmProviders.FIREWORKS_AI == provider:
|
||||||
return litellm.FireworksAIAudioTranscriptionConfig()
|
return litellm.FireworksAIAudioTranscriptionConfig()
|
||||||
|
elif litellm.LlmProviders.DEEPGRAM == provider:
|
||||||
|
return litellm.DeepgramAudioTranscriptionConfig()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
20
tests/llm_translation/test_deepgram.py
Normal file
20
tests/llm_translation/test_deepgram.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from base_audio_transcription_unit_tests import BaseLLMAudioTranscriptionTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeepgramAudioTranscription(BaseLLMAudioTranscriptionTest):
|
||||||
|
def get_base_audio_transcription_call_args(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": "deepgram/nova-2",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
|
return litellm.LlmProviders.DEEPGRAM
|
Loading…
Add table
Add a link
Reference in a new issue