fix(whisper---handle-openai/azure-vtt-response-format): Fixes https://github.com/BerriAI/litellm/issues/4595

This commit is contained in:
Krrish Dholakia 2024-07-08 09:05:29 -07:00
parent d5564dd81f
commit 298505c47c
10 changed files with 252 additions and 84 deletions

View file

@ -829,6 +829,7 @@ from .llms.openai import (
MistralConfig, MistralConfig,
MistralEmbeddingConfig, MistralEmbeddingConfig,
DeepInfraConfig, DeepInfraConfig,
GroqConfig,
AzureAIStudioConfig, AzureAIStudioConfig,
) )
from .llms.nvidia_nim import NvidiaNimConfig from .llms.nvidia_nim import NvidiaNimConfig

View file

@ -19,6 +19,7 @@ from typing import (
import httpx # type: ignore import httpx # type: ignore
import requests import requests
from openai import AsyncAzureOpenAI, AzureOpenAI from openai import AsyncAzureOpenAI, AzureOpenAI
from pydantic import BaseModel
from typing_extensions import overload from typing_extensions import overload
import litellm import litellm
@ -1534,7 +1535,12 @@ class AzureChatCompletion(BaseLLM):
response = azure_client.audio.transcriptions.create( response = azure_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore **data, timeout=timeout # type: ignore
) )
stringified_response = response.model_dump()
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=audio_file.name, input=audio_file.name,
@ -1587,7 +1593,10 @@ class AzureChatCompletion(BaseLLM):
**data, timeout=timeout **data, timeout=timeout
) # type: ignore ) # type: ignore
stringified_response = response.model_dump() if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(

View file

@ -348,6 +348,104 @@ class DeepInfraConfig:
return optional_params return optional_params
class GroqConfig:
"""
Reference: https://deepinfra.com/docs/advanced/openai_api
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
tools: Optional[list] = None
tool_choice: Optional[Union[str, dict]] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
tools: Optional[list] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params_stt(self):
return [
"prompt",
"response_format",
"temperature",
"language",
]
def get_supported_openai_response_formats_stt(self) -> List[str]:
return ["json", "verbose_json", "text"]
def map_openai_params_stt(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
response_formats = self.get_supported_openai_response_formats_stt()
for param, value in non_default_params.items():
if param == "response_format":
if value in response_formats:
optional_params[param] = value
else:
if litellm.drop_params is True or drop_params is True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="Groq doesn't support response_format={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
else:
optional_params[param] = value
return optional_params
class OpenAIConfig: class OpenAIConfig:
""" """
Reference: https://platform.openai.com/docs/api-reference/chat/create Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -1360,7 +1458,11 @@ class OpenAIChatCompletion(BaseLLM):
**data, timeout=timeout # type: ignore **data, timeout=timeout # type: ignore
) )
stringified_response = response.model_dump() if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=audio_file.name, input=audio_file.name,
@ -1400,7 +1502,10 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout, timeout=timeout,
) )
logging_obj.model_call_details["response_headers"] = headers logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump() if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=audio_file.name, input=audio_file.name,

View file

@ -61,6 +61,7 @@ from litellm.utils import (
get_llm_provider, get_llm_provider,
get_optional_params_embeddings, get_optional_params_embeddings,
get_optional_params_image_gen, get_optional_params_image_gen,
get_optional_params_transcription,
get_secret, get_secret,
mock_completion_streaming_obj, mock_completion_streaming_obj,
read_config_args, read_config_args,
@ -4279,7 +4280,7 @@ def image_generation(
@client @client
async def atranscription(*args, **kwargs): async def atranscription(*args, **kwargs) -> TranscriptionResponse:
""" """
Calls openai + azure whisper endpoints. Calls openai + azure whisper endpoints.
@ -4304,9 +4305,9 @@ async def atranscription(*args, **kwargs):
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance( if isinstance(init_response, dict):
init_response, TranscriptionResponse response = TranscriptionResponse(**init_response)
): ## CACHING SCENARIO elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO
response = init_response response = init_response
elif asyncio.iscoroutine(init_response): elif asyncio.iscoroutine(init_response):
response = await init_response response = await init_response
@ -4346,7 +4347,7 @@ def transcription(
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None, litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
custom_llm_provider=None, custom_llm_provider=None,
**kwargs, **kwargs,
): ) -> TranscriptionResponse:
""" """
Calls openai + azure whisper endpoints. Calls openai + azure whisper endpoints.
@ -4358,6 +4359,7 @@ def transcription(
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {}) metadata = kwargs.get("metadata", {})
drop_params = kwargs.get("drop_params", None)
client: Optional[ client: Optional[
Union[ Union[
openai.AsyncOpenAI, openai.AsyncOpenAI,
@ -4379,12 +4381,22 @@ def transcription(
if dynamic_api_key is not None: if dynamic_api_key is not None:
api_key = dynamic_api_key api_key = dynamic_api_key
optional_params = {
"language": language, optional_params = get_optional_params_transcription(
"prompt": prompt, model=model,
"response_format": response_format, language=language,
"temperature": None, # openai defaults this to 0 prompt=prompt,
} response_format=response_format,
temperature=temperature,
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
)
# optional_params = {
# "language": language,
# "prompt": prompt,
# "response_format": response_format,
# "temperature": None, # openai defaults this to 0
# }
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -5,6 +5,16 @@ model_list:
- model_name: gemini-1.5-flash - model_name: gemini-1.5-flash
litellm_params: litellm_params:
model: gemini/gemini-1.5-flash model: gemini/gemini-1.5-flash
- model_name: whisper
litellm_params:
model: azure/azure-whisper
api_version: 2024-02-15-preview
api_base: os.environ/AZURE_EUROPE_API_BASE
api_key: os.environ/AZURE_EUROPE_API_KEY
model_info:
mode: audio_transcription
general_settings: general_settings:
alerting: ["slack"] alerting: ["slack"]

View file

@ -35,75 +35,44 @@ import litellm
from litellm import Router from litellm import Router
def test_transcription(): @pytest.mark.parametrize(
transcript = litellm.transcription( "model, api_key, api_base",
model="whisper-1", [
file=audio_file, ("whisper-1", None, None),
) # ("groq/whisper-large-v3", None, None),
(
"azure/azure-whisper",
os.getenv("AZURE_EUROPE_API_KEY"),
"https://my-endpoint-europe-berri-992.openai.azure.com/",
),
],
)
@pytest.mark.parametrize("response_format", ["json", "vtt"])
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_transcription(model, api_key, api_base, response_format, sync_mode):
if sync_mode:
transcript = litellm.transcription(
model=model,
file=audio_file,
api_key=api_key,
api_base=api_base,
response_format=response_format,
drop_params=True,
)
else:
transcript = await litellm.atranscription(
model=model,
file=audio_file,
api_key=api_key,
api_base=api_base,
response_format=response_format,
drop_params=True,
)
print(f"transcript: {transcript.model_dump()}") print(f"transcript: {transcript.model_dump()}")
print(f"transcript: {transcript._hidden_params}") print(f"transcript: {transcript._hidden_params}")
# test_transcription()
def test_transcription_groq():
litellm.set_verbose = True
transcript = litellm.transcription(
model="groq/whisper-large-v3",
file=audio_file,
)
print(f"response=: {transcript.model_dump()}")
print(f"hidden_params: {transcript._hidden_params}")
# test_transcription()
def test_transcription_azure():
litellm.set_verbose = True
transcript = litellm.transcription(
model="azure/azure-whisper",
file=audio_file,
api_base="https://my-endpoint-europe-berri-992.openai.azure.com/",
api_key=os.getenv("AZURE_EUROPE_API_KEY"),
api_version="2024-02-15-preview",
)
print(f"transcript: {transcript}")
assert transcript.text is not None assert transcript.text is not None
assert isinstance(transcript.text, str)
# test_transcription_azure()
@pytest.mark.asyncio
async def test_transcription_async_azure():
transcript = await litellm.atranscription(
model="azure/azure-whisper",
file=audio_file,
api_base="https://my-endpoint-europe-berri-992.openai.azure.com/",
api_key=os.getenv("AZURE_EUROPE_API_KEY"),
api_version="2024-02-15-preview",
)
assert transcript.text is not None
assert isinstance(transcript.text, str)
# asyncio.run(test_transcription_async_azure())
@pytest.mark.asyncio
async def test_transcription_async_openai():
transcript = await litellm.atranscription(
model="whisper-1",
file=audio_file,
)
assert transcript.text is not None
assert isinstance(transcript.text, str)
# This file includes the custom callbacks for LiteLLM Proxy # This file includes the custom callbacks for LiteLLM Proxy

View file

@ -2144,6 +2144,71 @@ def get_litellm_params(
return litellm_params return litellm_params
def get_optional_params_transcription(
model: str,
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: Optional[str] = None,
temperature: Optional[int] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
**kwargs,
):
# retrieve all parameters passed to the function
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
drop_params = passed_params.pop("drop_params")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"language": None,
"prompt": None,
"response_format": None,
"temperature": None, # openai defaults this to 0
}
non_default_params = {
k: v
for k, v in passed_params.items()
if (k in default_params and v != default_params[k])
}
optional_params = {}
## raise exception if non-default value passed for non-openai/azure embedding calls
def _check_valid_arg(supported_params):
if len(non_default_params.keys()) > 0:
keys = list(non_default_params.keys())
for k in keys:
if (
drop_params is True or litellm.drop_params is True
) and k not in supported_params: # drop the unsupported non-default values
non_default_params.pop(k, None)
elif k not in supported_params:
raise UnsupportedParamsError(
status_code=500,
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
return non_default_params
if custom_llm_provider == "openai" or custom_llm_provider == "azure":
optional_params = non_default_params
elif custom_llm_provider == "groq":
supported_params = litellm.GroqConfig().get_supported_openai_params_stt()
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.GroqConfig().map_openai_params_stt(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params if drop_params is not None else False,
)
for k in passed_params.keys(): # pass additional kwargs without modification
if k not in default_params.keys():
optional_params[k] = passed_params[k]
return optional_params
def get_optional_params_image_gen( def get_optional_params_image_gen(
n: Optional[int] = None, n: Optional[int] = None,
quality: Optional[str] = None, quality: Optional[str] = None,
@ -7559,7 +7624,7 @@ def exception_type(
else: else:
# if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors
raise APIConnectionError( raise APIConnectionError(
message=f"{exception_provider} APIConnectionError - {message}", message=f"{exception_provider} APIConnectionError - {message}\n{traceback.format_exc()}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,