forked from phoenix/litellm-mirror
fix(whisper---handle-openai/azure-vtt-response-format): Fixes https://github.com/BerriAI/litellm/issues/4595
This commit is contained in:
parent
d5564dd81f
commit
298505c47c
10 changed files with 252 additions and 84 deletions
|
@ -829,6 +829,7 @@ from .llms.openai import (
|
||||||
MistralConfig,
|
MistralConfig,
|
||||||
MistralEmbeddingConfig,
|
MistralEmbeddingConfig,
|
||||||
DeepInfraConfig,
|
DeepInfraConfig,
|
||||||
|
GroqConfig,
|
||||||
AzureAIStudioConfig,
|
AzureAIStudioConfig,
|
||||||
)
|
)
|
||||||
from .llms.nvidia_nim import NvidiaNimConfig
|
from .llms.nvidia_nim import NvidiaNimConfig
|
||||||
|
|
|
@ -19,6 +19,7 @@ from typing import (
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests
|
import requests
|
||||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
from pydantic import BaseModel
|
||||||
from typing_extensions import overload
|
from typing_extensions import overload
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -1534,7 +1535,12 @@ class AzureChatCompletion(BaseLLM):
|
||||||
response = azure_client.audio.transcriptions.create(
|
response = azure_client.audio.transcriptions.create(
|
||||||
**data, timeout=timeout # type: ignore
|
**data, timeout=timeout # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(response, BaseModel):
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
|
else:
|
||||||
|
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=audio_file.name,
|
input=audio_file.name,
|
||||||
|
@ -1587,7 +1593,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
**data, timeout=timeout
|
**data, timeout=timeout
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
if isinstance(response, BaseModel):
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
|
else:
|
||||||
|
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
|
|
@ -348,6 +348,104 @@ class DeepInfraConfig:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class GroqConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://deepinfra.com/docs/advanced/openai_api
|
||||||
|
|
||||||
|
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
|
||||||
|
"""
|
||||||
|
|
||||||
|
frequency_penalty: Optional[int] = None
|
||||||
|
function_call: Optional[Union[str, dict]] = None
|
||||||
|
functions: Optional[list] = None
|
||||||
|
logit_bias: Optional[dict] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
presence_penalty: Optional[int] = None
|
||||||
|
stop: Optional[Union[str, list]] = None
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
response_format: Optional[dict] = None
|
||||||
|
tools: Optional[list] = None
|
||||||
|
tool_choice: Optional[Union[str, dict]] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frequency_penalty: Optional[int] = None,
|
||||||
|
function_call: Optional[Union[str, dict]] = None,
|
||||||
|
functions: Optional[list] = None,
|
||||||
|
logit_bias: Optional[dict] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, list]] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
response_format: Optional[dict] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[Union[str, dict]] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params_stt(self):
|
||||||
|
return [
|
||||||
|
"prompt",
|
||||||
|
"response_format",
|
||||||
|
"temperature",
|
||||||
|
"language",
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_supported_openai_response_formats_stt(self) -> List[str]:
|
||||||
|
return ["json", "verbose_json", "text"]
|
||||||
|
|
||||||
|
def map_openai_params_stt(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
response_formats = self.get_supported_openai_response_formats_stt()
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "response_format":
|
||||||
|
if value in response_formats:
|
||||||
|
optional_params[param] = value
|
||||||
|
else:
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Groq doesn't support response_format={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
||||||
|
value
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optional_params[param] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConfig:
|
class OpenAIConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
@ -1360,7 +1458,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
**data, timeout=timeout # type: ignore
|
**data, timeout=timeout # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(response, BaseModel):
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
|
else:
|
||||||
|
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=audio_file.name,
|
input=audio_file.name,
|
||||||
|
@ -1400,7 +1502,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
logging_obj.model_call_details["response_headers"] = headers
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
|
if isinstance(response, BaseModel):
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
|
else:
|
||||||
|
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=audio_file.name,
|
input=audio_file.name,
|
||||||
|
|
|
@ -61,6 +61,7 @@ from litellm.utils import (
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
get_optional_params_embeddings,
|
get_optional_params_embeddings,
|
||||||
get_optional_params_image_gen,
|
get_optional_params_image_gen,
|
||||||
|
get_optional_params_transcription,
|
||||||
get_secret,
|
get_secret,
|
||||||
mock_completion_streaming_obj,
|
mock_completion_streaming_obj,
|
||||||
read_config_args,
|
read_config_args,
|
||||||
|
@ -4279,7 +4280,7 @@ def image_generation(
|
||||||
|
|
||||||
|
|
||||||
@client
|
@client
|
||||||
async def atranscription(*args, **kwargs):
|
async def atranscription(*args, **kwargs) -> TranscriptionResponse:
|
||||||
"""
|
"""
|
||||||
Calls openai + azure whisper endpoints.
|
Calls openai + azure whisper endpoints.
|
||||||
|
|
||||||
|
@ -4304,9 +4305,9 @@ async def atranscription(*args, **kwargs):
|
||||||
|
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
if isinstance(init_response, dict) or isinstance(
|
if isinstance(init_response, dict):
|
||||||
init_response, TranscriptionResponse
|
response = TranscriptionResponse(**init_response)
|
||||||
): ## CACHING SCENARIO
|
elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO
|
||||||
response = init_response
|
response = init_response
|
||||||
elif asyncio.iscoroutine(init_response):
|
elif asyncio.iscoroutine(init_response):
|
||||||
response = await init_response
|
response = await init_response
|
||||||
|
@ -4346,7 +4347,7 @@ def transcription(
|
||||||
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
|
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
|
||||||
custom_llm_provider=None,
|
custom_llm_provider=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> TranscriptionResponse:
|
||||||
"""
|
"""
|
||||||
Calls openai + azure whisper endpoints.
|
Calls openai + azure whisper endpoints.
|
||||||
|
|
||||||
|
@ -4358,6 +4359,7 @@ def transcription(
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
metadata = kwargs.get("metadata", {})
|
||||||
|
drop_params = kwargs.get("drop_params", None)
|
||||||
client: Optional[
|
client: Optional[
|
||||||
Union[
|
Union[
|
||||||
openai.AsyncOpenAI,
|
openai.AsyncOpenAI,
|
||||||
|
@ -4379,12 +4381,22 @@ def transcription(
|
||||||
|
|
||||||
if dynamic_api_key is not None:
|
if dynamic_api_key is not None:
|
||||||
api_key = dynamic_api_key
|
api_key = dynamic_api_key
|
||||||
optional_params = {
|
|
||||||
"language": language,
|
optional_params = get_optional_params_transcription(
|
||||||
"prompt": prompt,
|
model=model,
|
||||||
"response_format": response_format,
|
language=language,
|
||||||
"temperature": None, # openai defaults this to 0
|
prompt=prompt,
|
||||||
}
|
response_format=response_format,
|
||||||
|
temperature=temperature,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
drop_params=drop_params,
|
||||||
|
)
|
||||||
|
# optional_params = {
|
||||||
|
# "language": language,
|
||||||
|
# "prompt": prompt,
|
||||||
|
# "response_format": response_format,
|
||||||
|
# "temperature": None, # openai defaults this to 0
|
||||||
|
# }
|
||||||
|
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -5,6 +5,16 @@ model_list:
|
||||||
- model_name: gemini-1.5-flash
|
- model_name: gemini-1.5-flash
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gemini/gemini-1.5-flash
|
model: gemini/gemini-1.5-flash
|
||||||
|
- model_name: whisper
|
||||||
|
litellm_params:
|
||||||
|
model: azure/azure-whisper
|
||||||
|
api_version: 2024-02-15-preview
|
||||||
|
api_base: os.environ/AZURE_EUROPE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||||
|
model_info:
|
||||||
|
mode: audio_transcription
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
|
|
|
@ -35,75 +35,44 @@ import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
|
|
||||||
def test_transcription():
|
@pytest.mark.parametrize(
|
||||||
|
"model, api_key, api_base",
|
||||||
|
[
|
||||||
|
("whisper-1", None, None),
|
||||||
|
# ("groq/whisper-large-v3", None, None),
|
||||||
|
(
|
||||||
|
"azure/azure-whisper",
|
||||||
|
os.getenv("AZURE_EUROPE_API_KEY"),
|
||||||
|
"https://my-endpoint-europe-berri-992.openai.azure.com/",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("response_format", ["json", "vtt"])
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcription(model, api_key, api_base, response_format, sync_mode):
|
||||||
|
if sync_mode:
|
||||||
transcript = litellm.transcription(
|
transcript = litellm.transcription(
|
||||||
model="whisper-1",
|
model=model,
|
||||||
file=audio_file,
|
file=audio_file,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
response_format=response_format,
|
||||||
|
drop_params=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transcript = await litellm.atranscription(
|
||||||
|
model=model,
|
||||||
|
file=audio_file,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
response_format=response_format,
|
||||||
|
drop_params=True,
|
||||||
)
|
)
|
||||||
print(f"transcript: {transcript.model_dump()}")
|
print(f"transcript: {transcript.model_dump()}")
|
||||||
print(f"transcript: {transcript._hidden_params}")
|
print(f"transcript: {transcript._hidden_params}")
|
||||||
|
|
||||||
|
|
||||||
# test_transcription()
|
|
||||||
|
|
||||||
|
|
||||||
def test_transcription_groq():
|
|
||||||
litellm.set_verbose = True
|
|
||||||
transcript = litellm.transcription(
|
|
||||||
model="groq/whisper-large-v3",
|
|
||||||
file=audio_file,
|
|
||||||
)
|
|
||||||
print(f"response=: {transcript.model_dump()}")
|
|
||||||
print(f"hidden_params: {transcript._hidden_params}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_transcription()
|
|
||||||
|
|
||||||
|
|
||||||
def test_transcription_azure():
|
|
||||||
litellm.set_verbose = True
|
|
||||||
transcript = litellm.transcription(
|
|
||||||
model="azure/azure-whisper",
|
|
||||||
file=audio_file,
|
|
||||||
api_base="https://my-endpoint-europe-berri-992.openai.azure.com/",
|
|
||||||
api_key=os.getenv("AZURE_EUROPE_API_KEY"),
|
|
||||||
api_version="2024-02-15-preview",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"transcript: {transcript}")
|
|
||||||
assert transcript.text is not None
|
assert transcript.text is not None
|
||||||
assert isinstance(transcript.text, str)
|
|
||||||
|
|
||||||
|
|
||||||
# test_transcription_azure()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_transcription_async_azure():
|
|
||||||
transcript = await litellm.atranscription(
|
|
||||||
model="azure/azure-whisper",
|
|
||||||
file=audio_file,
|
|
||||||
api_base="https://my-endpoint-europe-berri-992.openai.azure.com/",
|
|
||||||
api_key=os.getenv("AZURE_EUROPE_API_KEY"),
|
|
||||||
api_version="2024-02-15-preview",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert transcript.text is not None
|
|
||||||
assert isinstance(transcript.text, str)
|
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_transcription_async_azure())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_transcription_async_openai():
|
|
||||||
transcript = await litellm.atranscription(
|
|
||||||
model="whisper-1",
|
|
||||||
file=audio_file,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert transcript.text is not None
|
|
||||||
assert isinstance(transcript.text, str)
|
|
||||||
|
|
||||||
|
|
||||||
# This file includes the custom callbacks for LiteLLM Proxy
|
# This file includes the custom callbacks for LiteLLM Proxy
|
||||||
|
|
|
@ -2144,6 +2144,71 @@ def get_litellm_params(
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_optional_params_transcription(
|
||||||
|
model: str,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
response_format: Optional[str] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
drop_params: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# retrieve all parameters passed to the function
|
||||||
|
passed_params = locals()
|
||||||
|
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||||
|
drop_params = passed_params.pop("drop_params")
|
||||||
|
special_params = passed_params.pop("kwargs")
|
||||||
|
for k, v in special_params.items():
|
||||||
|
passed_params[k] = v
|
||||||
|
|
||||||
|
default_params = {
|
||||||
|
"language": None,
|
||||||
|
"prompt": None,
|
||||||
|
"response_format": None,
|
||||||
|
"temperature": None, # openai defaults this to 0
|
||||||
|
}
|
||||||
|
|
||||||
|
non_default_params = {
|
||||||
|
k: v
|
||||||
|
for k, v in passed_params.items()
|
||||||
|
if (k in default_params and v != default_params[k])
|
||||||
|
}
|
||||||
|
optional_params = {}
|
||||||
|
|
||||||
|
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||||
|
def _check_valid_arg(supported_params):
|
||||||
|
if len(non_default_params.keys()) > 0:
|
||||||
|
keys = list(non_default_params.keys())
|
||||||
|
for k in keys:
|
||||||
|
if (
|
||||||
|
drop_params is True or litellm.drop_params is True
|
||||||
|
) and k not in supported_params: # drop the unsupported non-default values
|
||||||
|
non_default_params.pop(k, None)
|
||||||
|
elif k not in supported_params:
|
||||||
|
raise UnsupportedParamsError(
|
||||||
|
status_code=500,
|
||||||
|
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
|
)
|
||||||
|
return non_default_params
|
||||||
|
|
||||||
|
if custom_llm_provider == "openai" or custom_llm_provider == "azure":
|
||||||
|
optional_params = non_default_params
|
||||||
|
elif custom_llm_provider == "groq":
|
||||||
|
supported_params = litellm.GroqConfig().get_supported_openai_params_stt()
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.GroqConfig().map_openai_params_stt(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=drop_params if drop_params is not None else False,
|
||||||
|
)
|
||||||
|
for k in passed_params.keys(): # pass additional kwargs without modification
|
||||||
|
if k not in default_params.keys():
|
||||||
|
optional_params[k] = passed_params[k]
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def get_optional_params_image_gen(
|
def get_optional_params_image_gen(
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
quality: Optional[str] = None,
|
quality: Optional[str] = None,
|
||||||
|
@ -7559,7 +7624,7 @@ def exception_type(
|
||||||
else:
|
else:
|
||||||
# if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors
|
# if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors
|
||||||
raise APIConnectionError(
|
raise APIConnectionError(
|
||||||
message=f"{exception_provider} APIConnectionError - {message}",
|
message=f"{exception_provider} APIConnectionError - {message}\n{traceback.format_exc()}",
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue