mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(huggingface_restapi.py): fix hf embeddings optional param processing
This commit is contained in:
parent
ebe1275fb0
commit
cac91dcae4
2 changed files with 33 additions and 1 deletions
|
@ -856,6 +856,30 @@ class Huggingface(BaseLLM):
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def _process_optional_params(self, data: dict, optional_params: dict) -> dict:
|
||||||
|
special_options_keys = ["use_cache", "wait_for_model"]
|
||||||
|
special_parameters_keys = [
|
||||||
|
"min_length",
|
||||||
|
"max_length",
|
||||||
|
"top_k",
|
||||||
|
"top_p",
|
||||||
|
"temperature",
|
||||||
|
"repetition_penalty",
|
||||||
|
"max_time",
|
||||||
|
]
|
||||||
|
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if k in special_options_keys:
|
||||||
|
data.setdefault("options", {})
|
||||||
|
data["options"][k] = v
|
||||||
|
elif k in special_parameters_keys:
|
||||||
|
data.setdefault("parameters", {})
|
||||||
|
data["parameters"][k] = v
|
||||||
|
else:
|
||||||
|
data[k] = v
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def _transform_input(
|
def _transform_input(
|
||||||
self,
|
self,
|
||||||
input: List,
|
input: List,
|
||||||
|
@ -892,7 +916,9 @@ class Huggingface(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(optional_params.keys()) > 0:
|
if len(optional_params.keys()) > 0:
|
||||||
data["options"] = optional_params
|
data = self._process_optional_params(
|
||||||
|
data=data, optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
|
@ -761,6 +761,8 @@ async def test_hf_embedddings_with_optional_params(sync_mode):
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model="huggingface/jinaai/jina-embeddings-v2-small-en",
|
model="huggingface/jinaai/jina-embeddings-v2-small-en",
|
||||||
input=["good morning from litellm"],
|
input=["good morning from litellm"],
|
||||||
|
top_p=10,
|
||||||
|
top_k=10,
|
||||||
wait_for_model=True,
|
wait_for_model=True,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
@ -768,6 +770,8 @@ async def test_hf_embedddings_with_optional_params(sync_mode):
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
model="huggingface/jinaai/jina-embeddings-v2-small-en",
|
model="huggingface/jinaai/jina-embeddings-v2-small-en",
|
||||||
input=["good morning from litellm"],
|
input=["good morning from litellm"],
|
||||||
|
top_p=10,
|
||||||
|
top_k=10,
|
||||||
wait_for_model=True,
|
wait_for_model=True,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
@ -781,3 +785,5 @@ async def test_hf_embedddings_with_optional_params(sync_mode):
|
||||||
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||||
assert "wait_for_model" in json_data["options"]
|
assert "wait_for_model" in json_data["options"]
|
||||||
assert json_data["options"]["wait_for_model"] is True
|
assert json_data["options"]["wait_for_model"] is True
|
||||||
|
assert json_data["parameters"]["top_p"] == 10
|
||||||
|
assert json_data["parameters"]["top_k"] == 10
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue