forked from phoenix/litellm-mirror
fix(router.py): add support for async image generation endpoints
This commit is contained in:
parent
a4aa645cf6
commit
be68796eba
6 changed files with 109 additions and 13 deletions
|
@ -6,7 +6,7 @@ from typing import Callable, Optional
|
|||
from litellm import OpenAIConfig
|
||||
import litellm, json
|
||||
import httpx
|
||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport
|
||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
|
@ -480,7 +480,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
response = None
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
|
||||
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
|
||||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.images.generate(**data)
|
||||
|
@ -492,7 +493,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -511,6 +512,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
|
@ -531,13 +533,26 @@ class AzureChatCompletion(BaseLLM):
|
|||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if aimg_generation == True:
|
||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
|
||||
return response
|
||||
|
||||
if client is None:
|
||||
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
|
||||
azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore
|
||||
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
|
|
|
@ -1,5 +1,61 @@
|
|||
import time, json, httpx, asyncio
|
||||
|
||||
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
||||
"""
|
||||
Async implementation of custom http transport
|
||||
"""
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if "images/generations" in request.url.path and request.url.params[
|
||||
"api-version"
|
||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
"2023-06-01-preview",
|
||||
"2023-07-01-preview",
|
||||
"2023-08-01-preview",
|
||||
"2023-09-01-preview",
|
||||
"2023-10-01-preview",
|
||||
]:
|
||||
request.url = request.url.copy_with(path="/openai/images/generations:submit")
|
||||
response = await super().handle_async_request(request)
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
request.url = httpx.URL(operation_location_url)
|
||||
request.method = "GET"
|
||||
response = await super().handle_async_request(request)
|
||||
await response.aread()
|
||||
|
||||
timeout_secs: int = 120
|
||||
start_time = time.time()
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
content=json.dumps(timeout).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
|
||||
time.sleep(int(response.headers.get("retry-after")) or 10)
|
||||
response = await super().handle_async_request(request)
|
||||
await response.aread()
|
||||
|
||||
if response.json()["status"] == "failed":
|
||||
error_data = response.json()
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
content=json.dumps(error_data).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
|
||||
result = response.json()["result"]
|
||||
return httpx.Response(
|
||||
status_code=200,
|
||||
headers=response.headers,
|
||||
content=json.dumps(result).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
return await super().handle_async_request(request)
|
||||
|
||||
class CustomHTTPTransport(httpx.HTTPTransport):
|
||||
"""
|
||||
This class was written as a workaround to support dall-e-2 on openai > v1.x
|
||||
|
|
|
@ -2351,9 +2351,9 @@ def image_generation(prompt: str,
|
|||
get_secret("AZURE_AD_TOKEN")
|
||||
)
|
||||
|
||||
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimage_generation)
|
||||
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimg_generation)
|
||||
elif custom_llm_provider == "openai":
|
||||
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimage_generation)
|
||||
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimg_generation)
|
||||
|
||||
return model_response
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import inspect, concurrent
|
|||
from openai import AsyncOpenAI
|
||||
from collections import defaultdict
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport
|
||||
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
import copy
|
||||
class Router:
|
||||
"""
|
||||
|
@ -525,7 +525,6 @@ class Router:
|
|||
|
||||
async def async_function_with_retries(self, *args, **kwargs):
|
||||
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
|
||||
backoff_factor = 1
|
||||
original_function = kwargs.pop("original_function")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||
|
@ -1099,6 +1098,7 @@ class Router:
|
|||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) # type: ignore
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
|
|
|
@ -424,6 +424,7 @@ def test_function_calling_on_router():
|
|||
# test_function_calling_on_router()
|
||||
|
||||
### IMAGE GENERATION
|
||||
@pytest.mark.asyncio
|
||||
async def test_aimg_gen_on_router():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
|
@ -442,14 +443,32 @@ async def test_aimg_gen_on_router():
|
|||
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
|
||||
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "dall-e-2",
|
||||
"litellm_params": {
|
||||
"model": "azure/",
|
||||
"api_version": "2023-06-01-preview",
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_key": os.getenv("AZURE_API_KEY")
|
||||
}
|
||||
}
|
||||
]
|
||||
router = Router(model_list=model_list)
|
||||
# response = await router.aimage_generation(
|
||||
# model="dall-e-3",
|
||||
# prompt="A cute baby sea otter"
|
||||
# )
|
||||
# print(response)
|
||||
# assert len(response.data) > 0
|
||||
|
||||
response = await router.aimage_generation(
|
||||
model="dall-e-3",
|
||||
model="dall-e-2",
|
||||
prompt="A cute baby sea otter"
|
||||
)
|
||||
print(response)
|
||||
assert len(response.data) > 0
|
||||
|
||||
router.reset()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
@ -489,7 +508,7 @@ def test_img_gen_on_router():
|
|||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_img_gen_on_router()
|
||||
# test_img_gen_on_router()
|
||||
###
|
||||
|
||||
def test_aembedding_on_router():
|
||||
|
@ -625,7 +644,7 @@ async def test_mistral_on_router():
|
|||
]
|
||||
)
|
||||
print(response)
|
||||
asyncio.run(test_mistral_on_router())
|
||||
# asyncio.run(test_mistral_on_router())
|
||||
|
||||
def test_openai_completion_on_router():
|
||||
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream
|
||||
|
|
|
@ -551,6 +551,8 @@ class ImageResponse(OpenAIObject):
|
|||
|
||||
data: Optional[list] = None
|
||||
|
||||
usage: Optional[dict] = None
|
||||
|
||||
def __init__(self, created=None, data=None, response_ms=None):
|
||||
if response_ms:
|
||||
_response_ms = response_ms
|
||||
|
@ -565,8 +567,10 @@ class ImageResponse(OpenAIObject):
|
|||
created = created
|
||||
else:
|
||||
created = None
|
||||
|
||||
|
||||
super().__init__(data=data, created=created)
|
||||
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
|
@ -1668,6 +1672,8 @@ def client(original_function):
|
|||
return result
|
||||
elif "aembedding" in kwargs and kwargs["aembedding"] == True:
|
||||
return result
|
||||
elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True:
|
||||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model or None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue