From be68796ebac1a60c725cfdb9e43969e67ec29983 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 14:38:44 +0530 Subject: [PATCH] fix(router.py): add support for async image generation endpoints --- litellm/llms/azure.py | 25 +++++++-- litellm/llms/custom_httpx/azure_dall_e_2.py | 56 +++++++++++++++++++++ litellm/main.py | 4 +- litellm/router.py | 4 +- litellm/tests/test_router.py | 25 +++++++-- litellm/utils.py | 8 ++- 6 files changed, 109 insertions(+), 13 deletions(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index b269afec7..2e75f7b40 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -6,7 +6,7 @@ from typing import Callable, Optional from litellm import OpenAIConfig import litellm, json import httpx -from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport +from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from openai import AzureOpenAI, AsyncAzureOpenAI class AzureOpenAIError(Exception): @@ -480,7 +480,8 @@ class AzureChatCompletion(BaseLLM): response = None try: if client is None: - openai_aclient = AsyncAzureOpenAI(**azure_client_params) + client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) + openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params) else: openai_aclient = client response = await openai_aclient.images.generate(**data) @@ -492,7 +493,7 @@ class AzureChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, original_response=stringified_response, ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation") except Exception as e: ## LOGGING logging_obj.post_call( @@ -511,6 +512,7 @@ class AzureChatCompletion(BaseLLM): api_base: Optional[str] = None, api_version: Optional[str] = None, model_response: Optional[litellm.utils.ImageResponse] = None, + azure_ad_token: Optional[str]=None, logging_obj=None, optional_params=None, client=None, @@ -531,13 +533,26 @@ class AzureChatCompletion(BaseLLM): if not isinstance(max_retries, int): raise AzureOpenAIError(status_code=422, message="max retries must be an int") + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "max_retries": max_retries, + "timeout": timeout + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + if aimg_generation == True: - response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore return response if client is None: client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),) - azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore + azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore else: azure_client = client diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py index cda84b156..3bc50dda7 100644 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -1,5 +1,61 @@ import time, json, httpx, asyncio +class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): + """ + Async implementation of custom http transport + """ + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + if "images/generations" in request.url.path and request.url.params[ + "api-version" + ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ]: + request.url = request.url.copy_with(path="/openai/images/generations:submit") + response = await super().handle_async_request(request) + operation_location_url = response.headers["operation-location"] + request.url = httpx.URL(operation_location_url) + request.method = "GET" + response = await super().handle_async_request(request) + await response.aread() + + timeout_secs: int = 120 + start_time = time.time() + while response.json()["status"] not in ["succeeded", "failed"]: + if time.time() - start_time > timeout_secs: + timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} + return httpx.Response( + status_code=400, + headers=response.headers, + content=json.dumps(timeout).encode("utf-8"), + request=request, + ) + + time.sleep(int(response.headers.get("retry-after")) or 10) + response = await super().handle_async_request(request) + await response.aread() + + if response.json()["status"] == "failed": + error_data = response.json() + return httpx.Response( + status_code=400, + headers=response.headers, + content=json.dumps(error_data).encode("utf-8"), + request=request, + ) + + result = response.json()["result"] + return httpx.Response( + status_code=200, + headers=response.headers, + content=json.dumps(result).encode("utf-8"), + request=request, + ) + return await super().handle_async_request(request) + class CustomHTTPTransport(httpx.HTTPTransport): """ This class was written as a workaround to support dall-e-2 on openai > v1.x diff --git a/litellm/main.py b/litellm/main.py index b2ed72f7f..0e7752e16 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2351,9 +2351,9 @@ def image_generation(prompt: str, get_secret("AZURE_AD_TOKEN") ) - model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimage_generation) + model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimg_generation) elif custom_llm_provider == "openai": - model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimage_generation) + model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimg_generation) return model_response diff --git a/litellm/router.py b/litellm/router.py index ddfd6b87c..4ee067e2f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -18,7 +18,7 @@ import inspect, concurrent from openai import AsyncOpenAI from collections import defaultdict from litellm.router_strategy.least_busy import LeastBusyLoggingHandler -from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport +from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport import copy class Router: """ @@ -525,7 +525,6 @@ class Router: async def async_function_with_retries(self, *args, **kwargs): self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}") - backoff_factor = 1 original_function = kwargs.pop("original_function") fallbacks = kwargs.pop("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) @@ -1099,6 +1098,7 @@ class Router: api_version=api_version, timeout=timeout, max_retries=max_retries, + http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) # type: ignore ) model["client"] = openai.AzureOpenAI( api_key=api_key, diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index d7f929f25..b52db394f 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -424,6 +424,7 @@ def test_function_calling_on_router(): # test_function_calling_on_router() ### IMAGE GENERATION +@pytest.mark.asyncio async def test_aimg_gen_on_router(): litellm.set_verbose = True try: @@ -442,14 +443,32 @@ async def test_aimg_gen_on_router(): "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), "api_key": os.getenv("AZURE_SWEDEN_API_KEY") } + }, + { + "model_name": "dall-e-2", + "litellm_params": { + "model": "azure/", + "api_version": "2023-06-01-preview", + "api_base": os.getenv("AZURE_API_BASE"), + "api_key": os.getenv("AZURE_API_KEY") + } } ] router = Router(model_list=model_list) + # response = await router.aimage_generation( + # model="dall-e-3", + # prompt="A cute baby sea otter" + # ) + # print(response) + # assert len(response.data) > 0 + response = await router.aimage_generation( - model="dall-e-3", + model="dall-e-2", prompt="A cute baby sea otter" ) print(response) + assert len(response.data) > 0 + router.reset() except Exception as e: traceback.print_exc() @@ -489,7 +508,7 @@ def test_img_gen_on_router(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") -test_img_gen_on_router() +# test_img_gen_on_router() ### def test_aembedding_on_router(): @@ -625,7 +644,7 @@ async def test_mistral_on_router(): ] ) print(response) -asyncio.run(test_mistral_on_router()) +# asyncio.run(test_mistral_on_router()) def test_openai_completion_on_router(): # [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream diff --git a/litellm/utils.py b/litellm/utils.py index f68fd49b0..3a48958fc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -551,6 +551,8 @@ class ImageResponse(OpenAIObject): data: Optional[list] = None + usage: Optional[dict] = None + def __init__(self, created=None, data=None, response_ms=None): if response_ms: _response_ms = response_ms @@ -565,8 +567,10 @@ class ImageResponse(OpenAIObject): created = created else: created = None - + super().__init__(data=data, created=created) + self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + def __contains__(self, key): # Define custom behavior for the 'in' operator @@ -1668,6 +1672,8 @@ def client(original_function): return result elif "aembedding" in kwargs and kwargs["aembedding"] == True: return result + elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True: + return result ### POST-CALL RULES ### post_call_processing(original_response=result, model=model or None)