feat(proxy_server.py): working /audio/transcription endpoint

This commit is contained in:
Krrish Dholakia 2024-03-08 18:20:27 -08:00
parent cc0294b2f2
commit 0fb7afe820
6 changed files with 95 additions and 54 deletions

View file

@ -120,6 +120,8 @@ from fastapi import (
Header,
Response,
Form,
UploadFile,
File,
)
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
@ -3216,17 +3218,16 @@ async def image_generation(
@router.post(
"/v1/audio/transcriptions",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["audio"],
)
@router.post(
"/audio/transcriptions",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["audio"],
)
async def audio_transcriptions(
request: Request,
file: UploadFile = File(...),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
@ -3237,11 +3238,11 @@ async def audio_transcriptions(
global proxy_logging_obj
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
form_data = await request.form()
data: Dict = {key: value for key, value in form_data.items() if key != "file"}
# Include original request and headers in the data
data["proxy_server_request"] = {
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
@ -3298,44 +3299,60 @@ async def audio_transcriptions(
else []
)
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="moderation"
)
assert (
file.filename is not None
) # make sure filename passed in (needed for type)
start_time = time.time()
with open(file.filename, "wb+") as f:
f.write(await file.read())
try:
data["file"] = open(file.filename, "rb")
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict,
data=data,
call_type="moderation",
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.atranscription(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.atranscription(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.atranscription(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atranscription(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.atranscription(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.atranscription(**data)
elif (
llm_router is not None
and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(
**data, specific_deployment=True
)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.atranscription(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atranscription(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.remove(file.filename) # Delete the saved file
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
@ -3344,7 +3361,7 @@ async def audio_transcriptions(
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e)),
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),