fix(router.py): add support for async image generation endpoints

This commit is contained in:
Krrish Dholakia 2023-12-21 14:38:44 +05:30
parent a4aa645cf6
commit be68796eba
6 changed files with 109 additions and 13 deletions

View file

@ -6,7 +6,7 @@ from typing import Callable, Optional
from litellm import OpenAIConfig
import litellm, json
import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception):
@ -480,7 +480,8 @@ class AzureChatCompletion(BaseLLM):
response = None
try:
if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
else:
openai_aclient = client
response = await openai_aclient.images.generate(**data)
@ -492,7 +493,7 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
except Exception as e:
## LOGGING
logging_obj.post_call(
@ -511,6 +512,7 @@ class AzureChatCompletion(BaseLLM):
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str]=None,
logging_obj=None,
optional_params=None,
client=None,
@ -531,13 +533,26 @@ class AzureChatCompletion(BaseLLM):
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
return response
if client is None:
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
else:
azure_client = client

View file

@ -1,5 +1,61 @@
import time, json, httpx, asyncio
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
"""
Async implementation of custom http transport
"""
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
response = await super().handle_async_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
request.method = "GET"
response = await super().handle_async_request(request)
await response.aread()
timeout_secs: int = 120
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(timeout).encode("utf-8"),
request=request,
)
time.sleep(int(response.headers.get("retry-after")) or 10)
response = await super().handle_async_request(request)
await response.aread()
if response.json()["status"] == "failed":
error_data = response.json()
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(error_data).encode("utf-8"),
request=request,
)
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=request,
)
return await super().handle_async_request(request)
class CustomHTTPTransport(httpx.HTTPTransport):
"""
This class was written as a workaround to support dall-e-2 on openai > v1.x

View file

@ -2351,9 +2351,9 @@ def image_generation(prompt: str,
get_secret("AZURE_AD_TOKEN")
)
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimage_generation)
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimg_generation)
elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimage_generation)
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimg_generation)
return model_response

View file

@ -18,7 +18,7 @@ import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
import copy
class Router:
"""
@ -525,7 +525,6 @@ class Router:
async def async_function_with_retries(self, *args, **kwargs):
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
backoff_factor = 1
original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
@ -1099,6 +1098,7 @@ class Router:
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) # type: ignore
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,

View file

@ -424,6 +424,7 @@ def test_function_calling_on_router():
# test_function_calling_on_router()
### IMAGE GENERATION
@pytest.mark.asyncio
async def test_aimg_gen_on_router():
litellm.set_verbose = True
try:
@ -442,14 +443,32 @@ async def test_aimg_gen_on_router():
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
}
},
{
"model_name": "dall-e-2",
"litellm_params": {
"model": "azure/",
"api_version": "2023-06-01-preview",
"api_base": os.getenv("AZURE_API_BASE"),
"api_key": os.getenv("AZURE_API_KEY")
}
}
]
router = Router(model_list=model_list)
# response = await router.aimage_generation(
# model="dall-e-3",
# prompt="A cute baby sea otter"
# )
# print(response)
# assert len(response.data) > 0
response = await router.aimage_generation(
model="dall-e-3",
model="dall-e-2",
prompt="A cute baby sea otter"
)
print(response)
assert len(response.data) > 0
router.reset()
except Exception as e:
traceback.print_exc()
@ -489,7 +508,7 @@ def test_img_gen_on_router():
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
test_img_gen_on_router()
# test_img_gen_on_router()
###
def test_aembedding_on_router():
@ -625,7 +644,7 @@ async def test_mistral_on_router():
]
)
print(response)
asyncio.run(test_mistral_on_router())
# asyncio.run(test_mistral_on_router())
def test_openai_completion_on_router():
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream

View file

@ -551,6 +551,8 @@ class ImageResponse(OpenAIObject):
data: Optional[list] = None
usage: Optional[dict] = None
def __init__(self, created=None, data=None, response_ms=None):
if response_ms:
_response_ms = response_ms
@ -565,8 +567,10 @@ class ImageResponse(OpenAIObject):
created = created
else:
created = None
super().__init__(data=data, created=created)
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -1668,6 +1672,8 @@ def client(original_function):
return result
elif "aembedding" in kwargs and kwargs["aembedding"] == True:
return result
elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True:
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model or None)