diff --git a/litellm/router.py b/litellm/router.py index 0276f5a444..1e2a32263b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -7,7 +7,7 @@ # # Thank you ! We ❤️ you! - Krrish & Ishaan -import copy +import copy, httpx from datetime import datetime from typing import Dict, List, Optional, Union, Literal, Any import random, threading, time, traceback, uuid @@ -18,6 +18,7 @@ import inspect, concurrent from openai import AsyncOpenAI from collections import defaultdict from litellm.router_strategy.least_busy import LeastBusyLoggingHandler +from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport import copy class Router: """ @@ -166,7 +167,7 @@ class Router: self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n") - ### COMPLETION + EMBEDDING FUNCTIONS + ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS def completion(self, model: str, @@ -260,6 +261,50 @@ class Router: self.fail_calls[model_name] +=1 raise e + async def aimage_generation(self, + prompt: str, + model: str, + **kwargs): + try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["original_function"] = self._image_generation + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + async def _image_generation(self, + prompt: str, + model: str, + **kwargs): + try: + self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}") + deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None)) + kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") + self.total_calls[model_name] +=1 + response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs}) + self.success_calls[model_name] +=1 + return response + except Exception as e: + if model_name is not None: + self.fail_calls[model_name] +=1 + raise e + def text_completion(self, model: str, prompt: str, @@ -1009,14 +1054,16 @@ class Router: azure_endpoint=api_base, api_version=api_version, timeout=timeout, - max_retries=max_retries + max_retries=max_retries, + http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore ) model["client"] = openai.AzureOpenAI( api_key=api_key, azure_endpoint=api_base, api_version=api_version, timeout=timeout, - max_retries=max_retries + max_retries=max_retries, + http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore ) # streaming clients should have diff timeouts model["stream_async_client"] = openai.AsyncAzureOpenAI( @@ -1024,7 +1071,7 @@ class Router: azure_endpoint=api_base, api_version=api_version, timeout=stream_timeout, - max_retries=max_retries + max_retries=max_retries, ) model["stream_client"] = openai.AzureOpenAI( diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 403c8dc2af..435be3ed5f 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -423,6 +423,41 @@ def test_function_calling_on_router(): # test_function_calling_on_router() +### IMAGE GENERATION +async def test_aimg_gen_on_router(): + litellm.set_verbose = True + try: + model_list = [ + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "dall-e-3", + }, + }, + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "azure/dall-e-3-test", + "api_version": "2023-12-01-preview", + "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), + "api_key": os.getenv("AZURE_SWEDEN_API_KEY") + } + } + ] + router = Router(model_list=model_list) + response = await router.aimage_generation( + model="dall-e-3", + prompt="A cute baby sea otter" + ) + print(response) + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + +asyncio.run(test_aimg_gen_on_router()) +### + def test_aembedding_on_router(): litellm.set_verbose = True try: