refactor(azure.py): replaces the custom transport logic for just using our httpx client

Done to fix all the http/https proxy issues people are facing with proxy.
This commit is contained in:
Krrish Dholakia 2024-07-02 15:32:53 -07:00
parent 612af8f5be
commit 589c1c6280
4 changed files with 192 additions and 29 deletions

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import json import json
import os import os
import time
import types import types
import uuid import uuid
from typing import ( from typing import (
@ -21,9 +22,10 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
from typing_extensions import overload from typing_extensions import overload
import litellm import litellm
from litellm import OpenAIConfig from litellm import ImageResponse, OpenAIConfig
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import ( from litellm.utils import (
Choices, Choices,
CustomStreamWrapper, CustomStreamWrapper,
@ -33,6 +35,7 @@ from litellm.utils import (
UnsupportedParamsError, UnsupportedParamsError,
convert_to_model_response_object, convert_to_model_response_object,
get_secret, get_secret,
modify_url,
) )
from ..types.llms.openai import ( from ..types.llms.openai import (
@ -1051,6 +1054,135 @@ class AzureChatCompletion(BaseLLM):
else: else:
raise AzureOpenAIError(status_code=500, message=str(e)) raise AzureOpenAIError(status_code=500, message=str(e))
async def make_async_azure_httpx_request(
self,
client: Optional[AsyncHTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
api_base: str,
api_version: str,
api_key: str,
data: dict,
) -> httpx.Response:
"""
Implemented for azure dall-e-2 image gen calls
Alternative to needing a custom transport implementation
"""
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler = AsyncHTTPHandler(**_params) # type: ignore
else:
async_handler = client # type: ignore
if (
"images/generations" in api_base
and api_version
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]
): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit"
)
data.pop(
"model", None
) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
response = await async_handler.post(
url=api_base,
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"api-key": api_key,
},
)
operation_location_url = response.headers["operation-location"]
response = await async_handler.get(
url=operation_location_url,
headers={
"api-key": api_key,
},
)
await response.aread()
timeout_secs: int = 120
start_time = time.time()
if "status" not in response.json():
raise Exception(
"Expected 'status' in response. Got={}".format(response.json())
)
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout_msg = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
raise AzureOpenAIError(
status_code=408, message="Operation polling timed out."
)
await asyncio.sleep(int(response.headers.get("retry-after") or 10))
response = await async_handler.get(
url=operation_location_url,
headers={
"api-key": api_key,
},
)
await response.aread()
if response.json()["status"] == "failed":
error_data = response.json()
raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
return response
return await async_handler.post(
url=api_base,
json=data,
headers={
"Content-Type": "application/json;",
"api-key": api_key,
},
)
def create_azure_base_url(
self, azure_client_params: dict, model: Optional[str]
) -> str:
api_base: str = azure_client_params.get(
"azure_endpoint", ""
) # "https://example-endpoint.openai.azure.com"
if api_base.endswith("/"):
api_base = api_base.rstrip("/")
api_version: str = azure_client_params.get("api_version", "")
if model is None:
model = ""
new_api_base = (
api_base
+ "/openai/deployments/"
+ model
+ "/images/generations"
+ "?api-version="
+ api_version
)
return new_api_base
async def aimage_generation( async def aimage_generation(
self, self,
data: dict, data: dict,
@ -1062,30 +1194,40 @@ class AzureChatCompletion(BaseLLM):
logging_obj=None, logging_obj=None,
timeout=None, timeout=None,
): ):
response = None response: Optional[dict] = None
try: try:
if client is None: # ## LOGGING
client_session = litellm.aclient_session or httpx.AsyncClient( # logging_obj.pre_call(
transport=AsyncCustomHTTPTransport(), # input=data["prompt"],
) # api_key=azure_client.api_key,
azure_client = AsyncAzureOpenAI( # additional_args={
http_client=client_session, **azure_client_params # "headers": {"api_key": azure_client.api_key},
) # "api_base": azure_client._base_url._uri_reference,
else: # "acompletion": True,
azure_client = client # "complete_input_dict": data,
## LOGGING # },
logging_obj.pre_call( # )
input=data["prompt"], # response = await azure_client.images.generate(**data, timeout=timeout)
api_key=azure_client.api_key, api_base: str = azure_client_params.get(
additional_args={ "api_base", ""
"headers": {"api_key": azure_client.api_key}, ) # "https://example-endpoint.openai.azure.com"
"api_base": azure_client._base_url._uri_reference, if api_base.endswith("/"):
"acompletion": True, api_base = api_base.rstrip("/")
"complete_input_dict": data, api_version: str = azure_client_params.get("api_version", "")
}, img_gen_api_base = self.create_azure_base_url(
azure_client_params=azure_client_params, model=data.get("model", "")
) )
response = await azure_client.images.generate(**data, timeout=timeout) httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
stringified_response = response.model_dump() client=None,
timeout=timeout,
api_base=img_gen_api_base,
api_version=api_version,
api_key=api_key,
data=data,
)
response = httpx_response.json()["result"]
stringified_response = response
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,

View file

@ -1,20 +1,23 @@
# What this tests? # What this tests?
## This tests the litellm support for the openai /generations endpoint ## This tests the litellm support for the openai /generations endpoint
import sys, os
import traceback
from dotenv import load_dotenv
import logging import logging
import os
import sys
import traceback
from dotenv import load_dotenv
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
load_dotenv() load_dotenv()
import os
import asyncio import asyncio
import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
@ -39,9 +42,10 @@ def test_image_generation_openai():
# test_image_generation_openai() # test_image_generation_openai()
def test_image_generation_azure(): @pytest.mark.asyncio
async def test_image_generation_azure():
try: try:
response = litellm.image_generation( response = await litellm.aimage_generation(
prompt="A cute baby sea otter", prompt="A cute baby sea otter",
model="azure/", model="azure/",
api_version="2023-06-01-preview", api_version="2023-06-01-preview",

View file

@ -155,6 +155,16 @@ class ToolConfig(TypedDict):
functionCallingConfig: FunctionCallingConfig functionCallingConfig: FunctionCallingConfig
class TTL(TypedDict, total=False):
seconds: Required[float]
nano: float
class CachedContent(TypedDict, total=False):
ttl: TTL
expire_time: str
class RequestBody(TypedDict, total=False): class RequestBody(TypedDict, total=False):
contents: Required[List[ContentType]] contents: Required[List[ContentType]]
system_instruction: SystemInstructions system_instruction: SystemInstructions
@ -162,6 +172,7 @@ class RequestBody(TypedDict, total=False):
toolConfig: ToolConfig toolConfig: ToolConfig
safetySettings: List[SafetSettingsConfig] safetySettings: List[SafetSettingsConfig]
generationConfig: GenerationConfig generationConfig: GenerationConfig
cachedContent: str
class SafetyRatings(TypedDict): class SafetyRatings(TypedDict):

View file

@ -4815,6 +4815,12 @@ def function_to_dict(input_function): # noqa: C901
return result return result
def modify_url(original_url, new_path):
url = httpx.URL(original_url)
modified_url = url.copy_with(path=new_path)
return str(modified_url)
def load_test_model( def load_test_model(
model: str, model: str,
custom_llm_provider: str = "", custom_llm_provider: str = "",