From f9d0bcc5a1dbe0adf167d63b03dcb7d972d3665d Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 3 Oct 2024 17:11:22 -0400 Subject: [PATCH] OpenAI `/v1/realtime` api support (#6047) * feat(azure/realtime): initial working commit for proxy azure openai realtime endpoint support Adds support for passing /v1/realtime calls via litellm proxy * feat(realtime_api/main.py): abstraction for handling openai realtime api calls * feat(router.py): add `arealtime()` endpoint in router for realtime api calls Allows using `model_list` in proxy for realtime as well * fix: make realtime api a private function Structure might change based on feedback. Make that clear to users. * build(requirements.txt): add websockets to the requirements.txt * feat(openai/realtime): add openai /v1/realtime api support --- litellm/__init__.py | 1 + .../get_llm_provider_logic.py | 2 + litellm/llms/AzureOpenAI/realtime/handler.py | 86 ++++++++++++++++++ litellm/llms/OpenAI/realtime/handler.py | 81 +++++++++++++++++ litellm/proxy/_new_secret_config.yaml | 14 ++- litellm/proxy/proxy_server.py | 39 ++++++++ litellm/proxy/route_llm_request.py | 2 +- litellm/realtime_api/README.md | 1 + litellm/realtime_api/main.py | 91 +++++++++++++++++++ litellm/router.py | 39 +++++++- requirements.txt | 1 + 11 files changed, 350 insertions(+), 7 deletions(-) create mode 100644 litellm/llms/AzureOpenAI/realtime/handler.py create mode 100644 litellm/llms/OpenAI/realtime/handler.py create mode 100644 litellm/realtime_api/README.md create mode 100644 litellm/realtime_api/main.py diff --git a/litellm/__init__.py b/litellm/__init__.py index ea0dcb45c..17154e2f7 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1033,6 +1033,7 @@ from .router import Router from .assistants.main import * from .batches.main import * from .rerank_api.main import * +from .realtime_api.main import _arealtime from .fine_tuning.main import * from .files.main import * from .scheduler import * diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index c69912c9c..41132a39e 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -311,6 +311,8 @@ def get_llm_provider( dynamic_api_key ) ) + if dynamic_api_key is None and api_key is not None: + dynamic_api_key = api_key return model, custom_llm_provider, dynamic_api_key, api_base elif model.split("/", 1)[0] in litellm.provider_list: custom_llm_provider = model.split("/", 1)[0] diff --git a/litellm/llms/AzureOpenAI/realtime/handler.py b/litellm/llms/AzureOpenAI/realtime/handler.py new file mode 100644 index 000000000..7d58ee78f --- /dev/null +++ b/litellm/llms/AzureOpenAI/realtime/handler.py @@ -0,0 +1,86 @@ +""" +This file contains the calling Azure OpenAI's `/openai/realtime` endpoint. + +This requires websockets, and is currently only supported on LiteLLM Proxy. +""" + +import asyncio +from typing import Any, Optional + +from ..azure import AzureChatCompletion + +# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01" + + +async def forward_messages(client_ws: Any, backend_ws: Any): + import websockets + + try: + while True: + message = await backend_ws.recv() + await client_ws.send_text(message) + except websockets.exceptions.ConnectionClosed: # type: ignore + pass + + +class AzureOpenAIRealtime(AzureChatCompletion): + def _construct_url(self, api_base: str, model: str, api_version: str) -> str: + """ + Example output: + "wss://my-endpoint-sweden-berri992.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview"; + + """ + api_base = api_base.replace("https://", "wss://") + return ( + f"{api_base}/openai/realtime?api-version={api_version}&deployment={model}" + ) + + async def async_realtime( + self, + model: str, + websocket: Any, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + azure_ad_token: Optional[str] = None, + client: Optional[Any] = None, + timeout: Optional[float] = None, + ): + import websockets + + if api_base is None: + raise ValueError("api_base is required for Azure OpenAI calls") + if api_version is None: + raise ValueError("api_version is required for Azure OpenAI calls") + + url = self._construct_url(api_base, model, api_version) + + try: + async with websockets.connect( # type: ignore + url, + extra_headers={ + "api-key": api_key, # type: ignore + }, + ) as backend_ws: + forward_task = asyncio.create_task( + forward_messages(websocket, backend_ws) + ) + + try: + while True: + message = await websocket.receive_text() + await backend_ws.send(message) + except websockets.exceptions.ConnectionClosed: # type: ignore + forward_task.cancel() + finally: + if not forward_task.done(): + forward_task.cancel() + try: + await forward_task + except asyncio.CancelledError: + pass + + except websockets.exceptions.InvalidStatusCode as e: # type: ignore + await websocket.close(code=e.status_code, reason=str(e)) + except Exception as e: + await websocket.close(code=1011, reason=f"Internal server error: {str(e)}") diff --git a/litellm/llms/OpenAI/realtime/handler.py b/litellm/llms/OpenAI/realtime/handler.py new file mode 100644 index 000000000..08e5fa0b9 --- /dev/null +++ b/litellm/llms/OpenAI/realtime/handler.py @@ -0,0 +1,81 @@ +""" +This file contains the calling Azure OpenAI's `/openai/realtime` endpoint. + +This requires websockets, and is currently only supported on LiteLLM Proxy. +""" + +import asyncio +from typing import Any, Optional + +from ..openai import OpenAIChatCompletion + + +async def forward_messages(client_ws: Any, backend_ws: Any): + import websockets + + try: + while True: + message = await backend_ws.recv() + await client_ws.send_text(message) + except websockets.exceptions.ConnectionClosed: # type: ignore + pass + + +class OpenAIRealtime(OpenAIChatCompletion): + def _construct_url(self, api_base: str, model: str) -> str: + """ + Example output: + "BACKEND_WS_URL = "wss://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01""; + """ + api_base = api_base.replace("https://", "wss://") + api_base = api_base.replace("http://", "ws://") + return f"{api_base}/v1/realtime?model={model}" + + async def async_realtime( + self, + model: str, + websocket: Any, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + client: Optional[Any] = None, + timeout: Optional[float] = None, + ): + import websockets + + if api_base is None: + raise ValueError("api_base is required for Azure OpenAI calls") + if api_key is None: + raise ValueError("api_key is required for Azure OpenAI calls") + + url = self._construct_url(api_base, model) + + try: + async with websockets.connect( # type: ignore + url, + extra_headers={ + "Authorization": f"Bearer {api_key}", # type: ignore + "OpenAI-Beta": "realtime=v1", + }, + ) as backend_ws: + forward_task = asyncio.create_task( + forward_messages(websocket, backend_ws) + ) + + try: + while True: + message = await websocket.receive_text() + await backend_ws.send(message) + except websockets.exceptions.ConnectionClosed: # type: ignore + forward_task.cancel() + finally: + if not forward_task.done(): + forward_task.cancel() + try: + await forward_task + except asyncio.CancelledError: + pass + + except websockets.exceptions.InvalidStatusCode as e: # type: ignore + await websocket.close(code=e.status_code, reason=str(e)) + except Exception as e: + await websocket.close(code=1011, reason=f"Internal server error: {str(e)}") diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 19856ff3d..040a4a536 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,7 +1,13 @@ model_list: - - model_name: whisper + - model_name: gpt-4o-realtime-audio litellm_params: - model: whisper-1 + model: azure/gpt-4o-realtime-preview + api_key: os.environ/AZURE_SWEDEN_API_KEY + api_base: os.environ/AZURE_SWEDEN_API_BASE + + - model_name: openai-gpt-4o-realtime-audio + litellm_params: + model: openai/gpt-4o-realtime-preview-2024-10-01 api_key: os.environ/OPENAI_API_KEY - model_info: - mode: audio_transcription + api_base: http://localhost:8080 + diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 2f002d964..12224634c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4143,6 +4143,45 @@ async def audio_transcriptions( ) +###################################################################### + +# /v1/realtime Endpoints + +###################################################################### +from fastapi import FastAPI, WebSocket, WebSocketDisconnect + +from litellm import _arealtime + + +@app.websocket("/v1/realtime") +async def websocket_endpoint(websocket: WebSocket, model: str): + import websockets + + await websocket.accept() + + data = { + "model": model, + "websocket": websocket, + } + + ### ROUTE THE REQUEST ### + try: + llm_call = await route_request( + data=data, + route_type="_arealtime", + llm_router=llm_router, + user_model=user_model, + ) + + await llm_call + except websockets.exceptions.InvalidStatusCode as e: # type: ignore + verbose_proxy_logger.exception("Invalid status code") + await websocket.close(code=e.status_code, reason="Invalid status code") + except Exception: + verbose_proxy_logger.exception("Internal server error") + await websocket.close(code=1011, reason="Internal server error") + + ###################################################################### # /v1/assistant Endpoints diff --git a/litellm/proxy/route_llm_request.py b/litellm/proxy/route_llm_request.py index 4d1ac6c15..63e41f64f 100644 --- a/litellm/proxy/route_llm_request.py +++ b/litellm/proxy/route_llm_request.py @@ -58,6 +58,7 @@ async def route_request( "atranscription", "amoderation", "arerank", + "_arealtime", # private function for realtime API ], ): """ @@ -65,7 +66,6 @@ async def route_request( """ router_model_names = llm_router.model_names if llm_router is not None else [] - if "api_key" in data or "api_base" in data: return getattr(litellm, f"{route_type}")(**data) diff --git a/litellm/realtime_api/README.md b/litellm/realtime_api/README.md new file mode 100644 index 000000000..6b467c056 --- /dev/null +++ b/litellm/realtime_api/README.md @@ -0,0 +1 @@ +Abstraction / Routing logic for OpenAI's `/v1/realtime` endpoint. \ No newline at end of file diff --git a/litellm/realtime_api/main.py b/litellm/realtime_api/main.py new file mode 100644 index 000000000..5e512795f --- /dev/null +++ b/litellm/realtime_api/main.py @@ -0,0 +1,91 @@ +"""Abstraction function for OpenAI's realtime API""" + +import os +from typing import Any, Optional + +import litellm +from litellm import get_llm_provider +from litellm.secret_managers.main import get_secret_str +from litellm.types.router import GenericLiteLLMParams + +from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime +from ..llms.OpenAI.realtime.handler import OpenAIRealtime + +azure_realtime = AzureOpenAIRealtime() +openai_realtime = OpenAIRealtime() + + +async def _arealtime( + model: str, + websocket: Any, # fastapi websocket + api_base: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + azure_ad_token: Optional[str] = None, + client: Optional[Any] = None, + timeout: Optional[float] = None, + **kwargs, +): + """ + Private function to handle the realtime API call. + + For PROXY use only. + """ + litellm_params = GenericLiteLLMParams(**kwargs) + + model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider( + model=model, + api_base=api_base, + api_key=api_key, + ) + + if _custom_llm_provider == "azure": + api_base = ( + dynamic_api_base + or litellm_params.api_base + or litellm.api_base + or get_secret_str("AZURE_API_BASE") + ) + # set API KEY + api_key = ( + dynamic_api_key + or litellm.api_key + or litellm.openai_key + or get_secret_str("AZURE_API_KEY") + ) + + await azure_realtime.async_realtime( + model=model, + websocket=websocket, + api_base=api_base, + api_key=api_key, + api_version="2024-10-01-preview", + azure_ad_token=None, + client=None, + timeout=timeout, + ) + elif _custom_llm_provider == "openai": + api_base = ( + dynamic_api_base + or litellm_params.api_base + or litellm.api_base + or "https://api.openai.com/" + ) + # set API KEY + api_key = ( + dynamic_api_key + or litellm.api_key + or litellm.openai_key + or get_secret_str("OPENAI_API_KEY") + ) + + await openai_realtime.async_realtime( + model=model, + websocket=websocket, + api_base=api_base, + api_key=api_key, + client=None, + timeout=timeout, + ) + else: + raise ValueError(f"Unsupported model: {model}") diff --git a/litellm/router.py b/litellm/router.py index 043cd64ee..d73f5d4b3 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -612,6 +612,7 @@ class Router: self, model: str, messages: List[Dict[str, str]], **kwargs ) -> Union[ModelResponse, CustomStreamWrapper]: model_name = None + traceback.print_stack() try: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment( @@ -1800,6 +1801,40 @@ class Router: self.fail_calls[model_name] += 1 raise e + async def _arealtime(self, model: str, **kwargs): + messages = [{"role": "user", "content": "dummy-text"}] + try: + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + + # pick the one that is available (lowest TPM/RPM) + deployment = await self.async_get_available_deployment( + model=model, + messages=messages, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + + data = deployment["litellm_params"].copy() + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore + except Exception as e: + traceback.print_exc() + if self.num_retries > 0: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["original_function"] = self._arealtime + return self.function_with_retries(**kwargs) + else: + raise e + def text_completion( self, model: str, @@ -1813,7 +1848,7 @@ class Router: try: kwargs["model"] = model kwargs["prompt"] = prompt - kwargs["original_function"] = self._acompletion + kwargs["original_function"] = self.text_completion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) @@ -1840,7 +1875,7 @@ class Router: if self.num_retries > 0: kwargs["model"] = model kwargs["messages"] = messages - kwargs["original_function"] = self.completion + kwargs["original_function"] = self.text_completion return self.function_with_retries(**kwargs) else: raise e diff --git a/requirements.txt b/requirements.txt index 966771bc7..1cedeeaf7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,4 +46,5 @@ aioboto3==12.3.0 # for async sagemaker calls tenacity==8.2.3 # for retrying requests, when litellm.num_retries set pydantic==2.7.1 # proxy + openai req. jsonschema==4.22.0 # validating json schema +websockets==10.4 # for realtime API #### \ No newline at end of file