Merge pull request #5292 from OgnjenFrancuski/main

Update SSL verification
This commit is contained in:
Krish Dholakia 2024-08-23 20:42:35 -07:00 committed by GitHub
commit f458f565af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 54 additions and 15 deletions

View file

@ -111,7 +111,7 @@ common_cloud_provider_auth_params: dict = {
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
}
use_client: bool = False
ssl_verify: bool = True
ssl_verify: Union[str, bool] = True
ssl_certificate: Optional[str] = None
disable_streaming_logging: bool = False
in_memory_llm_clients_cache: dict = {}

View file

@ -791,7 +791,7 @@ class AzureChatCompletion(BaseLLM):
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"http_client": litellm.aclient_session,
"max_retries": max_retries,
"timeout": timeout,
}
@ -959,7 +959,7 @@ class AzureChatCompletion(BaseLLM):
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"http_client": litellm.aclient_session,
"max_retries": data.pop("max_retries", 2),
"timeout": timeout,
}
@ -1083,13 +1083,16 @@ class AzureChatCompletion(BaseLLM):
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if aembedding:
azure_client_params["http_client"] = litellm.aclient_session
else:
azure_client_params["http_client"] = litellm.client_session
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:

View file

@ -605,6 +605,9 @@ def init_bedrock_client(
aws_web_identity_token,
) = params_to_check
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
### SET REGION NAME
if region_name:
pass
@ -673,6 +676,7 @@ def init_bedrock_client(
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in
@ -694,6 +698,7 @@ def init_bedrock_client(
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
elif aws_access_key_id is not None:
# uses auth params passed to completion
@ -706,6 +711,7 @@ def init_bedrock_client(
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
elif aws_profile_name is not None:
# uses auth values from AWS profile usually stored in ~/.aws/credentials
@ -715,6 +721,7 @@ def init_bedrock_client(
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
@ -725,6 +732,7 @@ def init_bedrock_client(
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
if extra_headers:
client.meta.events.register(

View file

@ -35,11 +35,18 @@ class AsyncHTTPHandler:
self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int
) -> httpx.AsyncClient:
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem
ssl_verify = os.getenv(
"SSL_VERIFY",
litellm.ssl_verify
)
# An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem
cert = os.getenv(
"SSL_CERTIFICATE", litellm.ssl_certificate
) # /path/to/client.pem
"SSL_CERTIFICATE",
litellm.ssl_certificate
)
if timeout is None:
timeout = _DEFAULT_TIMEOUT
@ -268,11 +275,18 @@ class HTTPHandler:
if timeout is None:
timeout = _DEFAULT_TIMEOUT
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify))
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem
ssl_verify = os.getenv(
"SSL_VERIFY",
litellm.ssl_verify
)
# An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem
cert = os.getenv(
"SSL_CERTIFICATE", litellm.ssl_certificate
) # /path/to/client.pem
"SSL_CERTIFICATE",
litellm.ssl_certificate
)
if client is None:
# Create a client with a connection pool

View file

@ -616,6 +616,10 @@ class Huggingface(BaseLLM):
},
)
## COMPLETION CALL
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
if acompletion is True:
### ASYNC STREAMING
if optional_params.get("stream", False):
@ -630,12 +634,16 @@ class Huggingface(BaseLLM):
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
verify=ssl_verify
)
return response.iter_lines()
### SYNC COMPLETION
else:
response = requests.post(
completion_url, headers=headers, data=json.dumps(data)
completion_url,
headers=headers,
data=json.dumps(data),
verify=ssl_verify
)
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
@ -731,9 +739,12 @@ class Huggingface(BaseLLM):
optional_params: dict,
timeout: float,
):
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
response = None
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
@ -785,7 +796,10 @@ class Huggingface(BaseLLM):
model: str,
timeout: float,
):
async with httpx.AsyncClient(timeout=timeout) as client:
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
response = client.stream(
"POST", url=f"{api_base}", json=data, headers=headers
)