forked from phoenix/litellm-mirror
feat(azure.py): add support for azure image generations endpoint
This commit is contained in:
parent
f0df28362a
commit
b3962e483f
6 changed files with 90 additions and 11 deletions
|
@ -6,6 +6,7 @@ from typing import Callable, Optional
|
|||
from litellm import OpenAIConfig
|
||||
import litellm, json
|
||||
import httpx
|
||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
|
@ -464,11 +465,12 @@ class AzureChatCompletion(BaseLLM):
|
|||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
def image_generation(self,
|
||||
prompt: list,
|
||||
prompt: str,
|
||||
timeout: float,
|
||||
model: Optional[str]=None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
|
@ -477,9 +479,12 @@ class AzureChatCompletion(BaseLLM):
|
|||
):
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
if model and len(model) > 0:
|
||||
model = model
|
||||
else:
|
||||
model = None
|
||||
data = {
|
||||
# "model": model,
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
|
@ -492,7 +497,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
# return response
|
||||
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore
|
||||
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
|
||||
azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
|
|
64
litellm/llms/custom_httpx/azure_dall_e_2.py
Normal file
64
litellm/llms/custom_httpx/azure_dall_e_2.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import time
|
||||
import json
|
||||
import httpx
|
||||
|
||||
class CustomHTTPTransport(httpx.HTTPTransport):
|
||||
"""
|
||||
This class was written as a workaround to support dall-e-2 on openai > v1.x
|
||||
|
||||
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
|
||||
"""
|
||||
def handle_request(
|
||||
self,
|
||||
request: httpx.Request,
|
||||
) -> httpx.Response:
|
||||
if "images/generations" in request.url.path and request.url.params[
|
||||
"api-version"
|
||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
"2023-06-01-preview",
|
||||
"2023-07-01-preview",
|
||||
"2023-08-01-preview",
|
||||
"2023-09-01-preview",
|
||||
"2023-10-01-preview",
|
||||
]:
|
||||
request.url = request.url.copy_with(path="/openai/images/generations:submit")
|
||||
response = super().handle_request(request)
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
request.url = httpx.URL(operation_location_url)
|
||||
request.method = "GET"
|
||||
response = super().handle_request(request)
|
||||
response.read()
|
||||
|
||||
timeout_secs: int = 120
|
||||
start_time = time.time()
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
content=json.dumps(timeout).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
|
||||
time.sleep(int(response.headers.get("retry-after")) or 10)
|
||||
response = super().handle_request(request)
|
||||
response.read()
|
||||
|
||||
if response.json()["status"] == "failed":
|
||||
error_data = response.json()
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
content=json.dumps(error_data).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
|
||||
result = response.json()["result"]
|
||||
return httpx.Response(
|
||||
status_code=200,
|
||||
headers=response.headers,
|
||||
content=json.dumps(result).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
return super().handle_request(request)
|
|
@ -2307,8 +2307,7 @@ def image_generation(prompt: str,
|
|||
get_secret("AZURE_AD_TOKEN")
|
||||
)
|
||||
|
||||
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
||||
pass
|
||||
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version)
|
||||
elif custom_llm_provider == "openai":
|
||||
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
||||
|
||||
|
|
|
@ -727,7 +727,7 @@ def test_completion_azure():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_azure()
|
||||
test_completion_azure()
|
||||
|
||||
def test_azure_openai_ad_token():
|
||||
# this tests if the azure ad token is set in the request header
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
|
@ -18,14 +19,22 @@ def test_image_generation_openai():
|
|||
litellm.set_verbose = True
|
||||
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
|
||||
# test_image_generation_openai()
|
||||
|
||||
# def test_image_generation_azure():
|
||||
# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure")
|
||||
# print(f"response: {response}")
|
||||
def test_image_generation_azure():
|
||||
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview")
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
# test_image_generation_azure()
|
||||
|
||||
def test_image_generation_azure_dall_e_3():
|
||||
litellm.set_verbose = True
|
||||
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY"))
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
# test_image_generation_azure_dall_e_3()
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_async_image_generation_openai():
|
||||
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||
|
|
|
@ -1613,6 +1613,7 @@ def client(original_function):
|
|||
try:
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
except:
|
||||
model = None
|
||||
call_type = original_function.__name__
|
||||
if call_type != CallTypes.image_generation.value:
|
||||
raise ValueError("model param not passed in.")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue