forked from phoenix/litellm-mirror
fix(router.py): add support for async image generation endpoints
This commit is contained in:
parent
a4aa645cf6
commit
be68796eba
6 changed files with 109 additions and 13 deletions
|
@ -6,7 +6,7 @@ from typing import Callable, Optional
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
import litellm, json
|
import litellm, json
|
||||||
import httpx
|
import httpx
|
||||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport
|
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
class AzureOpenAIError(Exception):
|
||||||
|
@ -480,7 +480,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
if client is None:
|
if client is None:
|
||||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
|
||||||
|
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
response = await openai_aclient.images.generate(**data)
|
response = await openai_aclient.images.generate(**data)
|
||||||
|
@ -492,7 +493,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
original_response=stringified_response,
|
original_response=stringified_response,
|
||||||
)
|
)
|
||||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -511,6 +512,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||||
|
azure_ad_token: Optional[str]=None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
client=None,
|
client=None,
|
||||||
|
@ -531,13 +533,26 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
|
||||||
|
# init AzureOpenAI Client
|
||||||
|
azure_client_params = {
|
||||||
|
"api_version": api_version,
|
||||||
|
"azure_endpoint": api_base,
|
||||||
|
"azure_deployment": model,
|
||||||
|
"max_retries": max_retries,
|
||||||
|
"timeout": timeout
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
azure_client_params["api_key"] = api_key
|
||||||
|
elif azure_ad_token is not None:
|
||||||
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if aimg_generation == True:
|
if aimg_generation == True:
|
||||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
|
||||||
return response
|
return response
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
|
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
|
||||||
azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore
|
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,61 @@
|
||||||
import time, json, httpx, asyncio
|
import time, json, httpx, asyncio
|
||||||
|
|
||||||
|
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
||||||
|
"""
|
||||||
|
Async implementation of custom http transport
|
||||||
|
"""
|
||||||
|
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||||
|
if "images/generations" in request.url.path and request.url.params[
|
||||||
|
"api-version"
|
||||||
|
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||||
|
"2023-06-01-preview",
|
||||||
|
"2023-07-01-preview",
|
||||||
|
"2023-08-01-preview",
|
||||||
|
"2023-09-01-preview",
|
||||||
|
"2023-10-01-preview",
|
||||||
|
]:
|
||||||
|
request.url = request.url.copy_with(path="/openai/images/generations:submit")
|
||||||
|
response = await super().handle_async_request(request)
|
||||||
|
operation_location_url = response.headers["operation-location"]
|
||||||
|
request.url = httpx.URL(operation_location_url)
|
||||||
|
request.method = "GET"
|
||||||
|
response = await super().handle_async_request(request)
|
||||||
|
await response.aread()
|
||||||
|
|
||||||
|
timeout_secs: int = 120
|
||||||
|
start_time = time.time()
|
||||||
|
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||||
|
if time.time() - start_time > timeout_secs:
|
||||||
|
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
|
||||||
|
return httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
headers=response.headers,
|
||||||
|
content=json.dumps(timeout).encode("utf-8"),
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
time.sleep(int(response.headers.get("retry-after")) or 10)
|
||||||
|
response = await super().handle_async_request(request)
|
||||||
|
await response.aread()
|
||||||
|
|
||||||
|
if response.json()["status"] == "failed":
|
||||||
|
error_data = response.json()
|
||||||
|
return httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
headers=response.headers,
|
||||||
|
content=json.dumps(error_data).encode("utf-8"),
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()["result"]
|
||||||
|
return httpx.Response(
|
||||||
|
status_code=200,
|
||||||
|
headers=response.headers,
|
||||||
|
content=json.dumps(result).encode("utf-8"),
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
return await super().handle_async_request(request)
|
||||||
|
|
||||||
class CustomHTTPTransport(httpx.HTTPTransport):
|
class CustomHTTPTransport(httpx.HTTPTransport):
|
||||||
"""
|
"""
|
||||||
This class was written as a workaround to support dall-e-2 on openai > v1.x
|
This class was written as a workaround to support dall-e-2 on openai > v1.x
|
||||||
|
|
|
@ -2351,9 +2351,9 @@ def image_generation(prompt: str,
|
||||||
get_secret("AZURE_AD_TOKEN")
|
get_secret("AZURE_AD_TOKEN")
|
||||||
)
|
)
|
||||||
|
|
||||||
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimage_generation)
|
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimg_generation)
|
||||||
elif custom_llm_provider == "openai":
|
elif custom_llm_provider == "openai":
|
||||||
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimage_generation)
|
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimg_generation)
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ import inspect, concurrent
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport
|
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||||
import copy
|
import copy
|
||||||
class Router:
|
class Router:
|
||||||
"""
|
"""
|
||||||
|
@ -525,7 +525,6 @@ class Router:
|
||||||
|
|
||||||
async def async_function_with_retries(self, *args, **kwargs):
|
async def async_function_with_retries(self, *args, **kwargs):
|
||||||
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
|
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
|
||||||
backoff_factor = 1
|
|
||||||
original_function = kwargs.pop("original_function")
|
original_function = kwargs.pop("original_function")
|
||||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||||
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||||
|
@ -1099,6 +1098,7 @@ class Router:
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) # type: ignore
|
||||||
)
|
)
|
||||||
model["client"] = openai.AzureOpenAI(
|
model["client"] = openai.AzureOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
|
@ -424,6 +424,7 @@ def test_function_calling_on_router():
|
||||||
# test_function_calling_on_router()
|
# test_function_calling_on_router()
|
||||||
|
|
||||||
### IMAGE GENERATION
|
### IMAGE GENERATION
|
||||||
|
@pytest.mark.asyncio
|
||||||
async def test_aimg_gen_on_router():
|
async def test_aimg_gen_on_router():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
try:
|
try:
|
||||||
|
@ -442,14 +443,32 @@ async def test_aimg_gen_on_router():
|
||||||
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
|
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
|
||||||
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
|
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "dall-e-2",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/",
|
||||||
|
"api_version": "2023-06-01-preview",
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
router = Router(model_list=model_list)
|
router = Router(model_list=model_list)
|
||||||
|
# response = await router.aimage_generation(
|
||||||
|
# model="dall-e-3",
|
||||||
|
# prompt="A cute baby sea otter"
|
||||||
|
# )
|
||||||
|
# print(response)
|
||||||
|
# assert len(response.data) > 0
|
||||||
|
|
||||||
response = await router.aimage_generation(
|
response = await router.aimage_generation(
|
||||||
model="dall-e-3",
|
model="dall-e-2",
|
||||||
prompt="A cute baby sea otter"
|
prompt="A cute baby sea otter"
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
assert len(response.data) > 0
|
||||||
|
|
||||||
router.reset()
|
router.reset()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -489,7 +508,7 @@ def test_img_gen_on_router():
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_img_gen_on_router()
|
# test_img_gen_on_router()
|
||||||
###
|
###
|
||||||
|
|
||||||
def test_aembedding_on_router():
|
def test_aembedding_on_router():
|
||||||
|
@ -625,7 +644,7 @@ async def test_mistral_on_router():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
asyncio.run(test_mistral_on_router())
|
# asyncio.run(test_mistral_on_router())
|
||||||
|
|
||||||
def test_openai_completion_on_router():
|
def test_openai_completion_on_router():
|
||||||
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream
|
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream
|
||||||
|
|
|
@ -551,6 +551,8 @@ class ImageResponse(OpenAIObject):
|
||||||
|
|
||||||
data: Optional[list] = None
|
data: Optional[list] = None
|
||||||
|
|
||||||
|
usage: Optional[dict] = None
|
||||||
|
|
||||||
def __init__(self, created=None, data=None, response_ms=None):
|
def __init__(self, created=None, data=None, response_ms=None):
|
||||||
if response_ms:
|
if response_ms:
|
||||||
_response_ms = response_ms
|
_response_ms = response_ms
|
||||||
|
@ -565,8 +567,10 @@ class ImageResponse(OpenAIObject):
|
||||||
created = created
|
created = created
|
||||||
else:
|
else:
|
||||||
created = None
|
created = None
|
||||||
|
|
||||||
super().__init__(data=data, created=created)
|
super().__init__(data=data, created=created)
|
||||||
|
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||||
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# Define custom behavior for the 'in' operator
|
# Define custom behavior for the 'in' operator
|
||||||
|
@ -1668,6 +1672,8 @@ def client(original_function):
|
||||||
return result
|
return result
|
||||||
elif "aembedding" in kwargs and kwargs["aembedding"] == True:
|
elif "aembedding" in kwargs and kwargs["aembedding"] == True:
|
||||||
return result
|
return result
|
||||||
|
elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True:
|
||||||
|
return result
|
||||||
|
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
post_call_processing(original_response=result, model=model or None)
|
post_call_processing(original_response=result, model=model or None)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue