forked from phoenix/litellm-mirror
Merge pull request #4434 from BerriAI/litellm_fix_httpx_transport
fix(router.py): fix setting httpx mounts
This commit is contained in:
commit
49d7faa31e
9 changed files with 141 additions and 320 deletions
|
@ -55,7 +55,6 @@ from ..types.llms.openai import (
|
||||||
Thread,
|
Thread,
|
||||||
)
|
)
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from .custom_httpx.azure_dall_e_2 import AsyncCustomHTTPTransport, CustomHTTPTransport
|
|
||||||
|
|
||||||
azure_ad_cache = DualCache()
|
azure_ad_cache = DualCache()
|
||||||
|
|
||||||
|
@ -1718,9 +1717,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
input: Optional[list] = None,
|
input: Optional[list] = None,
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
client_session = litellm.client_session or httpx.Client(
|
client_session = litellm.client_session or httpx.Client()
|
||||||
transport=CustomHTTPTransport(), # handle dall-e-2 calls
|
|
||||||
)
|
|
||||||
if "gateway.ai.cloudflare.com" in api_base:
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
## build base url - assume api base includes resource name
|
## build base url - assume api base includes resource name
|
||||||
if not api_base.endswith("/"):
|
if not api_base.endswith("/"):
|
||||||
|
@ -1793,9 +1790,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
input: Optional[list] = None,
|
input: Optional[list] = None,
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
client_session = (
|
||||||
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
litellm.aclient_session or httpx.AsyncClient()
|
||||||
)
|
) # handle dall-e-2 calls
|
||||||
|
|
||||||
if "gateway.ai.cloudflare.com" in api_base:
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
## build base url - assume api base includes resource name
|
## build base url - assume api base includes resource name
|
||||||
if not api_base.endswith("/"):
|
if not api_base.endswith("/"):
|
||||||
|
|
|
@ -1,24 +1,27 @@
|
||||||
from typing import Optional, Union, Any
|
import json
|
||||||
import types, requests # type: ignore
|
import types # type: ignore
|
||||||
from .base import BaseLLM
|
|
||||||
from litellm.utils import (
|
|
||||||
ModelResponse,
|
|
||||||
Choices,
|
|
||||||
Message,
|
|
||||||
CustomStreamWrapper,
|
|
||||||
convert_to_model_response_object,
|
|
||||||
TranscriptionResponse,
|
|
||||||
TextCompletionResponse,
|
|
||||||
)
|
|
||||||
from typing import Callable, Optional, BinaryIO
|
|
||||||
from litellm import OpenAIConfig
|
|
||||||
import litellm, json
|
|
||||||
import httpx
|
|
||||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
|
||||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
|
||||||
from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig
|
|
||||||
import uuid
|
import uuid
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from typing import Any, BinaryIO, Callable, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import requests
|
||||||
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import OpenAIConfig
|
||||||
|
from litellm.utils import (
|
||||||
|
Choices,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
TranscriptionResponse,
|
||||||
|
convert_to_model_response_object,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig
|
||||||
|
from .base import BaseLLM
|
||||||
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
openai_text_completion_config = OpenAITextCompletionConfig()
|
openai_text_completion_config = OpenAITextCompletionConfig()
|
||||||
|
|
||||||
|
|
|
@ -1,143 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
|
||||||
"""
|
|
||||||
Async implementation of custom http transport
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
|
||||||
_api_version = request.url.params.get("api-version", "")
|
|
||||||
if (
|
|
||||||
"images/generations" in request.url.path
|
|
||||||
and _api_version
|
|
||||||
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
|
||||||
"2023-06-01-preview",
|
|
||||||
"2023-07-01-preview",
|
|
||||||
"2023-08-01-preview",
|
|
||||||
"2023-09-01-preview",
|
|
||||||
"2023-10-01-preview",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
request.url = request.url.copy_with(
|
|
||||||
path="/openai/images/generations:submit"
|
|
||||||
)
|
|
||||||
response = await super().handle_async_request(request)
|
|
||||||
operation_location_url = response.headers["operation-location"]
|
|
||||||
request.url = httpx.URL(operation_location_url)
|
|
||||||
request.method = "GET"
|
|
||||||
response = await super().handle_async_request(request)
|
|
||||||
await response.aread()
|
|
||||||
|
|
||||||
timeout_secs: int = 120
|
|
||||||
start_time = time.time()
|
|
||||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
|
||||||
if time.time() - start_time > timeout_secs:
|
|
||||||
timeout = {
|
|
||||||
"error": {
|
|
||||||
"code": "Timeout",
|
|
||||||
"message": "Operation polling timed out.",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return httpx.Response(
|
|
||||||
status_code=400,
|
|
||||||
headers=response.headers,
|
|
||||||
content=json.dumps(timeout).encode("utf-8"),
|
|
||||||
request=request,
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.sleep(int(response.headers.get("retry-after") or 10))
|
|
||||||
response = await super().handle_async_request(request)
|
|
||||||
await response.aread()
|
|
||||||
|
|
||||||
if response.json()["status"] == "failed":
|
|
||||||
error_data = response.json()
|
|
||||||
return httpx.Response(
|
|
||||||
status_code=400,
|
|
||||||
headers=response.headers,
|
|
||||||
content=json.dumps(error_data).encode("utf-8"),
|
|
||||||
request=request,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = response.json()["result"]
|
|
||||||
return httpx.Response(
|
|
||||||
status_code=200,
|
|
||||||
headers=response.headers,
|
|
||||||
content=json.dumps(result).encode("utf-8"),
|
|
||||||
request=request,
|
|
||||||
)
|
|
||||||
return await super().handle_async_request(request)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomHTTPTransport(httpx.HTTPTransport):
|
|
||||||
"""
|
|
||||||
This class was written as a workaround to support dall-e-2 on openai > v1.x
|
|
||||||
|
|
||||||
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
|
|
||||||
"""
|
|
||||||
|
|
||||||
def handle_request(
|
|
||||||
self,
|
|
||||||
request: httpx.Request,
|
|
||||||
) -> httpx.Response:
|
|
||||||
_api_version = request.url.params.get("api-version", "")
|
|
||||||
if (
|
|
||||||
"images/generations" in request.url.path
|
|
||||||
and _api_version
|
|
||||||
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
|
||||||
"2023-06-01-preview",
|
|
||||||
"2023-07-01-preview",
|
|
||||||
"2023-08-01-preview",
|
|
||||||
"2023-09-01-preview",
|
|
||||||
"2023-10-01-preview",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
request.url = request.url.copy_with(
|
|
||||||
path="/openai/images/generations:submit"
|
|
||||||
)
|
|
||||||
response = super().handle_request(request)
|
|
||||||
operation_location_url = response.headers["operation-location"]
|
|
||||||
request.url = httpx.URL(operation_location_url)
|
|
||||||
request.method = "GET"
|
|
||||||
response = super().handle_request(request)
|
|
||||||
response.read()
|
|
||||||
timeout_secs: int = 120
|
|
||||||
start_time = time.time()
|
|
||||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
|
||||||
if time.time() - start_time > timeout_secs:
|
|
||||||
timeout = {
|
|
||||||
"error": {
|
|
||||||
"code": "Timeout",
|
|
||||||
"message": "Operation polling timed out.",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return httpx.Response(
|
|
||||||
status_code=400,
|
|
||||||
headers=response.headers,
|
|
||||||
content=json.dumps(timeout).encode("utf-8"),
|
|
||||||
request=request,
|
|
||||||
)
|
|
||||||
time.sleep(int(response.headers.get("retry-after", None) or 10))
|
|
||||||
response = super().handle_request(request)
|
|
||||||
response.read()
|
|
||||||
if response.json()["status"] == "failed":
|
|
||||||
error_data = response.json()
|
|
||||||
return httpx.Response(
|
|
||||||
status_code=400,
|
|
||||||
headers=response.headers,
|
|
||||||
content=json.dumps(error_data).encode("utf-8"),
|
|
||||||
request=request,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = response.json()["result"]
|
|
||||||
return httpx.Response(
|
|
||||||
status_code=200,
|
|
||||||
headers=response.headers,
|
|
||||||
content=json.dumps(result).encode("utf-8"),
|
|
||||||
request=request,
|
|
||||||
)
|
|
||||||
return super().handle_request(request)
|
|
|
@ -26,30 +26,12 @@ class AsyncHTTPHandler:
|
||||||
self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int
|
self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int
|
||||||
) -> httpx.AsyncClient:
|
) -> httpx.AsyncClient:
|
||||||
|
|
||||||
async_proxy_mounts = None
|
|
||||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
||||||
http_proxy = os.getenv("HTTP_PROXY", None)
|
|
||||||
https_proxy = os.getenv("HTTPS_PROXY", None)
|
|
||||||
no_proxy = os.getenv("NO_PROXY", None)
|
|
||||||
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
|
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
|
||||||
cert = os.getenv(
|
cert = os.getenv(
|
||||||
"SSL_CERTIFICATE", litellm.ssl_certificate
|
"SSL_CERTIFICATE", litellm.ssl_certificate
|
||||||
) # /path/to/client.pem
|
) # /path/to/client.pem
|
||||||
|
|
||||||
if http_proxy is not None and https_proxy is not None:
|
|
||||||
async_proxy_mounts = {
|
|
||||||
"http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
|
|
||||||
"https://": httpx.AsyncHTTPTransport(
|
|
||||||
proxy=httpx.Proxy(url=https_proxy)
|
|
||||||
),
|
|
||||||
}
|
|
||||||
# assume no_proxy is a list of comma separated urls
|
|
||||||
if no_proxy is not None and isinstance(no_proxy, str):
|
|
||||||
no_proxy_urls = no_proxy.split(",")
|
|
||||||
|
|
||||||
for url in no_proxy_urls: # set no-proxy support for specific urls
|
|
||||||
async_proxy_mounts[url] = None # type: ignore
|
|
||||||
|
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = _DEFAULT_TIMEOUT
|
timeout = _DEFAULT_TIMEOUT
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
|
@ -61,7 +43,6 @@ class AsyncHTTPHandler:
|
||||||
max_keepalive_connections=concurrent_limit,
|
max_keepalive_connections=concurrent_limit,
|
||||||
),
|
),
|
||||||
verify=ssl_verify,
|
verify=ssl_verify,
|
||||||
mounts=async_proxy_mounts,
|
|
||||||
cert=cert,
|
cert=cert,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -163,27 +144,11 @@ class HTTPHandler:
|
||||||
timeout = _DEFAULT_TIMEOUT
|
timeout = _DEFAULT_TIMEOUT
|
||||||
|
|
||||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
||||||
http_proxy = os.getenv("HTTP_PROXY", None)
|
|
||||||
https_proxy = os.getenv("HTTPS_PROXY", None)
|
|
||||||
no_proxy = os.getenv("NO_PROXY", None)
|
|
||||||
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
|
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
|
||||||
cert = os.getenv(
|
cert = os.getenv(
|
||||||
"SSL_CERTIFICATE", litellm.ssl_certificate
|
"SSL_CERTIFICATE", litellm.ssl_certificate
|
||||||
) # /path/to/client.pem
|
) # /path/to/client.pem
|
||||||
|
|
||||||
sync_proxy_mounts = None
|
|
||||||
if http_proxy is not None and https_proxy is not None:
|
|
||||||
sync_proxy_mounts = {
|
|
||||||
"http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
|
|
||||||
"https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
|
|
||||||
}
|
|
||||||
# assume no_proxy is a list of comma separated urls
|
|
||||||
if no_proxy is not None and isinstance(no_proxy, str):
|
|
||||||
no_proxy_urls = no_proxy.split(",")
|
|
||||||
|
|
||||||
for url in no_proxy_urls: # set no-proxy support for specific urls
|
|
||||||
sync_proxy_mounts[url] = None # type: ignore
|
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
self.client = httpx.Client(
|
self.client = httpx.Client(
|
||||||
|
@ -193,7 +158,6 @@ class HTTPHandler:
|
||||||
max_keepalive_connections=concurrent_limit,
|
max_keepalive_connections=concurrent_limit,
|
||||||
),
|
),
|
||||||
verify=ssl_verify,
|
verify=ssl_verify,
|
||||||
mounts=sync_proxy_mounts,
|
|
||||||
cert=cert,
|
cert=cert,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -893,7 +893,7 @@ def completion(
|
||||||
if (
|
if (
|
||||||
supports_system_message is not None
|
supports_system_message is not None
|
||||||
and isinstance(supports_system_message, bool)
|
and isinstance(supports_system_message, bool)
|
||||||
and supports_system_message == False
|
and supports_system_message is False
|
||||||
):
|
):
|
||||||
messages = map_system_message_pt(messages=messages)
|
messages = map_system_message_pt(messages=messages)
|
||||||
model_api_key = get_api_key(
|
model_api_key = get_api_key(
|
||||||
|
|
|
@ -46,10 +46,6 @@ from litellm._logging import verbose_router_logger
|
||||||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
|
||||||
AsyncCustomHTTPTransport,
|
|
||||||
CustomHTTPTransport,
|
|
||||||
)
|
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
|
@ -93,6 +89,7 @@ from litellm.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
_is_region_eu,
|
_is_region_eu,
|
||||||
calculate_max_parallel_requests,
|
calculate_max_parallel_requests,
|
||||||
|
create_proxy_transport_and_mounts,
|
||||||
get_utc_datetime,
|
get_utc_datetime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
|
||||||
AsyncCustomHTTPTransport,
|
|
||||||
CustomHTTPTransport,
|
|
||||||
)
|
|
||||||
from litellm.utils import calculate_max_parallel_requests
|
from litellm.utils import calculate_max_parallel_requests
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -169,39 +167,6 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
max_retries = litellm.get_secret(max_retries_env_name)
|
max_retries = litellm.get_secret(max_retries_env_name)
|
||||||
litellm_params["max_retries"] = max_retries
|
litellm_params["max_retries"] = max_retries
|
||||||
|
|
||||||
# proxy support
|
|
||||||
import os
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
|
||||||
http_proxy = os.getenv("HTTP_PROXY", None)
|
|
||||||
https_proxy = os.getenv("HTTPS_PROXY", None)
|
|
||||||
no_proxy = os.getenv("NO_PROXY", None)
|
|
||||||
|
|
||||||
# Create the proxies dictionary only if the environment variables are set.
|
|
||||||
sync_proxy_mounts = None
|
|
||||||
async_proxy_mounts = None
|
|
||||||
if http_proxy is not None and https_proxy is not None:
|
|
||||||
sync_proxy_mounts = {
|
|
||||||
"http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
|
|
||||||
"https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
|
|
||||||
}
|
|
||||||
async_proxy_mounts = {
|
|
||||||
"http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
|
|
||||||
"https://": httpx.AsyncHTTPTransport(
|
|
||||||
proxy=httpx.Proxy(url=https_proxy)
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
# assume no_proxy is a list of comma separated urls
|
|
||||||
if no_proxy is not None and isinstance(no_proxy, str):
|
|
||||||
no_proxy_urls = no_proxy.split(",")
|
|
||||||
|
|
||||||
for url in no_proxy_urls: # set no-proxy support for specific urls
|
|
||||||
sync_proxy_mounts[url] = None # type: ignore
|
|
||||||
async_proxy_mounts[url] = None # type: ignore
|
|
||||||
|
|
||||||
organization = litellm_params.get("organization", None)
|
organization = litellm_params.get("organization", None)
|
||||||
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
||||||
organization_env_name = organization.replace("os.environ/", "")
|
organization_env_name = organization.replace("os.environ/", "")
|
||||||
|
@ -241,13 +206,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -269,13 +231,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -294,13 +253,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -322,13 +278,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -365,13 +318,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -389,13 +339,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
verify=litellm.ssl_verify,
|
max_connections=1000, max_keepalive_connections=100
|
||||||
limits=httpx.Limits(
|
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -412,13 +359,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -437,13 +381,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -469,13 +410,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -496,13 +434,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -521,13 +456,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.AsyncClient(
|
http_client=httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=async_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
@ -549,13 +481,10 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
http_client=httpx.Client(
|
http_client=httpx.Client(
|
||||||
transport=CustomHTTPTransport(
|
limits=httpx.Limits(
|
||||||
limits=httpx.Limits(
|
max_connections=1000, max_keepalive_connections=100
|
||||||
max_connections=1000, max_keepalive_connections=100
|
|
||||||
),
|
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
),
|
),
|
||||||
mounts=sync_proxy_mounts,
|
verify=litellm.ssl_verify,
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
)
|
)
|
||||||
litellm_router_instance.cache.set_cache(
|
litellm_router_instance.cache.set_cache(
|
||||||
|
|
|
@ -1894,6 +1894,49 @@ async def test_router_model_usage(mock_response):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Check if this is causing ci/cd issues.")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_proxy_set():
|
||||||
|
"""
|
||||||
|
Assert if proxy is set
|
||||||
|
"""
|
||||||
|
from httpx import AsyncHTTPTransport
|
||||||
|
|
||||||
|
os.environ["HTTPS_PROXY"] = "https://proxy.example.com:8080"
|
||||||
|
from openai import AsyncAzureOpenAI
|
||||||
|
|
||||||
|
# Function to check if a proxy is set on the client
|
||||||
|
# Function to check if a proxy is set on the client
|
||||||
|
def check_proxy(client: httpx.AsyncClient) -> bool:
|
||||||
|
print(f"client._mounts: {client._mounts}")
|
||||||
|
assert len(client._mounts) == 1
|
||||||
|
for k, v in client._mounts.items():
|
||||||
|
assert isinstance(v, AsyncHTTPTransport)
|
||||||
|
return True
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"mock_response": "hello world",
|
||||||
|
},
|
||||||
|
"model_info": {"id": "1"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
_deployment = llm_router.get_deployment(model_id="1")
|
||||||
|
model_client: AsyncAzureOpenAI = llm_router._get_client(
|
||||||
|
deployment=_deployment, kwargs={}, client_type="async"
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
assert check_proxy(client=model_client._client)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, base_model, llm_provider",
|
"model, base_model, llm_provider",
|
||||||
[
|
[
|
||||||
|
|
|
@ -42,6 +42,8 @@ import httpx
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
from httpx import Proxy
|
||||||
|
from httpx._utils import get_environment_proxies
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
|
@ -4913,6 +4915,34 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def create_proxy_transport_and_mounts():
|
||||||
|
proxies = {
|
||||||
|
key: None if url is None else Proxy(url=url)
|
||||||
|
for key, url in get_environment_proxies().items()
|
||||||
|
}
|
||||||
|
|
||||||
|
sync_proxy_mounts = {}
|
||||||
|
async_proxy_mounts = {}
|
||||||
|
|
||||||
|
# Retrieve NO_PROXY environment variable
|
||||||
|
no_proxy = os.getenv("NO_PROXY", None)
|
||||||
|
no_proxy_urls = no_proxy.split(",") if no_proxy else []
|
||||||
|
|
||||||
|
for key, proxy in proxies.items():
|
||||||
|
if proxy is None:
|
||||||
|
sync_proxy_mounts[key] = httpx.HTTPTransport()
|
||||||
|
async_proxy_mounts[key] = httpx.AsyncHTTPTransport()
|
||||||
|
else:
|
||||||
|
sync_proxy_mounts[key] = httpx.HTTPTransport(proxy=proxy)
|
||||||
|
async_proxy_mounts[key] = httpx.AsyncHTTPTransport(proxy=proxy)
|
||||||
|
|
||||||
|
for url in no_proxy_urls:
|
||||||
|
sync_proxy_mounts[url] = httpx.HTTPTransport()
|
||||||
|
async_proxy_mounts[url] = httpx.AsyncHTTPTransport()
|
||||||
|
|
||||||
|
return sync_proxy_mounts, async_proxy_mounts
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(model: Optional[str] = None) -> dict:
|
def validate_environment(model: Optional[str] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Checks if the environment variables are valid for the given model.
|
Checks if the environment variables are valid for the given model.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue