forked from phoenix/litellm-mirror
OpenAI /v1/realtime
api support (#6047)
* feat(azure/realtime): initial working commit for proxy azure openai realtime endpoint support Adds support for passing /v1/realtime calls via litellm proxy * feat(realtime_api/main.py): abstraction for handling openai realtime api calls * feat(router.py): add `arealtime()` endpoint in router for realtime api calls Allows using `model_list` in proxy for realtime as well * fix: make realtime api a private function Structure might change based on feedback. Make that clear to users. * build(requirements.txt): add websockets to the requirements.txt * feat(openai/realtime): add openai /v1/realtime api support
This commit is contained in:
parent
130842537f
commit
f9d0bcc5a1
11 changed files with 350 additions and 7 deletions
|
@ -1033,6 +1033,7 @@ from .router import Router
|
||||||
from .assistants.main import *
|
from .assistants.main import *
|
||||||
from .batches.main import *
|
from .batches.main import *
|
||||||
from .rerank_api.main import *
|
from .rerank_api.main import *
|
||||||
|
from .realtime_api.main import _arealtime
|
||||||
from .fine_tuning.main import *
|
from .fine_tuning.main import *
|
||||||
from .files.main import *
|
from .files.main import *
|
||||||
from .scheduler import *
|
from .scheduler import *
|
||||||
|
|
|
@ -311,6 +311,8 @@ def get_llm_provider(
|
||||||
dynamic_api_key
|
dynamic_api_key
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if dynamic_api_key is None and api_key is not None:
|
||||||
|
dynamic_api_key = api_key
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
elif model.split("/", 1)[0] in litellm.provider_list:
|
elif model.split("/", 1)[0] in litellm.provider_list:
|
||||||
custom_llm_provider = model.split("/", 1)[0]
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
|
|
86
litellm/llms/AzureOpenAI/realtime/handler.py
Normal file
86
litellm/llms/AzureOpenAI/realtime/handler.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
"""
|
||||||
|
This file contains the calling Azure OpenAI's `/openai/realtime` endpoint.
|
||||||
|
|
||||||
|
This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ..azure import AzureChatCompletion
|
||||||
|
|
||||||
|
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
|
||||||
|
|
||||||
|
|
||||||
|
async def forward_messages(client_ws: Any, backend_ws: Any):
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await backend_ws.recv()
|
||||||
|
await client_ws.send_text(message)
|
||||||
|
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIRealtime(AzureChatCompletion):
|
||||||
|
def _construct_url(self, api_base: str, model: str, api_version: str) -> str:
|
||||||
|
"""
|
||||||
|
Example output:
|
||||||
|
"wss://my-endpoint-sweden-berri992.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview";
|
||||||
|
|
||||||
|
"""
|
||||||
|
api_base = api_base.replace("https://", "wss://")
|
||||||
|
return (
|
||||||
|
f"{api_base}/openai/realtime?api-version={api_version}&deployment={model}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_realtime(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
websocket: Any,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
azure_ad_token: Optional[str] = None,
|
||||||
|
client: Optional[Any] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
):
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError("api_base is required for Azure OpenAI calls")
|
||||||
|
if api_version is None:
|
||||||
|
raise ValueError("api_version is required for Azure OpenAI calls")
|
||||||
|
|
||||||
|
url = self._construct_url(api_base, model, api_version)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with websockets.connect( # type: ignore
|
||||||
|
url,
|
||||||
|
extra_headers={
|
||||||
|
"api-key": api_key, # type: ignore
|
||||||
|
},
|
||||||
|
) as backend_ws:
|
||||||
|
forward_task = asyncio.create_task(
|
||||||
|
forward_messages(websocket, backend_ws)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive_text()
|
||||||
|
await backend_ws.send(message)
|
||||||
|
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||||
|
forward_task.cancel()
|
||||||
|
finally:
|
||||||
|
if not forward_task.done():
|
||||||
|
forward_task.cancel()
|
||||||
|
try:
|
||||||
|
await forward_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||||
|
await websocket.close(code=e.status_code, reason=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
await websocket.close(code=1011, reason=f"Internal server error: {str(e)}")
|
81
litellm/llms/OpenAI/realtime/handler.py
Normal file
81
litellm/llms/OpenAI/realtime/handler.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
"""
|
||||||
|
This file contains the calling Azure OpenAI's `/openai/realtime` endpoint.
|
||||||
|
|
||||||
|
This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ..openai import OpenAIChatCompletion
|
||||||
|
|
||||||
|
|
||||||
|
async def forward_messages(client_ws: Any, backend_ws: Any):
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await backend_ws.recv()
|
||||||
|
await client_ws.send_text(message)
|
||||||
|
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIRealtime(OpenAIChatCompletion):
|
||||||
|
def _construct_url(self, api_base: str, model: str) -> str:
|
||||||
|
"""
|
||||||
|
Example output:
|
||||||
|
"BACKEND_WS_URL = "wss://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"";
|
||||||
|
"""
|
||||||
|
api_base = api_base.replace("https://", "wss://")
|
||||||
|
api_base = api_base.replace("http://", "ws://")
|
||||||
|
return f"{api_base}/v1/realtime?model={model}"
|
||||||
|
|
||||||
|
async def async_realtime(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
websocket: Any,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
client: Optional[Any] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
):
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError("api_base is required for Azure OpenAI calls")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError("api_key is required for Azure OpenAI calls")
|
||||||
|
|
||||||
|
url = self._construct_url(api_base, model)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with websockets.connect( # type: ignore
|
||||||
|
url,
|
||||||
|
extra_headers={
|
||||||
|
"Authorization": f"Bearer {api_key}", # type: ignore
|
||||||
|
"OpenAI-Beta": "realtime=v1",
|
||||||
|
},
|
||||||
|
) as backend_ws:
|
||||||
|
forward_task = asyncio.create_task(
|
||||||
|
forward_messages(websocket, backend_ws)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive_text()
|
||||||
|
await backend_ws.send(message)
|
||||||
|
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||||
|
forward_task.cancel()
|
||||||
|
finally:
|
||||||
|
if not forward_task.done():
|
||||||
|
forward_task.cancel()
|
||||||
|
try:
|
||||||
|
await forward_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||||
|
await websocket.close(code=e.status_code, reason=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
await websocket.close(code=1011, reason=f"Internal server error: {str(e)}")
|
|
@ -1,7 +1,13 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: whisper
|
- model_name: gpt-4o-realtime-audio
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: whisper-1
|
model: azure/gpt-4o-realtime-preview
|
||||||
|
api_key: os.environ/AZURE_SWEDEN_API_KEY
|
||||||
|
api_base: os.environ/AZURE_SWEDEN_API_BASE
|
||||||
|
|
||||||
|
- model_name: openai-gpt-4o-realtime-audio
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
model_info:
|
api_base: http://localhost:8080
|
||||||
mode: audio_transcription
|
|
||||||
|
|
|
@ -4143,6 +4143,45 @@ async def audio_transcriptions(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
|
||||||
|
# /v1/realtime Endpoints
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
|
from litellm import _arealtime
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/v1/realtime")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket, model: str):
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"websocket": websocket,
|
||||||
|
}
|
||||||
|
|
||||||
|
### ROUTE THE REQUEST ###
|
||||||
|
try:
|
||||||
|
llm_call = await route_request(
|
||||||
|
data=data,
|
||||||
|
route_type="_arealtime",
|
||||||
|
llm_router=llm_router,
|
||||||
|
user_model=user_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
await llm_call
|
||||||
|
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||||
|
verbose_proxy_logger.exception("Invalid status code")
|
||||||
|
await websocket.close(code=e.status_code, reason="Invalid status code")
|
||||||
|
except Exception:
|
||||||
|
verbose_proxy_logger.exception("Internal server error")
|
||||||
|
await websocket.close(code=1011, reason="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
# /v1/assistant Endpoints
|
# /v1/assistant Endpoints
|
||||||
|
|
|
@ -58,6 +58,7 @@ async def route_request(
|
||||||
"atranscription",
|
"atranscription",
|
||||||
"amoderation",
|
"amoderation",
|
||||||
"arerank",
|
"arerank",
|
||||||
|
"_arealtime", # private function for realtime API
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -65,7 +66,6 @@ async def route_request(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||||
|
|
||||||
if "api_key" in data or "api_base" in data:
|
if "api_key" in data or "api_base" in data:
|
||||||
return getattr(litellm, f"{route_type}")(**data)
|
return getattr(litellm, f"{route_type}")(**data)
|
||||||
|
|
||||||
|
|
1
litellm/realtime_api/README.md
Normal file
1
litellm/realtime_api/README.md
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Abstraction / Routing logic for OpenAI's `/v1/realtime` endpoint.
|
91
litellm/realtime_api/main.py
Normal file
91
litellm/realtime_api/main.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
"""Abstraction function for OpenAI's realtime API"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import get_llm_provider
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
|
|
||||||
|
from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime
|
||||||
|
from ..llms.OpenAI.realtime.handler import OpenAIRealtime
|
||||||
|
|
||||||
|
azure_realtime = AzureOpenAIRealtime()
|
||||||
|
openai_realtime = OpenAIRealtime()
|
||||||
|
|
||||||
|
|
||||||
|
async def _arealtime(
|
||||||
|
model: str,
|
||||||
|
websocket: Any, # fastapi websocket
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
azure_ad_token: Optional[str] = None,
|
||||||
|
client: Optional[Any] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Private function to handle the realtime API call.
|
||||||
|
|
||||||
|
For PROXY use only.
|
||||||
|
"""
|
||||||
|
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
|
||||||
|
model=model,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _custom_llm_provider == "azure":
|
||||||
|
api_base = (
|
||||||
|
dynamic_api_base
|
||||||
|
or litellm_params.api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret_str("AZURE_API_BASE")
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
dynamic_api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.openai_key
|
||||||
|
or get_secret_str("AZURE_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
await azure_realtime.async_realtime(
|
||||||
|
model=model,
|
||||||
|
websocket=websocket,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version="2024-10-01-preview",
|
||||||
|
azure_ad_token=None,
|
||||||
|
client=None,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
elif _custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
dynamic_api_base
|
||||||
|
or litellm_params.api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or "https://api.openai.com/"
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
dynamic_api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.openai_key
|
||||||
|
or get_secret_str("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
await openai_realtime.async_realtime(
|
||||||
|
model=model,
|
||||||
|
websocket=websocket,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
client=None,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model: {model}")
|
|
@ -612,6 +612,7 @@ class Router:
|
||||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
model_name = None
|
model_name = None
|
||||||
|
traceback.print_stack()
|
||||||
try:
|
try:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(
|
deployment = self.get_available_deployment(
|
||||||
|
@ -1800,6 +1801,40 @@ class Router:
|
||||||
self.fail_calls[model_name] += 1
|
self.fail_calls[model_name] += 1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def _arealtime(self, model: str, **kwargs):
|
||||||
|
messages = [{"role": "user", "content": "dummy-text"}]
|
||||||
|
try:
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
kwargs.get("request_timeout", self.timeout)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
|
deployment = await self.async_get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if (
|
||||||
|
k not in kwargs
|
||||||
|
): # prioritize model-specific params > default router params
|
||||||
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
|
return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if self.num_retries > 0:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["messages"] = messages
|
||||||
|
kwargs["original_function"] = self._arealtime
|
||||||
|
return self.function_with_retries(**kwargs)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
def text_completion(
|
def text_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -1813,7 +1848,7 @@ class Router:
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["prompt"] = prompt
|
kwargs["prompt"] = prompt
|
||||||
kwargs["original_function"] = self._acompletion
|
kwargs["original_function"] = self.text_completion
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
@ -1840,7 +1875,7 @@ class Router:
|
||||||
if self.num_retries > 0:
|
if self.num_retries > 0:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["original_function"] = self.completion
|
kwargs["original_function"] = self.text_completion
|
||||||
return self.function_with_retries(**kwargs)
|
return self.function_with_retries(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -46,4 +46,5 @@ aioboto3==12.3.0 # for async sagemaker calls
|
||||||
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
|
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
|
||||||
pydantic==2.7.1 # proxy + openai req.
|
pydantic==2.7.1 # proxy + openai req.
|
||||||
jsonschema==4.22.0 # validating json schema
|
jsonschema==4.22.0 # validating json schema
|
||||||
|
websockets==10.4 # for realtime API
|
||||||
####
|
####
|
Loading…
Add table
Add a link
Reference in a new issue