fix(huggingface_restapi.py): support passing 'wait_for_model' param on completion calls

This commit is contained in:
Krrish Dholakia 2024-08-09 09:25:19 -07:00
parent 466dc9f32a
commit 0cf81eba62
2 changed files with 29 additions and 2 deletions

View file

@ -131,6 +131,9 @@ class HuggingfaceConfig:
and v is not None
}
def get_special_options_params(self):
return ["use_cache", "wait_for_model"]
def get_supported_openai_params(self):
return [
"stream",
@ -491,6 +494,20 @@ class Huggingface(BaseLLM):
optional_params[k] = v
### MAP INPUT PARAMS
#### HANDLE SPECIAL PARAMS
special_params = HuggingfaceConfig().get_special_options_params()
special_params_dict = {}
# Create a list of keys to pop after iteration
keys_to_pop = []
for k, v in optional_params.items():
if k in special_params:
special_params_dict[k] = v
keys_to_pop.append(k)
# Pop the keys from the dictionary after iteration
for k in keys_to_pop:
optional_params.pop(k)
if task == "conversational":
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
@ -578,6 +595,11 @@ class Huggingface(BaseLLM):
else False
)
input_text = prompt
### RE-ADD SPECIAL PARAMS
if len(special_params_dict.keys()) > 0:
data.update({"options": special_params_dict})
## LOGGING
logging_obj.pre_call(
input=input_text,
@ -857,7 +879,7 @@ class Huggingface(BaseLLM):
return data
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
special_options_keys = ["use_cache", "wait_for_model"]
special_options_keys = HuggingfaceConfig().get_special_options_params()
special_parameters_keys = [
"min_length",
"max_length",

View file

@ -1819,14 +1819,19 @@ def tgi_mock_post(url, data=None, json=None, headers=None):
def test_hf_test_completion_tgi():
litellm.set_verbose = True
try:
with patch("requests.post", side_effect=tgi_mock_post):
with patch("requests.post", side_effect=tgi_mock_post) as mock_client:
response = completion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}],
max_tokens=10,
wait_for_model=True,
)
# Add any assertions-here to check the response
print(response)
assert "options" in mock_client.call_args.kwargs["data"]
json_data = json.loads(mock_client.call_args.kwargs["data"])
assert "wait_for_model" in json_data["options"]
assert json_data["options"]["wait_for_model"] is True
except litellm.ServiceUnavailableError as e:
pass
except Exception as e: