forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): support passing 'wait_for_model' param on completion calls
This commit is contained in:
parent
466dc9f32a
commit
0cf81eba62
2 changed files with 29 additions and 2 deletions
|
@ -131,6 +131,9 @@ class HuggingfaceConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_special_options_params(self):
|
||||
return ["use_cache", "wait_for_model"]
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
|
@ -491,6 +494,20 @@ class Huggingface(BaseLLM):
|
|||
optional_params[k] = v
|
||||
|
||||
### MAP INPUT PARAMS
|
||||
#### HANDLE SPECIAL PARAMS
|
||||
special_params = HuggingfaceConfig().get_special_options_params()
|
||||
special_params_dict = {}
|
||||
# Create a list of keys to pop after iteration
|
||||
keys_to_pop = []
|
||||
|
||||
for k, v in optional_params.items():
|
||||
if k in special_params:
|
||||
special_params_dict[k] = v
|
||||
keys_to_pop.append(k)
|
||||
|
||||
# Pop the keys from the dictionary after iteration
|
||||
for k in keys_to_pop:
|
||||
optional_params.pop(k)
|
||||
if task == "conversational":
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop("details")
|
||||
|
@ -578,6 +595,11 @@ class Huggingface(BaseLLM):
|
|||
else False
|
||||
)
|
||||
input_text = prompt
|
||||
|
||||
### RE-ADD SPECIAL PARAMS
|
||||
if len(special_params_dict.keys()) > 0:
|
||||
data.update({"options": special_params_dict})
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_text,
|
||||
|
@ -857,7 +879,7 @@ class Huggingface(BaseLLM):
|
|||
return data
|
||||
|
||||
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
|
||||
special_options_keys = ["use_cache", "wait_for_model"]
|
||||
special_options_keys = HuggingfaceConfig().get_special_options_params()
|
||||
special_parameters_keys = [
|
||||
"min_length",
|
||||
"max_length",
|
||||
|
|
|
@ -1819,14 +1819,19 @@ def tgi_mock_post(url, data=None, json=None, headers=None):
|
|||
def test_hf_test_completion_tgi():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
with patch("requests.post", side_effect=tgi_mock_post):
|
||||
with patch("requests.post", side_effect=tgi_mock_post) as mock_client:
|
||||
response = completion(
|
||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
max_tokens=10,
|
||||
wait_for_model=True,
|
||||
)
|
||||
# Add any assertions-here to check the response
|
||||
print(response)
|
||||
assert "options" in mock_client.call_args.kwargs["data"]
|
||||
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||
assert "wait_for_model" in json_data["options"]
|
||||
assert json_data["options"]["wait_for_model"] is True
|
||||
except litellm.ServiceUnavailableError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue