mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(router.py): fix setting httpx mounts
This commit is contained in:
parent
151d19960e
commit
98daedaf60
4 changed files with 93 additions and 26 deletions
|
@ -879,7 +879,7 @@ def completion(
|
||||||
if (
|
if (
|
||||||
supports_system_message is not None
|
supports_system_message is not None
|
||||||
and isinstance(supports_system_message, bool)
|
and isinstance(supports_system_message, bool)
|
||||||
and supports_system_message == False
|
and supports_system_message is False
|
||||||
):
|
):
|
||||||
messages = map_system_message_pt(messages=messages)
|
messages = map_system_message_pt(messages=messages)
|
||||||
model_api_key = get_api_key(
|
model_api_key = get_api_key(
|
||||||
|
|
|
@ -87,6 +87,7 @@ from litellm.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
_is_region_eu,
|
_is_region_eu,
|
||||||
calculate_max_parallel_requests,
|
calculate_max_parallel_requests,
|
||||||
|
create_proxy_transport_and_mounts,
|
||||||
get_utc_datetime,
|
get_utc_datetime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3316,34 +3317,32 @@ class Router:
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
||||||
http_proxy = os.getenv("HTTP_PROXY", None)
|
# http_proxy = os.getenv("HTTP_PROXY", None)
|
||||||
https_proxy = os.getenv("HTTPS_PROXY", None)
|
# https_proxy = os.getenv("HTTPS_PROXY", None)
|
||||||
no_proxy = os.getenv("NO_PROXY", None)
|
# no_proxy = os.getenv("NO_PROXY", None)
|
||||||
|
|
||||||
# Create the proxies dictionary only if the environment variables are set.
|
# Create the proxies dictionary only if the environment variables are set.
|
||||||
sync_proxy_mounts = None
|
sync_proxy_mounts, async_proxy_mounts = create_proxy_transport_and_mounts()
|
||||||
async_proxy_mounts = None
|
# if http_proxy is not None and https_proxy is not None:
|
||||||
if http_proxy is not None and https_proxy is not None:
|
# sync_proxy_mounts = {
|
||||||
sync_proxy_mounts = {
|
# "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
|
||||||
"http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
|
# "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
|
||||||
"https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
|
# }
|
||||||
}
|
# async_proxy_mounts = {
|
||||||
async_proxy_mounts = {
|
# "http://": httpx.AsyncHTTPTransport(
|
||||||
"http://": httpx.AsyncHTTPTransport(
|
# proxy=httpx.Proxy(url=http_proxy)
|
||||||
proxy=httpx.Proxy(url=http_proxy)
|
# ),
|
||||||
),
|
# "https://": httpx.AsyncHTTPTransport(
|
||||||
"https://": httpx.AsyncHTTPTransport(
|
# proxy=httpx.Proxy(url=https_proxy)
|
||||||
proxy=httpx.Proxy(url=https_proxy)
|
# ),
|
||||||
),
|
# }
|
||||||
}
|
|
||||||
|
|
||||||
# assume no_proxy is a list of comma separated urls
|
# # assume no_proxy is a list of comma separated urls
|
||||||
if no_proxy is not None and isinstance(no_proxy, str):
|
# if no_proxy is not None and isinstance(no_proxy, str):
|
||||||
no_proxy_urls = no_proxy.split(",")
|
# no_proxy_urls = no_proxy.split(",")
|
||||||
|
|
||||||
for url in no_proxy_urls: # set no-proxy support for specific urls
|
# for url in no_proxy_urls: # set no-proxy support for specific urls
|
||||||
sync_proxy_mounts[url] = None # type: ignore
|
# sync_proxy_mounts[url] = None # type: ignore
|
||||||
async_proxy_mounts[url] = None # type: ignore
|
# async_proxy_mounts[url] = None # type: ignore
|
||||||
|
|
||||||
organization = litellm_params.get("organization", None)
|
organization = litellm_params.get("organization", None)
|
||||||
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
||||||
|
|
|
@ -1884,3 +1884,41 @@ async def test_router_model_usage(mock_response):
|
||||||
else:
|
else:
|
||||||
print(f"allowed_fails: {allowed_fails}")
|
print(f"allowed_fails: {allowed_fails}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_proxy_set():
|
||||||
|
"""
|
||||||
|
Assert if proxy is set
|
||||||
|
"""
|
||||||
|
from httpcore import AsyncHTTPProxy
|
||||||
|
|
||||||
|
os.environ["HTTPS_PROXY"] = "https://proxy.example.com:8080"
|
||||||
|
from openai import AsyncAzureOpenAI
|
||||||
|
|
||||||
|
# Function to check if a proxy is set on the client
|
||||||
|
# Function to check if a proxy is set on the client
|
||||||
|
def check_proxy(client: httpx.AsyncClient) -> bool:
|
||||||
|
return isinstance(client._transport.__dict__["_pool"], AsyncHTTPProxy)
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"mock_response": "hello world",
|
||||||
|
},
|
||||||
|
"model_info": {"id": "1"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
_deployment = llm_router.get_deployment(model_id="1")
|
||||||
|
model_client: AsyncAzureOpenAI = llm_router._get_client(
|
||||||
|
deployment=_deployment, kwargs={}, client_type="async"
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
assert check_proxy(client=model_client._client) is True
|
||||||
|
|
|
@ -42,6 +42,8 @@ import httpx
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
from httpx import Proxy
|
||||||
|
from httpx._utils import get_environment_proxies
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
|
@ -4803,6 +4805,34 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def create_proxy_transport_and_mounts():
|
||||||
|
proxies = {
|
||||||
|
key: None if url is None else Proxy(url=url)
|
||||||
|
for key, url in get_environment_proxies().items()
|
||||||
|
}
|
||||||
|
|
||||||
|
sync_proxy_mounts = {}
|
||||||
|
async_proxy_mounts = {}
|
||||||
|
|
||||||
|
# Retrieve NO_PROXY environment variable
|
||||||
|
no_proxy = os.getenv("NO_PROXY", None)
|
||||||
|
no_proxy_urls = no_proxy.split(",") if no_proxy else []
|
||||||
|
|
||||||
|
for key, proxy in proxies.items():
|
||||||
|
if proxy is None:
|
||||||
|
sync_proxy_mounts[key] = httpx.HTTPTransport()
|
||||||
|
async_proxy_mounts[key] = httpx.AsyncHTTPTransport()
|
||||||
|
else:
|
||||||
|
sync_proxy_mounts[key] = httpx.HTTPTransport(proxy=proxy)
|
||||||
|
async_proxy_mounts[key] = httpx.AsyncHTTPTransport(proxy=proxy)
|
||||||
|
|
||||||
|
for url in no_proxy_urls:
|
||||||
|
sync_proxy_mounts[url] = httpx.HTTPTransport()
|
||||||
|
async_proxy_mounts[url] = httpx.AsyncHTTPTransport()
|
||||||
|
|
||||||
|
return sync_proxy_mounts, async_proxy_mounts
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(model: Optional[str] = None) -> dict:
|
def validate_environment(model: Optional[str] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Checks if the environment variables are valid for the given model.
|
Checks if the environment variables are valid for the given model.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue