mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge 6d77f38bb1
into b82af5b826
This commit is contained in:
commit
ffb42e3973
4 changed files with 559 additions and 1 deletions
|
@ -11,10 +11,12 @@ from ..litellm_core_utils.get_litellm_params import get_litellm_params
|
|||
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
|
||||
from ..llms.openai.realtime.handler import OpenAIRealtime
|
||||
from ..llms.vertex_ai.multimodal_live.handler import GeminiLive # Import the new handler
|
||||
from ..utils import client as wrapper_client
|
||||
|
||||
azure_realtime = AzureOpenAIRealtime()
|
||||
openai_realtime = OpenAIRealtime()
|
||||
gemini_realtime = GeminiLive() # Instantiate the Gemini handler
|
||||
|
||||
|
||||
@wrapper_client
|
||||
|
@ -104,6 +106,31 @@ async def _arealtime(
|
|||
client=None,
|
||||
timeout=timeout,
|
||||
)
|
||||
elif _custom_llm_provider == "vertex_ai_beta" or _custom_llm_provider == "vertex_ai": # Add the Gemini case
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
or litellm_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("GEMINI_API_BASE")
|
||||
# default base for vertexs
|
||||
or "https://us-central1-aiplatform.googleapis.com"
|
||||
)
|
||||
|
||||
try:
|
||||
await gemini_realtime.async_realtime(
|
||||
model=model,
|
||||
websocket=websocket,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
optional_params={},
|
||||
logging_obj=litellm_logging_obj,
|
||||
vertex_location=litellm_params.vertex_location, # Add default vertex location
|
||||
vertex_credentials_path=str(litellm_params.vertex_credentials), # Add default vertex credentials
|
||||
vertex_project=litellm_params.vertex_project, # Add default vertex project
|
||||
custom_llm_provider=_custom_llm_provider, # Add custom llm provider
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to connect to Gemini realtime API: {e}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
|
@ -143,8 +170,16 @@ async def _realtime_health_check(
|
|||
url = openai_realtime._construct_url(
|
||||
api_base=api_base or "https://api.openai.com/", model=model
|
||||
)
|
||||
elif custom_llm_provider == "gemini": # Add Gemini case
|
||||
url = gemini_realtime._construct_url(
|
||||
api_base=api_base or "https://generativelanguage.googleapis.com",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
if url is None:
|
||||
raise ValueError("Failed to construct WebSocket URL")
|
||||
|
||||
async with websockets.connect( # type: ignore
|
||||
url,
|
||||
extra_headers={
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue