forked from phoenix/litellm-mirror
refactor(azure.py): refactor sync azure calls to httpx
This commit is contained in:
parent
589c1c6280
commit
cf5334fe8a
2 changed files with 144 additions and 32 deletions
|
@ -1160,6 +1160,105 @@ class AzureChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
|
||||
def make_sync_azure_httpx_request(
|
||||
self,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Implemented for azure dall-e-2 image gen calls
|
||||
|
||||
Alternative to needing a custom transport implementation
|
||||
"""
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
if (
|
||||
"images/generations" in api_base
|
||||
and api_version
|
||||
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
"2023-06-01-preview",
|
||||
"2023-07-01-preview",
|
||||
"2023-08-01-preview",
|
||||
"2023-09-01-preview",
|
||||
"2023-10-01-preview",
|
||||
]
|
||||
): # CREATE + POLL for azure dall-e-2 calls
|
||||
|
||||
api_base = modify_url(
|
||||
original_url=api_base, new_path="/openai/images/generations:submit"
|
||||
)
|
||||
|
||||
data.pop(
|
||||
"model", None
|
||||
) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
|
||||
response = sync_handler.post(
|
||||
url=api_base,
|
||||
data=json.dumps(data),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"api-key": api_key,
|
||||
},
|
||||
)
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
response = sync_handler.get(
|
||||
url=operation_location_url,
|
||||
headers={
|
||||
"api-key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
response.read()
|
||||
|
||||
timeout_secs: int = 120
|
||||
start_time = time.time()
|
||||
if "status" not in response.json():
|
||||
raise Exception(
|
||||
"Expected 'status' in response. Got={}".format(response.json())
|
||||
)
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
raise AzureOpenAIError(
|
||||
status_code=408, message="Operation polling timed out."
|
||||
)
|
||||
|
||||
time.sleep(int(response.headers.get("retry-after") or 10))
|
||||
response = sync_handler.get(
|
||||
url=operation_location_url,
|
||||
headers={
|
||||
"api-key": api_key,
|
||||
},
|
||||
)
|
||||
response.read()
|
||||
|
||||
if response.json()["status"] == "failed":
|
||||
error_data = response.json()
|
||||
raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
|
||||
|
||||
return response
|
||||
return sync_handler.post(
|
||||
url=api_base,
|
||||
json=data,
|
||||
headers={
|
||||
"Content-Type": "application/json;",
|
||||
"api-key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
def create_azure_base_url(
|
||||
self, azure_client_params: dict, model: Optional[str]
|
||||
) -> str:
|
||||
|
@ -1196,17 +1295,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
):
|
||||
response: Optional[dict] = None
|
||||
try:
|
||||
# ## LOGGING
|
||||
# logging_obj.pre_call(
|
||||
# input=data["prompt"],
|
||||
# api_key=azure_client.api_key,
|
||||
# additional_args={
|
||||
# "headers": {"api_key": azure_client.api_key},
|
||||
# "api_base": azure_client._base_url._uri_reference,
|
||||
# "acompletion": True,
|
||||
# "complete_input_dict": data,
|
||||
# },
|
||||
# )
|
||||
# response = await azure_client.images.generate(**data, timeout=timeout)
|
||||
api_base: str = azure_client_params.get(
|
||||
"api_base", ""
|
||||
|
@ -1217,6 +1305,17 @@ class AzureChatCompletion(BaseLLM):
|
|||
img_gen_api_base = self.create_azure_base_url(
|
||||
azure_client_params=azure_client_params, model=data.get("model", "")
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": img_gen_api_base,
|
||||
"headers": {"api_key": api_key},
|
||||
},
|
||||
)
|
||||
httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
|
||||
client=None,
|
||||
timeout=timeout,
|
||||
|
@ -1310,28 +1409,30 @@ class AzureChatCompletion(BaseLLM):
|
|||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
|
||||
return response
|
||||
|
||||
if client is None:
|
||||
client_session = litellm.client_session or httpx.Client(
|
||||
transport=CustomHTTPTransport(),
|
||||
img_gen_api_base = self.create_azure_base_url(
|
||||
azure_client_params=azure_client_params, model=data.get("model", "")
|
||||
)
|
||||
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=azure_client.api_key,
|
||||
input=data["prompt"],
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": {"api_key": azure_client.api_key},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": False,
|
||||
"complete_input_dict": data,
|
||||
"api_base": img_gen_api_base,
|
||||
"headers": {"api_key": api_key},
|
||||
},
|
||||
)
|
||||
httpx_response: httpx.Response = self.make_sync_azure_httpx_request(
|
||||
client=None,
|
||||
timeout=timeout,
|
||||
api_base=img_gen_api_base,
|
||||
api_version=api_version or "",
|
||||
api_key=api_key or "",
|
||||
data=data,
|
||||
)
|
||||
response = httpx_response.json()["result"]
|
||||
|
||||
## COMPLETION CALL
|
||||
response = azure_client.images.generate(**data, timeout=timeout) # type: ignore
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
|
@ -1340,7 +1441,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
original_response=response,
|
||||
)
|
||||
# return response
|
||||
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
|
|
|
@ -42,9 +42,20 @@ def test_image_generation_openai():
|
|||
# test_image_generation_openai()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sync_mode",
|
||||
[True, False],
|
||||
) #
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_generation_azure():
|
||||
async def test_image_generation_azure(sync_mode):
|
||||
try:
|
||||
if sync_mode:
|
||||
response = litellm.image_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="azure/",
|
||||
api_version="2023-06-01-preview",
|
||||
)
|
||||
else:
|
||||
response = await litellm.aimage_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="azure/",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue