diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index e78529490..026f06fb8 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -6,6 +6,7 @@ from typing import Callable, Optional from litellm import OpenAIConfig import litellm, json import httpx +from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport from openai import AzureOpenAI, AsyncAzureOpenAI class AzureOpenAIError(Exception): @@ -464,11 +465,12 @@ class AzureChatCompletion(BaseLLM): raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) def image_generation(self, - prompt: list, + prompt: str, timeout: float, model: Optional[str]=None, api_key: Optional[str] = None, api_base: Optional[str] = None, + api_version: Optional[str] = None, model_response: Optional[litellm.utils.ImageResponse] = None, logging_obj=None, optional_params=None, @@ -477,9 +479,12 @@ class AzureChatCompletion(BaseLLM): ): exception_mapping_worked = False try: - model = model + if model and len(model) > 0: + model = model + else: + model = None data = { - # "model": model, + "model": model, "prompt": prompt, **optional_params } @@ -492,7 +497,8 @@ class AzureChatCompletion(BaseLLM): # return response if client is None: - azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore + client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),) + azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore else: azure_client = client diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py new file mode 100644 index 000000000..c5263bd49 --- /dev/null +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -0,0 +1,64 @@ +import time +import json +import httpx + +class CustomHTTPTransport(httpx.HTTPTransport): + """ + This class was written as a workaround to support dall-e-2 on openai > v1.x + + Refer to this issue for more: https://github.com/openai/openai-python/issues/692 + """ + def handle_request( + self, + request: httpx.Request, + ) -> httpx.Response: + if "images/generations" in request.url.path and request.url.params[ + "api-version" + ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ]: + request.url = request.url.copy_with(path="/openai/images/generations:submit") + response = super().handle_request(request) + operation_location_url = response.headers["operation-location"] + request.url = httpx.URL(operation_location_url) + request.method = "GET" + response = super().handle_request(request) + response.read() + + timeout_secs: int = 120 + start_time = time.time() + while response.json()["status"] not in ["succeeded", "failed"]: + if time.time() - start_time > timeout_secs: + timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} + return httpx.Response( + status_code=400, + headers=response.headers, + content=json.dumps(timeout).encode("utf-8"), + request=request, + ) + + time.sleep(int(response.headers.get("retry-after")) or 10) + response = super().handle_request(request) + response.read() + + if response.json()["status"] == "failed": + error_data = response.json() + return httpx.Response( + status_code=400, + headers=response.headers, + content=json.dumps(error_data).encode("utf-8"), + request=request, + ) + + result = response.json()["result"] + return httpx.Response( + status_code=200, + headers=response.headers, + content=json.dumps(result).encode("utf-8"), + request=request, + ) + return super().handle_request(request) \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 0018844c3..318ba8ffb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2307,8 +2307,7 @@ def image_generation(prompt: str, get_secret("AZURE_AD_TOKEN") ) - # model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) - pass + model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version) elif custom_llm_provider == "openai": model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 1e4155062..2f561e881 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -727,7 +727,7 @@ def test_completion_azure(): except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_azure() +test_completion_azure() def test_azure_openai_ad_token(): # this tests if the azure ad token is set in the request header diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index a265c0f65..d177ec81d 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -4,7 +4,8 @@ import sys, os import traceback from dotenv import load_dotenv - +import logging +logging.basicConfig(level=logging.DEBUG) load_dotenv() import os @@ -18,14 +19,22 @@ def test_image_generation_openai(): litellm.set_verbose = True response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") print(f"response: {response}") + assert len(response.data) > 0 # test_image_generation_openai() -# def test_image_generation_azure(): -# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure") -# print(f"response: {response}") +def test_image_generation_azure(): + response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview") + print(f"response: {response}") + assert len(response.data) > 0 # test_image_generation_azure() +def test_image_generation_azure_dall_e_3(): + litellm.set_verbose = True + response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY")) + print(f"response: {response}") + assert len(response.data) > 0 +# test_image_generation_azure_dall_e_3() # @pytest.mark.asyncio # async def test_async_image_generation_openai(): # response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") diff --git a/litellm/utils.py b/litellm/utils.py index c449a239e..b46ba5a9b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1613,6 +1613,7 @@ def client(original_function): try: model = args[0] if len(args) > 0 else kwargs["model"] except: + model = None call_type = original_function.__name__ if call_type != CallTypes.image_generation.value: raise ValueError("model param not passed in.")