Update Huggingface provider to utilize the SSL verification through 'SSL_VERIFY' env var or 'litellm.ssl_verify'.

This commit is contained in:
Ognjen Francuski 2024-08-20 14:55:12 +02:00
parent 924dfe4226
commit 31aac9a1e4

View file

@ -613,6 +613,10 @@ class Huggingface(BaseLLM):
}, },
) )
## COMPLETION CALL ## COMPLETION CALL
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
if acompletion is True: if acompletion is True:
### ASYNC STREAMING ### ASYNC STREAMING
if optional_params.get("stream", False): if optional_params.get("stream", False):
@ -627,12 +631,16 @@ class Huggingface(BaseLLM):
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"], stream=optional_params["stream"],
verify=ssl_verify
) )
return response.iter_lines() return response.iter_lines()
### SYNC COMPLETION ### SYNC COMPLETION
else: else:
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data) completion_url,
headers=headers,
data=json.dumps(data),
verify=ssl_verify
) )
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
@ -728,9 +736,12 @@ class Huggingface(BaseLLM):
optional_params: dict, optional_params: dict,
timeout: float, timeout: float,
): ):
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
response = None response = None
try: try:
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
response = await client.post(url=api_base, json=data, headers=headers) response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
@ -782,7 +793,10 @@ class Huggingface(BaseLLM):
model: str, model: str,
timeout: float, timeout: float,
): ):
async with httpx.AsyncClient(timeout=timeout) as client: # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
response = client.stream( response = client.stream(
"POST", url=f"{api_base}", json=data, headers=headers "POST", url=f"{api_base}", json=data, headers=headers
) )