From 04bbd0649f37f3991b5bd570a3cf5788af599287 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Dec 2023 19:10:59 +0530 Subject: [PATCH] fix(router.py): only do sync image gen fallbacks for now The customhttptransport we use for dall-e-2 only works for sync httpx calls, not async. Will need to spend some time writing the async version n --- litellm/llms/custom_httpx/azure_dall_e_2.py | 4 +- litellm/router.py | 49 +++++++++++++++++++-- litellm/tests/test_proxy_custom_logger.py | 2 - litellm/tests/test_router.py | 36 ++++++++++++++- 4 files changed, 82 insertions(+), 9 deletions(-) diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py index c5263bd49..cda84b156 100644 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -1,6 +1,4 @@ -import time -import json -import httpx +import time, json, httpx, asyncio class CustomHTTPTransport(httpx.HTTPTransport): """ diff --git a/litellm/router.py b/litellm/router.py index 1e2a32263..ddfd6b87c 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -261,6 +261,50 @@ class Router: self.fail_calls[model_name] +=1 raise e + def image_generation(self, + prompt: str, + model: str, + **kwargs): + try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["original_function"] = self._image_generation + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = self.function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + def _image_generation(self, + prompt: str, + model: str, + **kwargs): + try: + self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}") + deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None)) + kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") + self.total_calls[model_name] +=1 + response = litellm.image_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs}) + self.success_calls[model_name] +=1 + return response + except Exception as e: + if model_name is not None: + self.fail_calls[model_name] +=1 + raise e + async def aimage_generation(self, prompt: str, model: str, @@ -268,7 +312,7 @@ class Router: try: kwargs["model"] = model kwargs["prompt"] = prompt - kwargs["original_function"] = self._image_generation + kwargs["original_function"] = self._aimage_generation kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) @@ -278,7 +322,7 @@ class Router: except Exception as e: raise e - async def _image_generation(self, + async def _aimage_generation(self, prompt: str, model: str, **kwargs): @@ -1055,7 +1099,6 @@ class Router: api_version=api_version, timeout=timeout, max_retries=max_retries, - http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore ) model["client"] = openai.AzureOpenAI( api_key=api_key, diff --git a/litellm/tests/test_proxy_custom_logger.py b/litellm/tests/test_proxy_custom_logger.py index 6ddc9caac..0a3097af9 100644 --- a/litellm/tests/test_proxy_custom_logger.py +++ b/litellm/tests/test_proxy_custom_logger.py @@ -99,8 +99,6 @@ def test_embedding(client): def test_chat_completion(client): try: # Your test data - - print("initialized proxy") litellm.set_verbose=False from litellm.proxy.utils import get_instance_fn my_custom_logger = get_instance_fn( diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 435be3ed5..d7f929f25 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -455,7 +455,41 @@ async def test_aimg_gen_on_router(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") -asyncio.run(test_aimg_gen_on_router()) +# asyncio.run(test_aimg_gen_on_router()) + +def test_img_gen_on_router(): + litellm.set_verbose = True + try: + model_list = [ + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "dall-e-3", + }, + }, + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "azure/dall-e-3-test", + "api_version": "2023-12-01-preview", + "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), + "api_key": os.getenv("AZURE_SWEDEN_API_KEY") + } + } + ] + router = Router(model_list=model_list) + response = router.image_generation( + model="dall-e-3", + prompt="A cute baby sea otter" + ) + print(response) + assert len(response.data) > 0 + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + +test_img_gen_on_router() ### def test_aembedding_on_router():