Merge pull request #5292 from OgnjenFrancuski/main

Update SSL verification
This commit is contained in:
Krish Dholakia 2024-08-23 20:42:35 -07:00 committed by GitHub
commit f458f565af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 54 additions and 15 deletions

View file

@ -111,7 +111,7 @@ common_cloud_provider_auth_params: dict = {
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"], "providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
} }
use_client: bool = False use_client: bool = False
ssl_verify: bool = True ssl_verify: Union[str, bool] = True
ssl_certificate: Optional[str] = None ssl_certificate: Optional[str] = None
disable_streaming_logging: bool = False disable_streaming_logging: bool = False
in_memory_llm_clients_cache: dict = {} in_memory_llm_clients_cache: dict = {}

View file

@ -791,7 +791,7 @@ class AzureChatCompletion(BaseLLM):
"api_version": api_version, "api_version": api_version,
"azure_endpoint": api_base, "azure_endpoint": api_base,
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.aclient_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout, "timeout": timeout,
} }
@ -959,7 +959,7 @@ class AzureChatCompletion(BaseLLM):
"api_version": api_version, "api_version": api_version,
"azure_endpoint": api_base, "azure_endpoint": api_base,
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.aclient_session,
"max_retries": data.pop("max_retries", 2), "max_retries": data.pop("max_retries", 2),
"timeout": timeout, "timeout": timeout,
} }
@ -1083,13 +1083,16 @@ class AzureChatCompletion(BaseLLM):
"api_version": api_version, "api_version": api_version,
"azure_endpoint": api_base, "azure_endpoint": api_base,
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout, "timeout": timeout,
} }
azure_client_params = select_azure_base_url_or_endpoint( azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params azure_client_params=azure_client_params
) )
if aembedding:
azure_client_params["http_client"] = litellm.aclient_session
else:
azure_client_params["http_client"] = litellm.client_session
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:

View file

@ -605,6 +605,9 @@ def init_bedrock_client(
aws_web_identity_token, aws_web_identity_token,
) = params_to_check ) = params_to_check
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
### SET REGION NAME ### SET REGION NAME
if region_name: if region_name:
pass pass
@ -673,6 +676,7 @@ def init_bedrock_client(
region_name=region_name, region_name=region_name,
endpoint_url=endpoint_url, endpoint_url=endpoint_url,
config=config, config=config,
verify=ssl_verify,
) )
elif aws_role_name is not None and aws_session_name is not None: elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in # use sts if role name passed in
@ -694,6 +698,7 @@ def init_bedrock_client(
region_name=region_name, region_name=region_name,
endpoint_url=endpoint_url, endpoint_url=endpoint_url,
config=config, config=config,
verify=ssl_verify,
) )
elif aws_access_key_id is not None: elif aws_access_key_id is not None:
# uses auth params passed to completion # uses auth params passed to completion
@ -706,6 +711,7 @@ def init_bedrock_client(
region_name=region_name, region_name=region_name,
endpoint_url=endpoint_url, endpoint_url=endpoint_url,
config=config, config=config,
verify=ssl_verify,
) )
elif aws_profile_name is not None: elif aws_profile_name is not None:
# uses auth values from AWS profile usually stored in ~/.aws/credentials # uses auth values from AWS profile usually stored in ~/.aws/credentials
@ -715,6 +721,7 @@ def init_bedrock_client(
region_name=region_name, region_name=region_name,
endpoint_url=endpoint_url, endpoint_url=endpoint_url,
config=config, config=config,
verify=ssl_verify,
) )
else: else:
# aws_access_key_id is None, assume user is trying to auth using env variables # aws_access_key_id is None, assume user is trying to auth using env variables
@ -725,6 +732,7 @@ def init_bedrock_client(
region_name=region_name, region_name=region_name,
endpoint_url=endpoint_url, endpoint_url=endpoint_url,
config=config, config=config,
verify=ssl_verify,
) )
if extra_headers: if extra_headers:
client.meta.events.register( client.meta.events.register(

View file

@ -35,11 +35,18 @@ class AsyncHTTPHandler:
self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int
) -> httpx.AsyncClient: ) -> httpx.AsyncClient:
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify)) # /path/to/certificate.pem
ssl_verify = os.getenv(
"SSL_VERIFY",
litellm.ssl_verify
)
# An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem
cert = os.getenv( cert = os.getenv(
"SSL_CERTIFICATE", litellm.ssl_certificate "SSL_CERTIFICATE",
) # /path/to/client.pem litellm.ssl_certificate
)
if timeout is None: if timeout is None:
timeout = _DEFAULT_TIMEOUT timeout = _DEFAULT_TIMEOUT
@ -268,11 +275,18 @@ class HTTPHandler:
if timeout is None: if timeout is None:
timeout = _DEFAULT_TIMEOUT timeout = _DEFAULT_TIMEOUT
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = bool(os.getenv("SSL_VERIFY", litellm.ssl_verify)) # /path/to/certificate.pem
ssl_verify = os.getenv(
"SSL_VERIFY",
litellm.ssl_verify
)
# An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem
cert = os.getenv( cert = os.getenv(
"SSL_CERTIFICATE", litellm.ssl_certificate "SSL_CERTIFICATE",
) # /path/to/client.pem litellm.ssl_certificate
)
if client is None: if client is None:
# Create a client with a connection pool # Create a client with a connection pool

View file

@ -616,6 +616,10 @@ class Huggingface(BaseLLM):
}, },
) )
## COMPLETION CALL ## COMPLETION CALL
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
if acompletion is True: if acompletion is True:
### ASYNC STREAMING ### ASYNC STREAMING
if optional_params.get("stream", False): if optional_params.get("stream", False):
@ -630,12 +634,16 @@ class Huggingface(BaseLLM):
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"], stream=optional_params["stream"],
verify=ssl_verify
) )
return response.iter_lines() return response.iter_lines()
### SYNC COMPLETION ### SYNC COMPLETION
else: else:
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data) completion_url,
headers=headers,
data=json.dumps(data),
verify=ssl_verify
) )
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
@ -731,9 +739,12 @@ class Huggingface(BaseLLM):
optional_params: dict, optional_params: dict,
timeout: float, timeout: float,
): ):
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
response = None response = None
try: try:
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
response = await client.post(url=api_base, json=data, headers=headers) response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
@ -785,7 +796,10 @@ class Huggingface(BaseLLM):
model: str, model: str,
timeout: float, timeout: float,
): ):
async with httpx.AsyncClient(timeout=timeout) as client: # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
response = client.stream( response = client.stream(
"POST", url=f"{api_base}", json=data, headers=headers "POST", url=f"{api_base}", json=data, headers=headers
) )