forked from phoenix/litellm-mirror
Update Huggingface provider to utilize the SSL verification through 'SSL_VERIFY' env var or 'litellm.ssl_verify'.
This commit is contained in:
parent
924dfe4226
commit
31aac9a1e4
1 changed files with 17 additions and 3 deletions
|
@ -613,6 +613,10 @@ class Huggingface(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
### ASYNC STREAMING
|
### ASYNC STREAMING
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
|
@ -627,12 +631,16 @@ class Huggingface(BaseLLM):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
stream=optional_params["stream"],
|
stream=optional_params["stream"],
|
||||||
|
verify=ssl_verify
|
||||||
)
|
)
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
### SYNC COMPLETION
|
### SYNC COMPLETION
|
||||||
else:
|
else:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
completion_url, headers=headers, data=json.dumps(data)
|
completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
verify=ssl_verify
|
||||||
)
|
)
|
||||||
|
|
||||||
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
|
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
|
||||||
|
@ -728,9 +736,12 @@ class Huggingface(BaseLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
):
|
):
|
||||||
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
|
||||||
response = await client.post(url=api_base, json=data, headers=headers)
|
response = await client.post(url=api_base, json=data, headers=headers)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
@ -782,7 +793,10 @@ class Huggingface(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
):
|
):
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=timeout, verify=ssl_verify) as client:
|
||||||
response = client.stream(
|
response = client.stream(
|
||||||
"POST", url=f"{api_base}", json=data, headers=headers
|
"POST", url=f"{api_base}", json=data, headers=headers
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue