mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #3928 from BerriAI/litellm_audio_speech_endpoint
feat(main.py): support openai tts endpoint
This commit is contained in:
commit
d3a247bf20
11 changed files with 574 additions and 7 deletions
|
@ -79,6 +79,9 @@ def generate_feedback_box():
|
|||
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.openai import (
|
||||
HttpxBinaryResponseContent,
|
||||
)
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
DBClient,
|
||||
|
@ -4883,6 +4886,169 @@ async def image_generation(
|
|||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/speech",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["audio"],
|
||||
)
|
||||
@router.post(
|
||||
"/audio/speech",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["audio"],
|
||||
)
|
||||
async def audio_speech(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Same params as:
|
||||
|
||||
https://platform.openai.com/docs/api-reference/audio/createSpeech
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
data: Dict = {}
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = { # type: ignore
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"body": copy.copy(data), # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||
_headers = dict(request.headers)
|
||||
_headers.pop(
|
||||
"authorization", None
|
||||
) # do not store the original `sk-..` api key in the db
|
||||
data["metadata"]["headers"] = _headers
|
||||
data["metadata"]["user_api_key_alias"] = getattr(
|
||||
user_api_key_dict, "key_alias", None
|
||||
)
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||
user_api_key_dict, "team_alias", None
|
||||
)
|
||||
data["metadata"]["endpoint"] = str(request.url)
|
||||
|
||||
### TEAM-SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.team_id is not None:
|
||||
team_config = await proxy_config.load_team_config(
|
||||
team_id=user_api_key_dict.team_id
|
||||
)
|
||||
if len(team_config) == 0:
|
||||
pass
|
||||
else:
|
||||
team_id = team_config.pop("team_id", None)
|
||||
data["metadata"]["team_id"] = team_id
|
||||
data = {
|
||||
**team_config,
|
||||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation"
|
||||
)
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.aspeech(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.aspeech(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aspeech(**data, specific_deployment=True)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.aspeech(
|
||||
**data
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and llm_router.default_deployment is not None
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aspeech(**data)
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.aspeech(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "audio_speech: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
# Printing each chunk size
|
||||
async def generate(_response: HttpxBinaryResponseContent):
|
||||
_generator = await _response.aiter_bytes(chunk_size=1024)
|
||||
async for chunk in _generator:
|
||||
yield chunk
|
||||
|
||||
custom_headers = get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=None,
|
||||
)
|
||||
|
||||
selected_data_generator = select_data_generator(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
return StreamingResponse(
|
||||
generate(response), media_type="audio/mpeg", headers=custom_headers
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/transcriptions",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue