fix(huggingface_restapi.py): fix linting errors

This commit is contained in:
Krrish Dholakia 2023-11-15 15:34:21 -08:00
parent f84db3ce14
commit a59494571f
2 changed files with 13 additions and 7 deletions

View file

@ -12,11 +12,17 @@ from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models") if request is not None:
self.response = httpx.Response(status_code=status_code, request=self.request) self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models")
if response is not None:
self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -252,7 +258,7 @@ class Huggingface(BaseLLM):
task = get_hf_task_for_model(model) task = get_hf_task_for_model(model)
print_verbose(f"{model}, {task}") print_verbose(f"{model}, {task}")
completion_url = "" completion_url = ""
input_text = None input_text = ""
if "https" in model: if "https" in model:
completion_url = model completion_url = model
elif api_base: elif api_base:
@ -348,10 +354,10 @@ class Huggingface(BaseLLM):
if acompletion is True: if acompletion is True:
### ASYNC STREAMING ### ASYNC STREAMING
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
else: else:
### ASYNC COMPLETION ### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
### SYNC STREAMING ### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post( response = requests.post(

View file

@ -1779,7 +1779,7 @@ def embedding(
or get_secret("HUGGINGFACE_API_KEY") or get_secret("HUGGINGFACE_API_KEY")
or litellm.api_key or litellm.api_key
) )
response = huggingface_restapi.embedding( response = huggingface.embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,