diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 80fd9592f2..0cfe25c1e6 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -856,6 +856,30 @@ class Huggingface(BaseLLM): return data + def _process_optional_params(self, data: dict, optional_params: dict) -> dict: + special_options_keys = ["use_cache", "wait_for_model"] + special_parameters_keys = [ + "min_length", + "max_length", + "top_k", + "top_p", + "temperature", + "repetition_penalty", + "max_time", + ] + + for k, v in optional_params.items(): + if k in special_options_keys: + data.setdefault("options", {}) + data["options"][k] = v + elif k in special_parameters_keys: + data.setdefault("parameters", {}) + data["parameters"][k] = v + else: + data[k] = v + + return data + def _transform_input( self, input: List, @@ -892,7 +916,9 @@ class Huggingface(BaseLLM): ) if len(optional_params.keys()) > 0: - data["options"] = optional_params + data = self._process_optional_params( + data=data, optional_params=optional_params + ) return data diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 2d2e125eb8..9c1e4aa2cd 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -761,6 +761,8 @@ async def test_hf_embedddings_with_optional_params(sync_mode): response = embedding( model="huggingface/jinaai/jina-embeddings-v2-small-en", input=["good morning from litellm"], + top_p=10, + top_k=10, wait_for_model=True, client=client, ) @@ -768,6 +770,8 @@ async def test_hf_embedddings_with_optional_params(sync_mode): response = await litellm.aembedding( model="huggingface/jinaai/jina-embeddings-v2-small-en", input=["good morning from litellm"], + top_p=10, + top_k=10, wait_for_model=True, client=client, ) @@ -781,3 +785,5 @@ async def test_hf_embedddings_with_optional_params(sync_mode): json_data = json.loads(mock_client.call_args.kwargs["data"]) assert "wait_for_model" in json_data["options"] assert json_data["options"]["wait_for_model"] is True + assert json_data["parameters"]["top_p"] == 10 + assert json_data["parameters"]["top_k"] == 10