diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 69217fac2..da8189c4a 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -55,7 +55,6 @@ from ..types.llms.openai import ( Thread, ) from .base import BaseLLM -from .custom_httpx.azure_dall_e_2 import AsyncCustomHTTPTransport, CustomHTTPTransport azure_ad_cache = DualCache() @@ -1718,9 +1717,7 @@ class AzureChatCompletion(BaseLLM): input: Optional[list] = None, prompt: Optional[str] = None, ) -> dict: - client_session = litellm.client_session or httpx.Client( - transport=CustomHTTPTransport(), # handle dall-e-2 calls - ) + client_session = litellm.client_session or httpx.Client() if "gateway.ai.cloudflare.com" in api_base: ## build base url - assume api base includes resource name if not api_base.endswith("/"): @@ -1793,9 +1790,10 @@ class AzureChatCompletion(BaseLLM): input: Optional[list] = None, prompt: Optional[str] = None, ) -> dict: - client_session = litellm.aclient_session or httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls - ) + client_session = ( + litellm.aclient_session or httpx.AsyncClient() + ) # handle dall-e-2 calls + if "gateway.ai.cloudflare.com" in api_base: ## build base url - assume api base includes resource name if not api_base.endswith("/"): diff --git a/litellm/llms/azure_text.py b/litellm/llms/azure_text.py index 640ab8222..72d6f134b 100644 --- a/litellm/llms/azure_text.py +++ b/litellm/llms/azure_text.py @@ -1,24 +1,27 @@ -from typing import Optional, Union, Any -import types, requests # type: ignore -from .base import BaseLLM -from litellm.utils import ( - ModelResponse, - Choices, - Message, - CustomStreamWrapper, - convert_to_model_response_object, - TranscriptionResponse, - TextCompletionResponse, -) -from typing import Callable, Optional, BinaryIO -from litellm import OpenAIConfig -import litellm, json -import httpx -from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport -from openai import AzureOpenAI, AsyncAzureOpenAI -from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig +import json +import types # type: ignore import uuid -from .prompt_templates.factory import prompt_factory, custom_prompt +from typing import Any, BinaryIO, Callable, Optional, Union + +import httpx +import requests +from openai import AsyncAzureOpenAI, AzureOpenAI + +import litellm +from litellm import OpenAIConfig +from litellm.utils import ( + Choices, + CustomStreamWrapper, + Message, + ModelResponse, + TextCompletionResponse, + TranscriptionResponse, + convert_to_model_response_object, +) + +from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig +from .base import BaseLLM +from .prompt_templates.factory import custom_prompt, prompt_factory openai_text_completion_config = OpenAITextCompletionConfig() diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py deleted file mode 100644 index a6726eb98..000000000 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ /dev/null @@ -1,143 +0,0 @@ -import asyncio -import json -import time - -import httpx - - -class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): - """ - Async implementation of custom http transport - """ - - async def handle_async_request(self, request: httpx.Request) -> httpx.Response: - _api_version = request.url.params.get("api-version", "") - if ( - "images/generations" in request.url.path - and _api_version - in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict - "2023-06-01-preview", - "2023-07-01-preview", - "2023-08-01-preview", - "2023-09-01-preview", - "2023-10-01-preview", - ] - ): - request.url = request.url.copy_with( - path="/openai/images/generations:submit" - ) - response = await super().handle_async_request(request) - operation_location_url = response.headers["operation-location"] - request.url = httpx.URL(operation_location_url) - request.method = "GET" - response = await super().handle_async_request(request) - await response.aread() - - timeout_secs: int = 120 - start_time = time.time() - while response.json()["status"] not in ["succeeded", "failed"]: - if time.time() - start_time > timeout_secs: - timeout = { - "error": { - "code": "Timeout", - "message": "Operation polling timed out.", - } - } - return httpx.Response( - status_code=400, - headers=response.headers, - content=json.dumps(timeout).encode("utf-8"), - request=request, - ) - - await asyncio.sleep(int(response.headers.get("retry-after") or 10)) - response = await super().handle_async_request(request) - await response.aread() - - if response.json()["status"] == "failed": - error_data = response.json() - return httpx.Response( - status_code=400, - headers=response.headers, - content=json.dumps(error_data).encode("utf-8"), - request=request, - ) - - result = response.json()["result"] - return httpx.Response( - status_code=200, - headers=response.headers, - content=json.dumps(result).encode("utf-8"), - request=request, - ) - return await super().handle_async_request(request) - - -class CustomHTTPTransport(httpx.HTTPTransport): - """ - This class was written as a workaround to support dall-e-2 on openai > v1.x - - Refer to this issue for more: https://github.com/openai/openai-python/issues/692 - """ - - def handle_request( - self, - request: httpx.Request, - ) -> httpx.Response: - _api_version = request.url.params.get("api-version", "") - if ( - "images/generations" in request.url.path - and _api_version - in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict - "2023-06-01-preview", - "2023-07-01-preview", - "2023-08-01-preview", - "2023-09-01-preview", - "2023-10-01-preview", - ] - ): - request.url = request.url.copy_with( - path="/openai/images/generations:submit" - ) - response = super().handle_request(request) - operation_location_url = response.headers["operation-location"] - request.url = httpx.URL(operation_location_url) - request.method = "GET" - response = super().handle_request(request) - response.read() - timeout_secs: int = 120 - start_time = time.time() - while response.json()["status"] not in ["succeeded", "failed"]: - if time.time() - start_time > timeout_secs: - timeout = { - "error": { - "code": "Timeout", - "message": "Operation polling timed out.", - } - } - return httpx.Response( - status_code=400, - headers=response.headers, - content=json.dumps(timeout).encode("utf-8"), - request=request, - ) - time.sleep(int(response.headers.get("retry-after", None) or 10)) - response = super().handle_request(request) - response.read() - if response.json()["status"] == "failed": - error_data = response.json() - return httpx.Response( - status_code=400, - headers=response.headers, - content=json.dumps(error_data).encode("utf-8"), - request=request, - ) - - result = response.json()["result"] - return httpx.Response( - status_code=200, - headers=response.headers, - content=json.dumps(result).encode("utf-8"), - request=request, - ) - return super().handle_request(request) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 9b01c96b1..ef79b487f 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -26,30 +26,12 @@ class AsyncHTTPHandler: self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int ) -> httpx.AsyncClient: - async_proxy_mounts = None # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. - http_proxy = os.getenv("HTTP_PROXY", None) - https_proxy = os.getenv("HTTPS_PROXY", None) - no_proxy = os.getenv("NO_PROXY", None) ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify)) cert = os.getenv( "SSL_CERTIFICATE", litellm.ssl_certificate ) # /path/to/client.pem - if http_proxy is not None and https_proxy is not None: - async_proxy_mounts = { - "http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)), - "https://": httpx.AsyncHTTPTransport( - proxy=httpx.Proxy(url=https_proxy) - ), - } - # assume no_proxy is a list of comma separated urls - if no_proxy is not None and isinstance(no_proxy, str): - no_proxy_urls = no_proxy.split(",") - - for url in no_proxy_urls: # set no-proxy support for specific urls - async_proxy_mounts[url] = None # type: ignore - if timeout is None: timeout = _DEFAULT_TIMEOUT # Create a client with a connection pool @@ -61,7 +43,6 @@ class AsyncHTTPHandler: max_keepalive_connections=concurrent_limit, ), verify=ssl_verify, - mounts=async_proxy_mounts, cert=cert, ) @@ -163,27 +144,11 @@ class HTTPHandler: timeout = _DEFAULT_TIMEOUT # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. - http_proxy = os.getenv("HTTP_PROXY", None) - https_proxy = os.getenv("HTTPS_PROXY", None) - no_proxy = os.getenv("NO_PROXY", None) ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify)) cert = os.getenv( "SSL_CERTIFICATE", litellm.ssl_certificate ) # /path/to/client.pem - sync_proxy_mounts = None - if http_proxy is not None and https_proxy is not None: - sync_proxy_mounts = { - "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)), - "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)), - } - # assume no_proxy is a list of comma separated urls - if no_proxy is not None and isinstance(no_proxy, str): - no_proxy_urls = no_proxy.split(",") - - for url in no_proxy_urls: # set no-proxy support for specific urls - sync_proxy_mounts[url] = None # type: ignore - if client is None: # Create a client with a connection pool self.client = httpx.Client( @@ -193,7 +158,6 @@ class HTTPHandler: max_keepalive_connections=concurrent_limit, ), verify=ssl_verify, - mounts=sync_proxy_mounts, cert=cert, ) else: diff --git a/litellm/main.py b/litellm/main.py index 37ae125b9..1bd33c4b3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -893,7 +893,7 @@ def completion( if ( supports_system_message is not None and isinstance(supports_system_message, bool) - and supports_system_message == False + and supports_system_message is False ): messages = map_system_message_pt(messages=messages) model_api_key = get_api_key( diff --git a/litellm/router.py b/litellm/router.py index db68197a4..d8364fa24 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -46,10 +46,6 @@ from litellm._logging import verbose_router_logger from litellm.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc -from litellm.llms.custom_httpx.azure_dall_e_2 import ( - AsyncCustomHTTPTransport, - CustomHTTPTransport, -) from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler @@ -93,6 +89,7 @@ from litellm.utils import ( ModelResponse, _is_region_eu, calculate_max_parallel_requests, + create_proxy_transport_and_mounts, get_utc_datetime, ) diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 0160ffda1..073a87901 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -1,16 +1,14 @@ import asyncio +import os import traceback from typing import TYPE_CHECKING, Any +import httpx import openai import litellm from litellm._logging import verbose_router_logger from litellm.llms.azure import get_azure_ad_token_from_oidc -from litellm.llms.custom_httpx.azure_dall_e_2 import ( - AsyncCustomHTTPTransport, - CustomHTTPTransport, -) from litellm.utils import calculate_max_parallel_requests if TYPE_CHECKING: @@ -169,39 +167,6 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): max_retries = litellm.get_secret(max_retries_env_name) litellm_params["max_retries"] = max_retries - # proxy support - import os - - import httpx - - # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. - http_proxy = os.getenv("HTTP_PROXY", None) - https_proxy = os.getenv("HTTPS_PROXY", None) - no_proxy = os.getenv("NO_PROXY", None) - - # Create the proxies dictionary only if the environment variables are set. - sync_proxy_mounts = None - async_proxy_mounts = None - if http_proxy is not None and https_proxy is not None: - sync_proxy_mounts = { - "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)), - "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)), - } - async_proxy_mounts = { - "http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)), - "https://": httpx.AsyncHTTPTransport( - proxy=httpx.Proxy(url=https_proxy) - ), - } - - # assume no_proxy is a list of comma separated urls - if no_proxy is not None and isinstance(no_proxy, str): - no_proxy_urls = no_proxy.split(",") - - for url in no_proxy_urls: # set no-proxy support for specific urls - sync_proxy_mounts[url] = None # type: ignore - async_proxy_mounts[url] = None # type: ignore - organization = litellm_params.get("organization", None) if isinstance(organization, str) and organization.startswith("os.environ/"): organization_env_name = organization.replace("os.environ/", "") @@ -241,13 +206,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): timeout=timeout, max_retries=max_retries, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=async_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( @@ -269,13 +231,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): timeout=timeout, max_retries=max_retries, http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=sync_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( @@ -294,13 +253,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): timeout=stream_timeout, max_retries=max_retries, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=async_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( @@ -322,13 +278,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): timeout=stream_timeout, max_retries=max_retries, http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=sync_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( @@ -365,13 +318,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): timeout=timeout, max_retries=max_retries, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=async_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( @@ -389,13 +339,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): timeout=timeout, max_retries=max_retries, http_client=httpx.Client( - transport=CustomHTTPTransport( - verify=litellm.ssl_verify, - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=sync_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( @@ -412,13 +359,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): timeout=stream_timeout, max_retries=max_retries, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=async_proxy_mounts, + verify=litellm.ssl_verify, ), ) litellm_router_instance.cache.set_cache( @@ -437,13 +381,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): timeout=stream_timeout, max_retries=max_retries, http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=sync_proxy_mounts, + verify=litellm.ssl_verify, ), ) litellm_router_instance.cache.set_cache( @@ -469,13 +410,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): max_retries=max_retries, organization=organization, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=async_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( @@ -496,13 +434,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): max_retries=max_retries, organization=organization, http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=sync_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( @@ -521,13 +456,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): max_retries=max_retries, organization=organization, http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=async_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( @@ -549,13 +481,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): max_retries=max_retries, organization=organization, http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 ), - mounts=sync_proxy_mounts, + verify=litellm.ssl_verify, ), # type: ignore ) litellm_router_instance.cache.set_cache( diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 43525afcb..86506a589 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -1894,6 +1894,49 @@ async def test_router_model_usage(mock_response): raise e +@pytest.mark.skip(reason="Check if this is causing ci/cd issues.") +@pytest.mark.asyncio +async def test_is_proxy_set(): + """ + Assert if proxy is set + """ + from httpx import AsyncHTTPTransport + + os.environ["HTTPS_PROXY"] = "https://proxy.example.com:8080" + from openai import AsyncAzureOpenAI + + # Function to check if a proxy is set on the client + # Function to check if a proxy is set on the client + def check_proxy(client: httpx.AsyncClient) -> bool: + print(f"client._mounts: {client._mounts}") + assert len(client._mounts) == 1 + for k, v in client._mounts.items(): + assert isinstance(v, AsyncHTTPTransport) + return True + + llm_router = Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": { + "model": "azure/gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "mock_response": "hello world", + }, + "model_info": {"id": "1"}, + } + ] + ) + + _deployment = llm_router.get_deployment(model_id="1") + model_client: AsyncAzureOpenAI = llm_router._get_client( + deployment=_deployment, kwargs={}, client_type="async" + ) # type: ignore + + assert check_proxy(client=model_client._client) + + @pytest.mark.parametrize( "model, base_model, llm_provider", [ diff --git a/litellm/utils.py b/litellm/utils.py index 13be18422..b95d54eb7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -42,6 +42,8 @@ import httpx import openai import requests import tiktoken +from httpx import Proxy +from httpx._utils import get_environment_proxies from pydantic import BaseModel from tokenizers import Tokenizer @@ -4913,6 +4915,34 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]: return [] +def create_proxy_transport_and_mounts(): + proxies = { + key: None if url is None else Proxy(url=url) + for key, url in get_environment_proxies().items() + } + + sync_proxy_mounts = {} + async_proxy_mounts = {} + + # Retrieve NO_PROXY environment variable + no_proxy = os.getenv("NO_PROXY", None) + no_proxy_urls = no_proxy.split(",") if no_proxy else [] + + for key, proxy in proxies.items(): + if proxy is None: + sync_proxy_mounts[key] = httpx.HTTPTransport() + async_proxy_mounts[key] = httpx.AsyncHTTPTransport() + else: + sync_proxy_mounts[key] = httpx.HTTPTransport(proxy=proxy) + async_proxy_mounts[key] = httpx.AsyncHTTPTransport(proxy=proxy) + + for url in no_proxy_urls: + sync_proxy_mounts[url] = httpx.HTTPTransport() + async_proxy_mounts[url] = httpx.AsyncHTTPTransport() + + return sync_proxy_mounts, async_proxy_mounts + + def validate_environment(model: Optional[str] = None) -> dict: """ Checks if the environment variables are valid for the given model.