forked from phoenix/litellm-mirror
feat(azure.py): add support for azure image generations endpoint
This commit is contained in:
parent
f0df28362a
commit
b3962e483f
6 changed files with 90 additions and 11 deletions
|
@ -6,6 +6,7 @@ from typing import Callable, Optional
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
import litellm, json
|
import litellm, json
|
||||||
import httpx
|
import httpx
|
||||||
|
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport
|
||||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
class AzureOpenAIError(Exception):
|
||||||
|
@ -464,11 +465,12 @@ class AzureChatCompletion(BaseLLM):
|
||||||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
def image_generation(self,
|
def image_generation(self,
|
||||||
prompt: list,
|
prompt: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
model: Optional[str]=None,
|
model: Optional[str]=None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
|
@ -477,9 +479,12 @@ class AzureChatCompletion(BaseLLM):
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
try:
|
try:
|
||||||
model = model
|
if model and len(model) > 0:
|
||||||
|
model = model
|
||||||
|
else:
|
||||||
|
model = None
|
||||||
data = {
|
data = {
|
||||||
# "model": model,
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
**optional_params
|
**optional_params
|
||||||
}
|
}
|
||||||
|
@ -492,7 +497,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
# return response
|
# return response
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore
|
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
|
||||||
|
azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
|
|
||||||
|
|
64
litellm/llms/custom_httpx/azure_dall_e_2.py
Normal file
64
litellm/llms/custom_httpx/azure_dall_e_2.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
class CustomHTTPTransport(httpx.HTTPTransport):
|
||||||
|
"""
|
||||||
|
This class was written as a workaround to support dall-e-2 on openai > v1.x
|
||||||
|
|
||||||
|
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
|
||||||
|
"""
|
||||||
|
def handle_request(
|
||||||
|
self,
|
||||||
|
request: httpx.Request,
|
||||||
|
) -> httpx.Response:
|
||||||
|
if "images/generations" in request.url.path and request.url.params[
|
||||||
|
"api-version"
|
||||||
|
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||||
|
"2023-06-01-preview",
|
||||||
|
"2023-07-01-preview",
|
||||||
|
"2023-08-01-preview",
|
||||||
|
"2023-09-01-preview",
|
||||||
|
"2023-10-01-preview",
|
||||||
|
]:
|
||||||
|
request.url = request.url.copy_with(path="/openai/images/generations:submit")
|
||||||
|
response = super().handle_request(request)
|
||||||
|
operation_location_url = response.headers["operation-location"]
|
||||||
|
request.url = httpx.URL(operation_location_url)
|
||||||
|
request.method = "GET"
|
||||||
|
response = super().handle_request(request)
|
||||||
|
response.read()
|
||||||
|
|
||||||
|
timeout_secs: int = 120
|
||||||
|
start_time = time.time()
|
||||||
|
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||||
|
if time.time() - start_time > timeout_secs:
|
||||||
|
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
|
||||||
|
return httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
headers=response.headers,
|
||||||
|
content=json.dumps(timeout).encode("utf-8"),
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
time.sleep(int(response.headers.get("retry-after")) or 10)
|
||||||
|
response = super().handle_request(request)
|
||||||
|
response.read()
|
||||||
|
|
||||||
|
if response.json()["status"] == "failed":
|
||||||
|
error_data = response.json()
|
||||||
|
return httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
headers=response.headers,
|
||||||
|
content=json.dumps(error_data).encode("utf-8"),
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()["result"]
|
||||||
|
return httpx.Response(
|
||||||
|
status_code=200,
|
||||||
|
headers=response.headers,
|
||||||
|
content=json.dumps(result).encode("utf-8"),
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
return super().handle_request(request)
|
|
@ -2307,8 +2307,7 @@ def image_generation(prompt: str,
|
||||||
get_secret("AZURE_AD_TOKEN")
|
get_secret("AZURE_AD_TOKEN")
|
||||||
)
|
)
|
||||||
|
|
||||||
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version)
|
||||||
pass
|
|
||||||
elif custom_llm_provider == "openai":
|
elif custom_llm_provider == "openai":
|
||||||
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
||||||
|
|
||||||
|
|
|
@ -727,7 +727,7 @@ def test_completion_azure():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_azure()
|
test_completion_azure()
|
||||||
|
|
||||||
def test_azure_openai_ad_token():
|
def test_azure_openai_ad_token():
|
||||||
# this tests if the azure ad token is set in the request header
|
# this tests if the azure ad token is set in the request header
|
||||||
|
|
|
@ -4,7 +4,8 @@
|
||||||
import sys, os
|
import sys, os
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -18,14 +19,22 @@ def test_image_generation_openai():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
assert len(response.data) > 0
|
||||||
|
|
||||||
# test_image_generation_openai()
|
# test_image_generation_openai()
|
||||||
|
|
||||||
# def test_image_generation_azure():
|
def test_image_generation_azure():
|
||||||
# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure")
|
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview")
|
||||||
# print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
assert len(response.data) > 0
|
||||||
# test_image_generation_azure()
|
# test_image_generation_azure()
|
||||||
|
|
||||||
|
def test_image_generation_azure_dall_e_3():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY"))
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert len(response.data) > 0
|
||||||
|
# test_image_generation_azure_dall_e_3()
|
||||||
# @pytest.mark.asyncio
|
# @pytest.mark.asyncio
|
||||||
# async def test_async_image_generation_openai():
|
# async def test_async_image_generation_openai():
|
||||||
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||||
|
|
|
@ -1613,6 +1613,7 @@ def client(original_function):
|
||||||
try:
|
try:
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
except:
|
except:
|
||||||
|
model = None
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
if call_type != CallTypes.image_generation.value:
|
if call_type != CallTypes.image_generation.value:
|
||||||
raise ValueError("model param not passed in.")
|
raise ValueError("model param not passed in.")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue