forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): support passing 'wait_for_model' param on completion calls
This commit is contained in:
parent
466dc9f32a
commit
0cf81eba62
2 changed files with 29 additions and 2 deletions
|
@ -131,6 +131,9 @@ class HuggingfaceConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_special_options_params(self):
|
||||||
|
return ["use_cache", "wait_for_model"]
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self):
|
||||||
return [
|
return [
|
||||||
"stream",
|
"stream",
|
||||||
|
@ -491,6 +494,20 @@ class Huggingface(BaseLLM):
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
### MAP INPUT PARAMS
|
### MAP INPUT PARAMS
|
||||||
|
#### HANDLE SPECIAL PARAMS
|
||||||
|
special_params = HuggingfaceConfig().get_special_options_params()
|
||||||
|
special_params_dict = {}
|
||||||
|
# Create a list of keys to pop after iteration
|
||||||
|
keys_to_pop = []
|
||||||
|
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if k in special_params:
|
||||||
|
special_params_dict[k] = v
|
||||||
|
keys_to_pop.append(k)
|
||||||
|
|
||||||
|
# Pop the keys from the dictionary after iteration
|
||||||
|
for k in keys_to_pop:
|
||||||
|
optional_params.pop(k)
|
||||||
if task == "conversational":
|
if task == "conversational":
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
inference_params.pop("details")
|
inference_params.pop("details")
|
||||||
|
@ -578,6 +595,11 @@ class Huggingface(BaseLLM):
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
|
|
||||||
|
### RE-ADD SPECIAL PARAMS
|
||||||
|
if len(special_params_dict.keys()) > 0:
|
||||||
|
data.update({"options": special_params_dict})
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input_text,
|
input=input_text,
|
||||||
|
@ -857,7 +879,7 @@ class Huggingface(BaseLLM):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
|
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
|
||||||
special_options_keys = ["use_cache", "wait_for_model"]
|
special_options_keys = HuggingfaceConfig().get_special_options_params()
|
||||||
special_parameters_keys = [
|
special_parameters_keys = [
|
||||||
"min_length",
|
"min_length",
|
||||||
"max_length",
|
"max_length",
|
||||||
|
|
|
@ -1819,14 +1819,19 @@ def tgi_mock_post(url, data=None, json=None, headers=None):
|
||||||
def test_hf_test_completion_tgi():
|
def test_hf_test_completion_tgi():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
try:
|
try:
|
||||||
with patch("requests.post", side_effect=tgi_mock_post):
|
with patch("requests.post", side_effect=tgi_mock_post) as mock_client:
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
wait_for_model=True,
|
||||||
)
|
)
|
||||||
# Add any assertions-here to check the response
|
# Add any assertions-here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
assert "options" in mock_client.call_args.kwargs["data"]
|
||||||
|
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||||
|
assert "wait_for_model" in json_data["options"]
|
||||||
|
assert json_data["options"]["wait_for_model"] is True
|
||||||
except litellm.ServiceUnavailableError as e:
|
except litellm.ServiceUnavailableError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue