fix(huggingface_restapi.py): fix linting errors

This commit is contained in:
Krrish Dholakia 2023-11-15 15:34:21 -08:00
parent f84db3ce14
commit a59494571f
2 changed files with 13 additions and 7 deletions

View file

@ -12,11 +12,17 @@ from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception):
def __init__(self, status_code, message):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models")
self.response = httpx.Response(status_code=status_code, request=self.request)
if request is not None:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models")
if response is not None:
self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
@ -252,7 +258,7 @@ class Huggingface(BaseLLM):
task = get_hf_task_for_model(model)
print_verbose(f"{model}, {task}")
completion_url = ""
input_text = None
input_text = ""
if "https" in model:
completion_url = model
elif api_base:
@ -348,10 +354,10 @@ class Huggingface(BaseLLM):
if acompletion is True:
### ASYNC STREAMING
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model)
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
else:
### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params)
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(