forked from phoenix/litellm-mirror
refactor(azure.py): replaces the custom transport logic for just using our httpx client
Done to fix all the http/https proxy issues people are facing with proxy.
This commit is contained in:
parent
612af8f5be
commit
589c1c6280
4 changed files with 192 additions and 29 deletions
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
import uuid
|
||||
from typing import (
|
||||
|
@ -21,9 +22,10 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
|
|||
from typing_extensions import overload
|
||||
|
||||
import litellm
|
||||
from litellm import OpenAIConfig
|
||||
from litellm import ImageResponse, OpenAIConfig
|
||||
from litellm.caching import DualCache
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import (
|
||||
Choices,
|
||||
CustomStreamWrapper,
|
||||
|
@ -33,6 +35,7 @@ from litellm.utils import (
|
|||
UnsupportedParamsError,
|
||||
convert_to_model_response_object,
|
||||
get_secret,
|
||||
modify_url,
|
||||
)
|
||||
|
||||
from ..types.llms.openai import (
|
||||
|
@ -1051,6 +1054,135 @@ class AzureChatCompletion(BaseLLM):
|
|||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
|
||||
async def make_async_azure_httpx_request(
|
||||
self,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Implemented for azure dall-e-2 image gen calls
|
||||
|
||||
Alternative to needing a custom transport implementation
|
||||
"""
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
async_handler = client # type: ignore
|
||||
|
||||
if (
|
||||
"images/generations" in api_base
|
||||
and api_version
|
||||
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
"2023-06-01-preview",
|
||||
"2023-07-01-preview",
|
||||
"2023-08-01-preview",
|
||||
"2023-09-01-preview",
|
||||
"2023-10-01-preview",
|
||||
]
|
||||
): # CREATE + POLL for azure dall-e-2 calls
|
||||
|
||||
api_base = modify_url(
|
||||
original_url=api_base, new_path="/openai/images/generations:submit"
|
||||
)
|
||||
|
||||
data.pop(
|
||||
"model", None
|
||||
) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
|
||||
response = await async_handler.post(
|
||||
url=api_base,
|
||||
data=json.dumps(data),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"api-key": api_key,
|
||||
},
|
||||
)
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
response = await async_handler.get(
|
||||
url=operation_location_url,
|
||||
headers={
|
||||
"api-key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
await response.aread()
|
||||
|
||||
timeout_secs: int = 120
|
||||
start_time = time.time()
|
||||
if "status" not in response.json():
|
||||
raise Exception(
|
||||
"Expected 'status' in response. Got={}".format(response.json())
|
||||
)
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
timeout_msg = {
|
||||
"error": {
|
||||
"code": "Timeout",
|
||||
"message": "Operation polling timed out.",
|
||||
}
|
||||
}
|
||||
|
||||
raise AzureOpenAIError(
|
||||
status_code=408, message="Operation polling timed out."
|
||||
)
|
||||
|
||||
await asyncio.sleep(int(response.headers.get("retry-after") or 10))
|
||||
response = await async_handler.get(
|
||||
url=operation_location_url,
|
||||
headers={
|
||||
"api-key": api_key,
|
||||
},
|
||||
)
|
||||
await response.aread()
|
||||
|
||||
if response.json()["status"] == "failed":
|
||||
error_data = response.json()
|
||||
raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
|
||||
|
||||
return response
|
||||
return await async_handler.post(
|
||||
url=api_base,
|
||||
json=data,
|
||||
headers={
|
||||
"Content-Type": "application/json;",
|
||||
"api-key": api_key,
|
||||
},
|
||||
)
|
||||
|
||||
def create_azure_base_url(
|
||||
self, azure_client_params: dict, model: Optional[str]
|
||||
) -> str:
|
||||
|
||||
api_base: str = azure_client_params.get(
|
||||
"azure_endpoint", ""
|
||||
) # "https://example-endpoint.openai.azure.com"
|
||||
if api_base.endswith("/"):
|
||||
api_base = api_base.rstrip("/")
|
||||
api_version: str = azure_client_params.get("api_version", "")
|
||||
if model is None:
|
||||
model = ""
|
||||
new_api_base = (
|
||||
api_base
|
||||
+ "/openai/deployments/"
|
||||
+ model
|
||||
+ "/images/generations"
|
||||
+ "?api-version="
|
||||
+ api_version
|
||||
)
|
||||
|
||||
return new_api_base
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
data: dict,
|
||||
|
@ -1062,30 +1194,40 @@ class AzureChatCompletion(BaseLLM):
|
|||
logging_obj=None,
|
||||
timeout=None,
|
||||
):
|
||||
response = None
|
||||
response: Optional[dict] = None
|
||||
try:
|
||||
if client is None:
|
||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(),
|
||||
# ## LOGGING
|
||||
# logging_obj.pre_call(
|
||||
# input=data["prompt"],
|
||||
# api_key=azure_client.api_key,
|
||||
# additional_args={
|
||||
# "headers": {"api_key": azure_client.api_key},
|
||||
# "api_base": azure_client._base_url._uri_reference,
|
||||
# "acompletion": True,
|
||||
# "complete_input_dict": data,
|
||||
# },
|
||||
# )
|
||||
# response = await azure_client.images.generate(**data, timeout=timeout)
|
||||
api_base: str = azure_client_params.get(
|
||||
"api_base", ""
|
||||
) # "https://example-endpoint.openai.azure.com"
|
||||
if api_base.endswith("/"):
|
||||
api_base = api_base.rstrip("/")
|
||||
api_version: str = azure_client_params.get("api_version", "")
|
||||
img_gen_api_base = self.create_azure_base_url(
|
||||
azure_client_params=azure_client_params, model=data.get("model", "")
|
||||
)
|
||||
azure_client = AsyncAzureOpenAI(
|
||||
http_client=client_session, **azure_client_params
|
||||
httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
|
||||
client=None,
|
||||
timeout=timeout,
|
||||
api_base=img_gen_api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
)
|
||||
else:
|
||||
azure_client = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"api_key": azure_client.api_key},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
response = await azure_client.images.generate(**data, timeout=timeout)
|
||||
stringified_response = response.model_dump()
|
||||
response = httpx_response.json()["result"]
|
||||
|
||||
stringified_response = response
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
|
|
|
@ -1,20 +1,23 @@
|
|||
# What this tests?
|
||||
## This tests the litellm support for the openai /generations endpoint
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
load_dotenv()
|
||||
import os
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
|
@ -39,9 +42,10 @@ def test_image_generation_openai():
|
|||
# test_image_generation_openai()
|
||||
|
||||
|
||||
def test_image_generation_azure():
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_generation_azure():
|
||||
try:
|
||||
response = litellm.image_generation(
|
||||
response = await litellm.aimage_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="azure/",
|
||||
api_version="2023-06-01-preview",
|
||||
|
|
|
@ -155,6 +155,16 @@ class ToolConfig(TypedDict):
|
|||
functionCallingConfig: FunctionCallingConfig
|
||||
|
||||
|
||||
class TTL(TypedDict, total=False):
|
||||
seconds: Required[float]
|
||||
nano: float
|
||||
|
||||
|
||||
class CachedContent(TypedDict, total=False):
|
||||
ttl: TTL
|
||||
expire_time: str
|
||||
|
||||
|
||||
class RequestBody(TypedDict, total=False):
|
||||
contents: Required[List[ContentType]]
|
||||
system_instruction: SystemInstructions
|
||||
|
@ -162,6 +172,7 @@ class RequestBody(TypedDict, total=False):
|
|||
toolConfig: ToolConfig
|
||||
safetySettings: List[SafetSettingsConfig]
|
||||
generationConfig: GenerationConfig
|
||||
cachedContent: str
|
||||
|
||||
|
||||
class SafetyRatings(TypedDict):
|
||||
|
|
|
@ -4815,6 +4815,12 @@ def function_to_dict(input_function): # noqa: C901
|
|||
return result
|
||||
|
||||
|
||||
def modify_url(original_url, new_path):
|
||||
url = httpx.URL(original_url)
|
||||
modified_url = url.copy_with(path=new_path)
|
||||
return str(modified_url)
|
||||
|
||||
|
||||
def load_test_model(
|
||||
model: str,
|
||||
custom_llm_provider: str = "",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue