diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 0cfe25c1e..06ef0e6fc 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -131,6 +131,9 @@ class HuggingfaceConfig: and v is not None } + def get_special_options_params(self): + return ["use_cache", "wait_for_model"] + def get_supported_openai_params(self): return [ "stream", @@ -491,6 +494,20 @@ class Huggingface(BaseLLM): optional_params[k] = v ### MAP INPUT PARAMS + #### HANDLE SPECIAL PARAMS + special_params = HuggingfaceConfig().get_special_options_params() + special_params_dict = {} + # Create a list of keys to pop after iteration + keys_to_pop = [] + + for k, v in optional_params.items(): + if k in special_params: + special_params_dict[k] = v + keys_to_pop.append(k) + + # Pop the keys from the dictionary after iteration + for k in keys_to_pop: + optional_params.pop(k) if task == "conversational": inference_params = copy.deepcopy(optional_params) inference_params.pop("details") @@ -578,6 +595,11 @@ class Huggingface(BaseLLM): else False ) input_text = prompt + + ### RE-ADD SPECIAL PARAMS + if len(special_params_dict.keys()) > 0: + data.update({"options": special_params_dict}) + ## LOGGING logging_obj.pre_call( input=input_text, @@ -857,7 +879,7 @@ class Huggingface(BaseLLM): return data def _process_optional_params(self, data: dict, optional_params: dict) -> dict: - special_options_keys = ["use_cache", "wait_for_model"] + special_options_keys = HuggingfaceConfig().get_special_options_params() special_parameters_keys = [ "min_length", "max_length", diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index c73fcddb6..c4426a243 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1819,14 +1819,19 @@ def tgi_mock_post(url, data=None, json=None, headers=None): def test_hf_test_completion_tgi(): litellm.set_verbose = True try: - with patch("requests.post", side_effect=tgi_mock_post): + with patch("requests.post", side_effect=tgi_mock_post) as mock_client: response = completion( model="huggingface/HuggingFaceH4/zephyr-7b-beta", messages=[{"content": "Hello, how are you?", "role": "user"}], max_tokens=10, + wait_for_model=True, ) # Add any assertions-here to check the response print(response) + assert "options" in mock_client.call_args.kwargs["data"] + json_data = json.loads(mock_client.call_args.kwargs["data"]) + assert "wait_for_model" in json_data["options"] + assert json_data["options"]["wait_for_model"] is True except litellm.ServiceUnavailableError as e: pass except Exception as e: