refactor(azure.py): refactor sync azure calls to httpx

This commit is contained in:
Krrish Dholakia 2024-07-02 17:06:48 -07:00
parent 589c1c6280
commit cf5334fe8a
2 changed files with 144 additions and 32 deletions

View file

@ -1160,6 +1160,105 @@ class AzureChatCompletion(BaseLLM):
},
)
def make_sync_azure_httpx_request(
self,
client: Optional[HTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
api_base: str,
api_version: str,
api_key: str,
data: dict,
) -> httpx.Response:
"""
Implemented for azure dall-e-2 image gen calls
Alternative to needing a custom transport implementation
"""
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
if (
"images/generations" in api_base
and api_version
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]
): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit"
)
data.pop(
"model", None
) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
response = sync_handler.post(
url=api_base,
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"api-key": api_key,
},
)
operation_location_url = response.headers["operation-location"]
response = sync_handler.get(
url=operation_location_url,
headers={
"api-key": api_key,
},
)
response.read()
timeout_secs: int = 120
start_time = time.time()
if "status" not in response.json():
raise Exception(
"Expected 'status' in response. Got={}".format(response.json())
)
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
raise AzureOpenAIError(
status_code=408, message="Operation polling timed out."
)
time.sleep(int(response.headers.get("retry-after") or 10))
response = sync_handler.get(
url=operation_location_url,
headers={
"api-key": api_key,
},
)
response.read()
if response.json()["status"] == "failed":
error_data = response.json()
raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
return response
return sync_handler.post(
url=api_base,
json=data,
headers={
"Content-Type": "application/json;",
"api-key": api_key,
},
)
def create_azure_base_url(
self, azure_client_params: dict, model: Optional[str]
) -> str:
@ -1196,17 +1295,6 @@ class AzureChatCompletion(BaseLLM):
):
response: Optional[dict] = None
try:
# ## LOGGING
# logging_obj.pre_call(
# input=data["prompt"],
# api_key=azure_client.api_key,
# additional_args={
# "headers": {"api_key": azure_client.api_key},
# "api_base": azure_client._base_url._uri_reference,
# "acompletion": True,
# "complete_input_dict": data,
# },
# )
# response = await azure_client.images.generate(**data, timeout=timeout)
api_base: str = azure_client_params.get(
"api_base", ""
@ -1217,6 +1305,17 @@ class AzureChatCompletion(BaseLLM):
img_gen_api_base = self.create_azure_base_url(
azure_client_params=azure_client_params, model=data.get("model", "")
)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": img_gen_api_base,
"headers": {"api_key": api_key},
},
)
httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
client=None,
timeout=timeout,
@ -1310,28 +1409,30 @@ class AzureChatCompletion(BaseLLM):
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
return response
if client is None:
client_session = litellm.client_session or httpx.Client(
transport=CustomHTTPTransport(),
img_gen_api_base = self.create_azure_base_url(
azure_client_params=azure_client_params, model=data.get("model", "")
)
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=azure_client.api_key,
input=data["prompt"],
api_key=api_key,
additional_args={
"headers": {"api_key": azure_client.api_key},
"api_base": azure_client._base_url._uri_reference,
"acompletion": False,
"complete_input_dict": data,
"api_base": img_gen_api_base,
"headers": {"api_key": api_key},
},
)
httpx_response: httpx.Response = self.make_sync_azure_httpx_request(
client=None,
timeout=timeout,
api_base=img_gen_api_base,
api_version=api_version or "",
api_key=api_key or "",
data=data,
)
response = httpx_response.json()["result"]
## COMPLETION CALL
response = azure_client.images.generate(**data, timeout=timeout) # type: ignore
## LOGGING
logging_obj.post_call(
input=prompt,
@ -1340,7 +1441,7 @@ class AzureChatCompletion(BaseLLM):
original_response=response,
)
# return response
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e

View file

@ -42,9 +42,20 @@ def test_image_generation_openai():
# test_image_generation_openai()
@pytest.mark.parametrize(
"sync_mode",
[True, False],
) #
@pytest.mark.asyncio
async def test_image_generation_azure():
async def test_image_generation_azure(sync_mode):
try:
if sync_mode:
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="azure/",
api_version="2023-06-01-preview",
)
else:
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="azure/",