diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index e19023b03..022266996 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -770,9 +770,8 @@ class AzureChatCompletion(BaseLLM): api_version: Optional[str] = None, client=None, azure_ad_token: Optional[str] = None, - max_retries=None, logging_obj=None, - atranscriptions: bool = False, + atranscription: bool = False, ): data = {"model": model, "file": audio_file, **optional_params} @@ -781,9 +780,11 @@ class AzureChatCompletion(BaseLLM): "api_version": api_version, "azure_endpoint": api_base, "azure_deployment": model, - "max_retries": max_retries, "timeout": timeout, } + + max_retries = optional_params.pop("max_retries", None) + azure_client_params = select_azure_base_url_or_endpoint( azure_client_params=azure_client_params ) @@ -792,7 +793,10 @@ class AzureChatCompletion(BaseLLM): elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token - if atranscriptions == True: + if max_retries is not None: + azure_client_params["max_retries"] = max_retries + + if atranscription == True: return self.async_audio_transcriptions( audio_file=audio_file, data=data, @@ -845,18 +849,29 @@ class AzureChatCompletion(BaseLLM): ) else: async_azure_client = client + response = await async_azure_client.audio.transcriptions.create( **data, timeout=timeout ) # type: ignore + stringified_response = response.model_dump() + ## LOGGING logging_obj.post_call( input=audio_file.name, api_key=api_key, - additional_args={"complete_input_dict": data}, + additional_args={ + "headers": { + "Authorization": f"Bearer {async_azure_client.api_key}" + }, + "api_base": async_azure_client._base_url._uri_reference, + "atranscription": True, + "complete_input_dict": data, + }, original_response=stringified_response, ) - return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore + response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore + return response except Exception as e: ## LOGGING logging_obj.post_call( diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index fca950d31..a90d2457a 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -779,10 +779,10 @@ class OpenAIChatCompletion(BaseLLM): client=None, max_retries=None, logging_obj=None, - atranscriptions: bool = False, + atranscription: bool = False, ): data = {"model": model, "file": audio_file, **optional_params} - if atranscriptions == True: + if atranscription == True: return self.async_audio_transcriptions( audio_file=audio_file, data=data, diff --git a/litellm/main.py b/litellm/main.py index a04dba4ec..6deaf653f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3385,7 +3385,7 @@ def transcription( Allows router to load balance between them """ - atranscriptions = kwargs.get("atranscriptions", False) + atranscription = kwargs.get("atranscription", False) litellm_call_id = kwargs.get("litellm_call_id", None) logger_fn = kwargs.get("logger_fn", None) proxy_server_request = kwargs.get("proxy_server_request", None) @@ -3421,12 +3421,13 @@ def transcription( or litellm.azure_key or get_secret("AZURE_API_KEY") ) + response = azure_chat_completions.audio_transcriptions( model=model, audio_file=file, optional_params=optional_params, model_response=model_response, - atranscriptions=atranscriptions, + atranscription=atranscription, timeout=timeout, logging_obj=litellm_logging_obj, api_base=api_base, @@ -3440,7 +3441,7 @@ def transcription( audio_file=file, optional_params=optional_params, model_response=model_response, - atranscriptions=atranscriptions, + atranscription=atranscription, timeout=timeout, logging_obj=litellm_logging_obj, ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 27eeea573..7a04c8e7d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -120,6 +120,8 @@ from fastapi import ( Header, Response, Form, + UploadFile, + File, ) from fastapi.routing import APIRouter from fastapi.security import OAuth2PasswordBearer @@ -3216,17 +3218,16 @@ async def image_generation( @router.post( "/v1/audio/transcriptions", dependencies=[Depends(user_api_key_auth)], - response_class=ORJSONResponse, tags=["audio"], ) @router.post( "/audio/transcriptions", dependencies=[Depends(user_api_key_auth)], - response_class=ORJSONResponse, tags=["audio"], ) async def audio_transcriptions( request: Request, + file: UploadFile = File(...), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -3237,11 +3238,11 @@ async def audio_transcriptions( global proxy_logging_obj try: # Use orjson to parse JSON data, orjson speeds up requests significantly - body = await request.body() - data = orjson.loads(body) + form_data = await request.form() + data: Dict = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data - data["proxy_server_request"] = { + data["proxy_server_request"] = { # type: ignore "url": str(request.url), "method": request.method, "headers": dict(request.headers), @@ -3298,44 +3299,60 @@ async def audio_transcriptions( else [] ) - ### CALL HOOKS ### - modify incoming data / reject request before calling the model - data = await proxy_logging_obj.pre_call_hook( - user_api_key_dict=user_api_key_dict, data=data, call_type="moderation" - ) + assert ( + file.filename is not None + ) # make sure filename passed in (needed for type) - start_time = time.time() + with open(file.filename, "wb+") as f: + f.write(await file.read()) + try: + data["file"] = open(file.filename, "rb") + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, + data=data, + call_type="moderation", + ) - ## ROUTE TO CORRECT ENDPOINT ## - # skip router if user passed their key - if "api_key" in data: - response = await litellm.atranscription(**data) - elif ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - response = await llm_router.atranscription(**data) - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.atranscription(**data, specific_deployment=True) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - response = await llm_router.atranscription( - **data - ) # ensure this goes the llm_router, router will do the correct alias mapping - elif user_model is not None: # `litellm --model ` - response = await litellm.atranscription(**data) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": "Invalid model name passed in"}, - ) + ## ROUTE TO CORRECT ENDPOINT ## + # skip router if user passed their key + if "api_key" in data: + response = await litellm.atranscription(**data) + elif ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list + response = await llm_router.atranscription(**data) + + elif ( + llm_router is not None + and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.atranscription( + **data, specific_deployment=True + ) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + response = await llm_router.atranscription( + **data + ) # ensure this goes the llm_router, router will do the correct alias mapping + elif user_model is not None: # `litellm --model ` + response = await litellm.atranscription(**data) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Invalid model name passed in"}, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + finally: + os.remove(file.filename) # Delete the saved file ### ALERTING ### data["litellm_status"] = "success" # used for alerting - return response except Exception as e: await proxy_logging_obj.post_call_failure_hook( @@ -3344,7 +3361,7 @@ async def audio_transcriptions( traceback.print_exc() if isinstance(e, HTTPException): raise ProxyException( - message=getattr(e, "message", str(e)), + message=getattr(e, "message", str(e.detail)), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), diff --git a/litellm/utils.py b/litellm/utils.py index 330903f5a..3285f3a08 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2309,7 +2309,7 @@ def client(original_function): or call_type == CallTypes.transcription.value ): _file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"] - messages = _file_name.name + messages = "audio_file" stream = True if "stream" in kwargs and kwargs["stream"] == True else False logging_obj = Logging( model=model, @@ -2607,6 +2607,8 @@ def client(original_function): return result elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True: return result + elif "atranscription" in kwargs and kwargs["atranscription"] == True: + return result ### POST-CALL RULES ### post_call_processing(original_response=result, model=model or None) @@ -7834,7 +7836,9 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - request=original_exception.request, + request=httpx.Request( + method="POST", url="https://openai.com/" + ), ) else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors @@ -7842,7 +7846,11 @@ def exception_type( __cause__=original_exception.__cause__, llm_provider="azure", model=model, - request=original_exception.request, + request=getattr( + original_exception, + "request", + httpx.Request(method="POST", url="https://openai.com/"), + ), ) if ( "BadRequestError.__init__() missing 1 required positional argument: 'param'" diff --git a/tests/test_whisper.py b/tests/test_whisper.py index bb0971732..5cb651951 100644 --- a/tests/test_whisper.py +++ b/tests/test_whisper.py @@ -98,7 +98,7 @@ async def test_transcription_on_router(): "model": "azure/azure-whisper", "api_base": os.getenv("AZURE_EUROPE_API_BASE"), "api_key": os.getenv("AZURE_EUROPE_API_KEY"), - "api_version": os.getenv("2024-02-15-preview"), + "api_version": "2024-02-15-preview", }, }, ]