diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 000feed44..7221fdbe9 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -1,6 +1,7 @@ import asyncio import json import os +import time import types import uuid from typing import ( @@ -21,9 +22,10 @@ from openai import AsyncAzureOpenAI, AzureOpenAI from typing_extensions import overload import litellm -from litellm import OpenAIConfig +from litellm import ImageResponse, OpenAIConfig from litellm.caching import DualCache from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.utils import ( Choices, CustomStreamWrapper, @@ -33,6 +35,7 @@ from litellm.utils import ( UnsupportedParamsError, convert_to_model_response_object, get_secret, + modify_url, ) from ..types.llms.openai import ( @@ -1051,6 +1054,135 @@ class AzureChatCompletion(BaseLLM): else: raise AzureOpenAIError(status_code=500, message=str(e)) + async def make_async_azure_httpx_request( + self, + client: Optional[AsyncHTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + api_base: str, + api_version: str, + api_key: str, + data: dict, + ) -> httpx.Response: + """ + Implemented for azure dall-e-2 image gen calls + + Alternative to needing a custom transport implementation + """ + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + async_handler = AsyncHTTPHandler(**_params) # type: ignore + else: + async_handler = client # type: ignore + + if ( + "images/generations" in api_base + and api_version + in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ] + ): # CREATE + POLL for azure dall-e-2 calls + + api_base = modify_url( + original_url=api_base, new_path="/openai/images/generations:submit" + ) + + data.pop( + "model", None + ) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview + response = await async_handler.post( + url=api_base, + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "api-key": api_key, + }, + ) + operation_location_url = response.headers["operation-location"] + response = await async_handler.get( + url=operation_location_url, + headers={ + "api-key": api_key, + }, + ) + + await response.aread() + + timeout_secs: int = 120 + start_time = time.time() + if "status" not in response.json(): + raise Exception( + "Expected 'status' in response. Got={}".format(response.json()) + ) + while response.json()["status"] not in ["succeeded", "failed"]: + if time.time() - start_time > timeout_secs: + timeout_msg = { + "error": { + "code": "Timeout", + "message": "Operation polling timed out.", + } + } + + raise AzureOpenAIError( + status_code=408, message="Operation polling timed out." + ) + + await asyncio.sleep(int(response.headers.get("retry-after") or 10)) + response = await async_handler.get( + url=operation_location_url, + headers={ + "api-key": api_key, + }, + ) + await response.aread() + + if response.json()["status"] == "failed": + error_data = response.json() + raise AzureOpenAIError(status_code=400, message=json.dumps(error_data)) + + return response + return await async_handler.post( + url=api_base, + json=data, + headers={ + "Content-Type": "application/json;", + "api-key": api_key, + }, + ) + + def create_azure_base_url( + self, azure_client_params: dict, model: Optional[str] + ) -> str: + + api_base: str = azure_client_params.get( + "azure_endpoint", "" + ) # "https://example-endpoint.openai.azure.com" + if api_base.endswith("/"): + api_base = api_base.rstrip("/") + api_version: str = azure_client_params.get("api_version", "") + if model is None: + model = "" + new_api_base = ( + api_base + + "/openai/deployments/" + + model + + "/images/generations" + + "?api-version=" + + api_version + ) + + return new_api_base + async def aimage_generation( self, data: dict, @@ -1062,30 +1194,40 @@ class AzureChatCompletion(BaseLLM): logging_obj=None, timeout=None, ): - response = None + response: Optional[dict] = None try: - if client is None: - client_session = litellm.aclient_session or httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), - ) - azure_client = AsyncAzureOpenAI( - http_client=client_session, **azure_client_params - ) - else: - azure_client = client - ## LOGGING - logging_obj.pre_call( - input=data["prompt"], - api_key=azure_client.api_key, - additional_args={ - "headers": {"api_key": azure_client.api_key}, - "api_base": azure_client._base_url._uri_reference, - "acompletion": True, - "complete_input_dict": data, - }, + # ## LOGGING + # logging_obj.pre_call( + # input=data["prompt"], + # api_key=azure_client.api_key, + # additional_args={ + # "headers": {"api_key": azure_client.api_key}, + # "api_base": azure_client._base_url._uri_reference, + # "acompletion": True, + # "complete_input_dict": data, + # }, + # ) + # response = await azure_client.images.generate(**data, timeout=timeout) + api_base: str = azure_client_params.get( + "api_base", "" + ) # "https://example-endpoint.openai.azure.com" + if api_base.endswith("/"): + api_base = api_base.rstrip("/") + api_version: str = azure_client_params.get("api_version", "") + img_gen_api_base = self.create_azure_base_url( + azure_client_params=azure_client_params, model=data.get("model", "") ) - response = await azure_client.images.generate(**data, timeout=timeout) - stringified_response = response.model_dump() + httpx_response: httpx.Response = await self.make_async_azure_httpx_request( + client=None, + timeout=timeout, + api_base=img_gen_api_base, + api_version=api_version, + api_key=api_key, + data=data, + ) + response = httpx_response.json()["result"] + + stringified_response = response ## LOGGING logging_obj.post_call( input=input, diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 49ec18f24..16a759c9e 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -1,20 +1,23 @@ # What this tests? ## This tests the litellm support for the openai /generations endpoint -import sys, os -import traceback -from dotenv import load_dotenv import logging +import os +import sys +import traceback + +from dotenv import load_dotenv logging.basicConfig(level=logging.DEBUG) load_dotenv() -import os import asyncio +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest + import litellm @@ -39,9 +42,10 @@ def test_image_generation_openai(): # test_image_generation_openai() -def test_image_generation_azure(): +@pytest.mark.asyncio +async def test_image_generation_azure(): try: - response = litellm.image_generation( + response = await litellm.aimage_generation( prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview", diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 2dda57c2e..17fc26d60 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -155,6 +155,16 @@ class ToolConfig(TypedDict): functionCallingConfig: FunctionCallingConfig +class TTL(TypedDict, total=False): + seconds: Required[float] + nano: float + + +class CachedContent(TypedDict, total=False): + ttl: TTL + expire_time: str + + class RequestBody(TypedDict, total=False): contents: Required[List[ContentType]] system_instruction: SystemInstructions @@ -162,6 +172,7 @@ class RequestBody(TypedDict, total=False): toolConfig: ToolConfig safetySettings: List[SafetSettingsConfig] generationConfig: GenerationConfig + cachedContent: str class SafetyRatings(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index f8e8566f8..82e3ca171 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4815,6 +4815,12 @@ def function_to_dict(input_function): # noqa: C901 return result +def modify_url(original_url, new_path): + url = httpx.URL(original_url) + modified_url = url.copy_with(path=new_path) + return str(modified_url) + + def load_test_model( model: str, custom_llm_provider: str = "",