fix(router.py): add support for async image generation endpoints

This commit is contained in:
Krrish Dholakia 2023-12-21 14:38:44 +05:30
parent ae361230fd
commit c084f04a35
6 changed files with 109 additions and 13 deletions

View file

@ -6,7 +6,7 @@ from typing import Callable, Optional
from litellm import OpenAIConfig
import litellm, json
import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception):
@ -480,7 +480,8 @@ class AzureChatCompletion(BaseLLM):
response = None
try:
if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
else:
openai_aclient = client
response = await openai_aclient.images.generate(**data)
@ -492,7 +493,7 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
except Exception as e:
## LOGGING
logging_obj.post_call(
@ -511,6 +512,7 @@ class AzureChatCompletion(BaseLLM):
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str]=None,
logging_obj=None,
optional_params=None,
client=None,
@ -531,13 +533,26 @@ class AzureChatCompletion(BaseLLM):
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
return response
if client is None:
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
else:
azure_client = client