feat(azure.py): add support for azure image generations endpoint

This commit is contained in:
Krrish Dholakia 2023-12-20 16:37:21 +05:30
parent f0df28362a
commit b3962e483f
6 changed files with 90 additions and 11 deletions

View file

@ -6,6 +6,7 @@ from typing import Callable, Optional
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
import httpx import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
@ -464,11 +465,12 @@ class AzureChatCompletion(BaseLLM):
raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
def image_generation(self, def image_generation(self,
prompt: list, prompt: str,
timeout: float, timeout: float,
model: Optional[str]=None, model: Optional[str]=None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None, model_response: Optional[litellm.utils.ImageResponse] = None,
logging_obj=None, logging_obj=None,
optional_params=None, optional_params=None,
@ -477,9 +479,12 @@ class AzureChatCompletion(BaseLLM):
): ):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
model = model if model and len(model) > 0:
model = model
else:
model = None
data = { data = {
# "model": model, "model": model,
"prompt": prompt, "prompt": prompt,
**optional_params **optional_params
} }
@ -492,7 +497,8 @@ class AzureChatCompletion(BaseLLM):
# return response # return response
if client is None: if client is None:
azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore
else: else:
azure_client = client azure_client = client

View file

@ -0,0 +1,64 @@
import time
import json
import httpx
class CustomHTTPTransport(httpx.HTTPTransport):
"""
This class was written as a workaround to support dall-e-2 on openai > v1.x
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
"""
def handle_request(
self,
request: httpx.Request,
) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
response = super().handle_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
request.method = "GET"
response = super().handle_request(request)
response.read()
timeout_secs: int = 120
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(timeout).encode("utf-8"),
request=request,
)
time.sleep(int(response.headers.get("retry-after")) or 10)
response = super().handle_request(request)
response.read()
if response.json()["status"] == "failed":
error_data = response.json()
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(error_data).encode("utf-8"),
request=request,
)
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=request,
)
return super().handle_request(request)

View file

@ -2307,8 +2307,7 @@ def image_generation(prompt: str,
get_secret("AZURE_AD_TOKEN") get_secret("AZURE_AD_TOKEN")
) )
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version)
pass
elif custom_llm_provider == "openai": elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)

View file

@ -727,7 +727,7 @@ def test_completion_azure():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_azure() test_completion_azure()
def test_azure_openai_ad_token(): def test_azure_openai_ad_token():
# this tests if the azure ad token is set in the request header # this tests if the azure ad token is set in the request header

View file

@ -4,7 +4,8 @@
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
import logging
logging.basicConfig(level=logging.DEBUG)
load_dotenv() load_dotenv()
import os import os
@ -18,14 +19,22 @@ def test_image_generation_openai():
litellm.set_verbose = True litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_openai() # test_image_generation_openai()
# def test_image_generation_azure(): def test_image_generation_azure():
# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure") response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview")
# print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_azure() # test_image_generation_azure()
def test_image_generation_azure_dall_e_3():
litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY"))
print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_azure_dall_e_3()
# @pytest.mark.asyncio # @pytest.mark.asyncio
# async def test_async_image_generation_openai(): # async def test_async_image_generation_openai():
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") # response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")

View file

@ -1613,6 +1613,7 @@ def client(original_function):
try: try:
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
except: except:
model = None
call_type = original_function.__name__ call_type = original_function.__name__
if call_type != CallTypes.image_generation.value: if call_type != CallTypes.image_generation.value:
raise ValueError("model param not passed in.") raise ValueError("model param not passed in.")