fix(huggingface_restapi.py): support passing 'wait_for_model' param on completion calls

This commit is contained in:
Krrish Dholakia 2024-08-09 09:25:19 -07:00
parent 466dc9f32a
commit 0cf81eba62
2 changed files with 29 additions and 2 deletions

View file

@ -131,6 +131,9 @@ class HuggingfaceConfig:
and v is not None
}
def get_special_options_params(self):
return ["use_cache", "wait_for_model"]
def get_supported_openai_params(self):
return [
"stream",
@ -491,6 +494,20 @@ class Huggingface(BaseLLM):
optional_params[k] = v
### MAP INPUT PARAMS
#### HANDLE SPECIAL PARAMS
special_params = HuggingfaceConfig().get_special_options_params()
special_params_dict = {}
# Create a list of keys to pop after iteration
keys_to_pop = []
for k, v in optional_params.items():
if k in special_params:
special_params_dict[k] = v
keys_to_pop.append(k)
# Pop the keys from the dictionary after iteration
for k in keys_to_pop:
optional_params.pop(k)
if task == "conversational":
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
@ -578,6 +595,11 @@ class Huggingface(BaseLLM):
else False
)
input_text = prompt
### RE-ADD SPECIAL PARAMS
if len(special_params_dict.keys()) > 0:
data.update({"options": special_params_dict})
## LOGGING
logging_obj.pre_call(
input=input_text,
@ -857,7 +879,7 @@ class Huggingface(BaseLLM):
return data
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
special_options_keys = ["use_cache", "wait_for_model"]
special_options_keys = HuggingfaceConfig().get_special_options_params()
special_parameters_keys = [
"min_length",
"max_length",