forked from phoenix/litellm-mirror
Merge pull request #5292 from OgnjenFrancuski/main
Update SSL verification
This commit is contained in:
commit
f458f565af
5 changed files with 54 additions and 15 deletions
|
@ -111,7 +111,7 @@ common_cloud_provider_auth_params: dict = {
|
||||||
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
|
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
|
||||||
}
|
}
|
||||||
use_client: bool = False
|
use_client: bool = False
|
||||||
ssl_verify: bool = True
|
ssl_verify: Union[str, bool] = True
|
||||||
ssl_certificate: Optional[str] = None
|
ssl_certificate: Optional[str] = None
|
||||||
disable_streaming_logging: bool = False
|
disable_streaming_logging: bool = False
|
||||||
in_memory_llm_clients_cache: dict = {}
|
in_memory_llm_clients_cache: dict = {}
|
||||||
|
|
|
@ -791,7 +791,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"azure_endpoint": api_base,
|
"azure_endpoint": api_base,
|
||||||
"azure_deployment": model,
|
"azure_deployment": model,
|
||||||
"http_client": litellm.client_session,
|
"http_client": litellm.aclient_session,
|
||||||
"max_retries": max_retries,
|
"max_retries": max_retries,
|
||||||
"timeout": timeout,
|
"timeout": timeout,
|
||||||
}
|
}
|
||||||
|
@ -959,7 +959,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"azure_endpoint": api_base,
|
"azure_endpoint": api_base,
|
||||||
"azure_deployment": model,
|
"azure_deployment": model,
|
||||||
"http_client": litellm.client_session,
|
"http_client": litellm.aclient_session,
|
||||||
"max_retries": data.pop("max_retries", 2),
|
"max_retries": data.pop("max_retries", 2),
|
||||||
"timeout": timeout,
|
"timeout": timeout,
|
||||||
}
|
}
|
||||||
|
@ -1083,13 +1083,16 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"azure_endpoint": api_base,
|
"azure_endpoint": api_base,
|
||||||
"azure_deployment": model,
|
"azure_deployment": model,
|
||||||
"http_client": litellm.client_session,
|
|
||||||
"max_retries": max_retries,
|
"max_retries": max_retries,
|
||||||
"timeout": timeout,
|
"timeout": timeout,
|
||||||
}
|
}
|
||||||
azure_client_params = select_azure_base_url_or_endpoint(
|
azure_client_params = select_azure_base_url_or_endpoint(
|
||||||
azure_client_params=azure_client_params
|
azure_client_params=azure_client_params
|
||||||
)
|
)
|
||||||
|
if aembedding:
|
||||||
|
azure_client_params["http_client"] = litellm.aclient_session
|
||||||
|
else:
|
||||||
|
azure_client_params["http_client"] = litellm.client_session
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
|
|
@ -605,6 +605,9 @@ def init_bedrock_client(
|
||||||
aws_web_identity_token,
|
aws_web_identity_token,
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
|
|
||||||
### SET REGION NAME
|
### SET REGION NAME
|
||||||
if region_name:
|
if region_name:
|
||||||
pass
|
pass
|
||||||
|
@ -673,6 +676,7 @@ def init_bedrock_client(
|
||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
endpoint_url=endpoint_url,
|
endpoint_url=endpoint_url,
|
||||||
config=config,
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
)
|
)
|
||||||
elif aws_role_name is not None and aws_session_name is not None:
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
# use sts if role name passed in
|
# use sts if role name passed in
|
||||||
|
@ -694,6 +698,7 @@ def init_bedrock_client(
|
||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
endpoint_url=endpoint_url,
|
endpoint_url=endpoint_url,
|
||||||
config=config,
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
)
|
)
|
||||||
elif aws_access_key_id is not None:
|
elif aws_access_key_id is not None:
|
||||||
# uses auth params passed to completion
|
# uses auth params passed to completion
|
||||||
|
@ -706,6 +711,7 @@ def init_bedrock_client(
|
||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
endpoint_url=endpoint_url,
|
endpoint_url=endpoint_url,
|
||||||
config=config,
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
)
|
)
|
||||||
elif aws_profile_name is not None:
|
elif aws_profile_name is not None:
|
||||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
@ -715,6 +721,7 @@ def init_bedrock_client(
|
||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
endpoint_url=endpoint_url,
|
endpoint_url=endpoint_url,
|
||||||
config=config,
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||||
|
@ -725,6 +732,7 @@ def init_bedrock_client(
|
||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
endpoint_url=endpoint_url,
|
endpoint_url=endpoint_url,
|
||||||
config=config,
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
)
|
)
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
client.meta.events.register(
|
client.meta.events.register(
|
||||||
|
|
|
@ -35,11 +35,18 @@ class AsyncHTTPHandler:
|
||||||
self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int
|
self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int
|
||||||
) -> httpx.AsyncClient:
|
) -> httpx.AsyncClient:
|
||||||
|
|
||||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
|
# /path/to/certificate.pem
|
||||||
|
ssl_verify = os.getenv(
|
||||||
|
"SSL_VERIFY",
|
||||||
|
litellm.ssl_verify
|
||||||
|
)
|
||||||
|
# An SSL certificate used by the requested host to authenticate the client.
|
||||||
|
# /path/to/client.pem
|
||||||
cert = os.getenv(
|
cert = os.getenv(
|
||||||
"SSL_CERTIFICATE", litellm.ssl_certificate
|
"SSL_CERTIFICATE",
|
||||||
) # /path/to/client.pem
|
litellm.ssl_certificate
|
||||||
|
)
|
||||||
|
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = _DEFAULT_TIMEOUT
|
timeout = _DEFAULT_TIMEOUT
|
||||||
|
@ -268,11 +275,18 @@ class HTTPHandler:
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = _DEFAULT_TIMEOUT
|
timeout = _DEFAULT_TIMEOUT
|
||||||
|
|
||||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
|
# /path/to/certificate.pem
|
||||||
|
ssl_verify = os.getenv(
|
||||||
|
"SSL_VERIFY",
|
||||||
|
litellm.ssl_verify
|
||||||
|
)
|
||||||
|
# An SSL certificate used by the requested host to authenticate the client.
|
||||||
|
# /path/to/client.pem
|
||||||
cert = os.getenv(
|
cert = os.getenv(
|
||||||
"SSL_CERTIFICATE", litellm.ssl_certificate
|
"SSL_CERTIFICATE",
|
||||||
) # /path/to/client.pem
|
litellm.ssl_certificate
|
||||||
|
)
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
|
|
|
@ -616,6 +616,10 @@ class Huggingface(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
### ASYNC STREAMING
|
### ASYNC STREAMING
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
|
@ -630,12 +634,16 @@ class Huggingface(BaseLLM):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
stream=optional_params["stream"],
|
stream=optional_params["stream"],
|
||||||
|
verify=ssl_verify
|
||||||
)
|
)
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
### SYNC COMPLETION
|
### SYNC COMPLETION
|
||||||
else:
|
else:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
completion_url, headers=headers, data=json.dumps(data)
|
completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
verify=ssl_verify
|
||||||
)
|
)
|
||||||
|
|
||||||
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
|
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
|
||||||
|
@ -731,9 +739,12 @@ class Huggingface(BaseLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
):
|
):
|
||||||
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
|
||||||
response = await client.post(url=api_base, json=data, headers=headers)
|
response = await client.post(url=api_base, json=data, headers=headers)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
@ -785,7 +796,10 @@ class Huggingface(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
):
|
):
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
|
||||||
response = client.stream(
|
response = client.stream(
|
||||||
"POST", url=f"{api_base}", json=data, headers=headers
|
"POST", url=f"{api_base}", json=data, headers=headers
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue