forked from phoenix/litellm-mirror
refactor(azure.py): replaces the custom transport logic for just using our httpx client
Done to fix all the http/https proxy issues people are facing with proxy.
This commit is contained in:
parent
612af8f5be
commit
589c1c6280
4 changed files with 192 additions and 29 deletions
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -21,9 +22,10 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
from typing_extensions import overload
|
from typing_extensions import overload
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import OpenAIConfig
|
from litellm import ImageResponse, OpenAIConfig
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -33,6 +35,7 @@ from litellm.utils import (
|
||||||
UnsupportedParamsError,
|
UnsupportedParamsError,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
get_secret,
|
get_secret,
|
||||||
|
modify_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..types.llms.openai import (
|
from ..types.llms.openai import (
|
||||||
|
@ -1051,6 +1054,135 @@ class AzureChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
async def make_async_azure_httpx_request(
|
||||||
|
self,
|
||||||
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
api_base: str,
|
||||||
|
api_version: str,
|
||||||
|
api_key: str,
|
||||||
|
data: dict,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""
|
||||||
|
Implemented for azure dall-e-2 image gen calls
|
||||||
|
|
||||||
|
Alternative to needing a custom transport implementation
|
||||||
|
"""
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
_httpx_timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = _httpx_timeout
|
||||||
|
else:
|
||||||
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
async_handler = client # type: ignore
|
||||||
|
|
||||||
|
if (
|
||||||
|
"images/generations" in api_base
|
||||||
|
and api_version
|
||||||
|
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||||
|
"2023-06-01-preview",
|
||||||
|
"2023-07-01-preview",
|
||||||
|
"2023-08-01-preview",
|
||||||
|
"2023-09-01-preview",
|
||||||
|
"2023-10-01-preview",
|
||||||
|
]
|
||||||
|
): # CREATE + POLL for azure dall-e-2 calls
|
||||||
|
|
||||||
|
api_base = modify_url(
|
||||||
|
original_url=api_base, new_path="/openai/images/generations:submit"
|
||||||
|
)
|
||||||
|
|
||||||
|
data.pop(
|
||||||
|
"model", None
|
||||||
|
) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
|
||||||
|
response = await async_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
data=json.dumps(data),
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
operation_location_url = response.headers["operation-location"]
|
||||||
|
response = await async_handler.get(
|
||||||
|
url=operation_location_url,
|
||||||
|
headers={
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await response.aread()
|
||||||
|
|
||||||
|
timeout_secs: int = 120
|
||||||
|
start_time = time.time()
|
||||||
|
if "status" not in response.json():
|
||||||
|
raise Exception(
|
||||||
|
"Expected 'status' in response. Got={}".format(response.json())
|
||||||
|
)
|
||||||
|
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||||
|
if time.time() - start_time > timeout_secs:
|
||||||
|
timeout_msg = {
|
||||||
|
"error": {
|
||||||
|
"code": "Timeout",
|
||||||
|
"message": "Operation polling timed out.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=408, message="Operation polling timed out."
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(int(response.headers.get("retry-after") or 10))
|
||||||
|
response = await async_handler.get(
|
||||||
|
url=operation_location_url,
|
||||||
|
headers={
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await response.aread()
|
||||||
|
|
||||||
|
if response.json()["status"] == "failed":
|
||||||
|
error_data = response.json()
|
||||||
|
raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
|
||||||
|
|
||||||
|
return response
|
||||||
|
return await async_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
json=data,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json;",
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_azure_base_url(
|
||||||
|
self, azure_client_params: dict, model: Optional[str]
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
api_base: str = azure_client_params.get(
|
||||||
|
"azure_endpoint", ""
|
||||||
|
) # "https://example-endpoint.openai.azure.com"
|
||||||
|
if api_base.endswith("/"):
|
||||||
|
api_base = api_base.rstrip("/")
|
||||||
|
api_version: str = azure_client_params.get("api_version", "")
|
||||||
|
if model is None:
|
||||||
|
model = ""
|
||||||
|
new_api_base = (
|
||||||
|
api_base
|
||||||
|
+ "/openai/deployments/"
|
||||||
|
+ model
|
||||||
|
+ "/images/generations"
|
||||||
|
+ "?api-version="
|
||||||
|
+ api_version
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_api_base
|
||||||
|
|
||||||
async def aimage_generation(
|
async def aimage_generation(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
@ -1062,30 +1194,40 @@ class AzureChatCompletion(BaseLLM):
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
):
|
):
|
||||||
response = None
|
response: Optional[dict] = None
|
||||||
try:
|
try:
|
||||||
if client is None:
|
# ## LOGGING
|
||||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
# logging_obj.pre_call(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
# input=data["prompt"],
|
||||||
)
|
# api_key=azure_client.api_key,
|
||||||
azure_client = AsyncAzureOpenAI(
|
# additional_args={
|
||||||
http_client=client_session, **azure_client_params
|
# "headers": {"api_key": azure_client.api_key},
|
||||||
)
|
# "api_base": azure_client._base_url._uri_reference,
|
||||||
else:
|
# "acompletion": True,
|
||||||
azure_client = client
|
# "complete_input_dict": data,
|
||||||
## LOGGING
|
# },
|
||||||
logging_obj.pre_call(
|
# )
|
||||||
input=data["prompt"],
|
# response = await azure_client.images.generate(**data, timeout=timeout)
|
||||||
api_key=azure_client.api_key,
|
api_base: str = azure_client_params.get(
|
||||||
additional_args={
|
"api_base", ""
|
||||||
"headers": {"api_key": azure_client.api_key},
|
) # "https://example-endpoint.openai.azure.com"
|
||||||
"api_base": azure_client._base_url._uri_reference,
|
if api_base.endswith("/"):
|
||||||
"acompletion": True,
|
api_base = api_base.rstrip("/")
|
||||||
"complete_input_dict": data,
|
api_version: str = azure_client_params.get("api_version", "")
|
||||||
},
|
img_gen_api_base = self.create_azure_base_url(
|
||||||
|
azure_client_params=azure_client_params, model=data.get("model", "")
|
||||||
)
|
)
|
||||||
response = await azure_client.images.generate(**data, timeout=timeout)
|
httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
|
||||||
stringified_response = response.model_dump()
|
client=None,
|
||||||
|
timeout=timeout,
|
||||||
|
api_base=img_gen_api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
response = httpx_response.json()["result"]
|
||||||
|
|
||||||
|
stringified_response = response
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
|
|
@ -1,20 +1,23 @@
|
||||||
# What this tests?
|
# What this tests?
|
||||||
## This tests the litellm support for the openai /generations endpoint
|
## This tests the litellm support for the openai /generations endpoint
|
||||||
|
|
||||||
import sys, os
|
|
||||||
import traceback
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,9 +42,10 @@ def test_image_generation_openai():
|
||||||
# test_image_generation_openai()
|
# test_image_generation_openai()
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_azure():
|
@pytest.mark.asyncio
|
||||||
|
async def test_image_generation_azure():
|
||||||
try:
|
try:
|
||||||
response = litellm.image_generation(
|
response = await litellm.aimage_generation(
|
||||||
prompt="A cute baby sea otter",
|
prompt="A cute baby sea otter",
|
||||||
model="azure/",
|
model="azure/",
|
||||||
api_version="2023-06-01-preview",
|
api_version="2023-06-01-preview",
|
||||||
|
|
|
@ -155,6 +155,16 @@ class ToolConfig(TypedDict):
|
||||||
functionCallingConfig: FunctionCallingConfig
|
functionCallingConfig: FunctionCallingConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TTL(TypedDict, total=False):
|
||||||
|
seconds: Required[float]
|
||||||
|
nano: float
|
||||||
|
|
||||||
|
|
||||||
|
class CachedContent(TypedDict, total=False):
|
||||||
|
ttl: TTL
|
||||||
|
expire_time: str
|
||||||
|
|
||||||
|
|
||||||
class RequestBody(TypedDict, total=False):
|
class RequestBody(TypedDict, total=False):
|
||||||
contents: Required[List[ContentType]]
|
contents: Required[List[ContentType]]
|
||||||
system_instruction: SystemInstructions
|
system_instruction: SystemInstructions
|
||||||
|
@ -162,6 +172,7 @@ class RequestBody(TypedDict, total=False):
|
||||||
toolConfig: ToolConfig
|
toolConfig: ToolConfig
|
||||||
safetySettings: List[SafetSettingsConfig]
|
safetySettings: List[SafetSettingsConfig]
|
||||||
generationConfig: GenerationConfig
|
generationConfig: GenerationConfig
|
||||||
|
cachedContent: str
|
||||||
|
|
||||||
|
|
||||||
class SafetyRatings(TypedDict):
|
class SafetyRatings(TypedDict):
|
||||||
|
|
|
@ -4815,6 +4815,12 @@ def function_to_dict(input_function): # noqa: C901
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def modify_url(original_url, new_path):
|
||||||
|
url = httpx.URL(original_url)
|
||||||
|
modified_url = url.copy_with(path=new_path)
|
||||||
|
return str(modified_url)
|
||||||
|
|
||||||
|
|
||||||
def load_test_model(
|
def load_test_model(
|
||||||
model: str,
|
model: str,
|
||||||
custom_llm_provider: str = "",
|
custom_llm_provider: str = "",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue