From 512769e84105a24fd7118ba18b71805ab94c28d3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 3 Oct 2023 07:10:05 -0700 Subject: [PATCH] support n param for hf --- litellm/__pycache__/utils.cpython-311.pyc | Bin 139358 -> 139395 bytes litellm/llms/huggingface_restapi.py | 2 -- litellm/tests/test_streaming.py | 2 +- litellm/utils.py | 4 +++- pyproject.toml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 6459cd2d8adb31965cd5580ce5ef0f1155c12788..9cef80ae0463a4dfda4c047a3e04743ff45a5ff0 100644 GIT binary patch delta 654 zcmZXPSxD4z9K`qY+rO@xtC>gTQJF24M;8@E6j9)Vl}H)2VDyk1WkrXA6xb=1L~hcb z+LCq?=`Je!D_`u=A~BXUH4_7mL@1?)woDSj%6)Da=4EDJrne!a-@ghJf=R;LH}Iigk*1`FkB742bV08G$1aw~5lm2iotmK_gU(9&sm{8IssY z#{}x!Bk$;By}WJZlA=NK-t!Zh#PNYlDQo*RIv6B5ZCDu delta 659 zcmZXPSxA&&6o%*dW^)_UEO()@tR%uNq_8LmTT}>IsLU=5CM*M`1uW3OL`I5?`l+AH znZ!^_h0?D$iIFLUkriSaWhqu<8KI&=P*zU29bV4G`|uvlThLlTCK$O_$z95oPS?(lwE=wgXo7D6Y&l7tjVYNtMWLN2aaV9aktR=1qZL_Z=~=$8)6B2rIN9dMO1ju9 z_2=kQC$7QCJ!|Nx-ezMrSQp0SSsK&r?-ESBC zvqk%82Q7s)tWn&T=30sry|S>5$1o(L^=$C>O1(~mZRX#MX13Z0lq7xgmrMRL70;Q3hf?yAj6f#L?g73LDteIJiWd^{fhSH@$eXW(IcH)cV?1(l zO!N3@ycQQ3n1I;s{jB1 diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 90abdb2d2..0739d1e92 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -117,7 +117,6 @@ def completion( inference_params = copy.deepcopy(optional_params) inference_params.pop("details") inference_params.pop("return_full_text") - inference_params.pop("task") past_user_inputs = [] generated_responses = [] text = "" @@ -181,7 +180,6 @@ def completion( inference_params = copy.deepcopy(optional_params) inference_params.pop("details") inference_params.pop("return_full_text") - inference_params.pop("task") data = { "inputs": prompt, "parameters": inference_params, diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index f016ae97d..1bf8eadf9 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -351,7 +351,7 @@ def test_completion_cohere_stream_bad_key(): # }, # ] # response = completion( -# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000 +# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000, n=1 # ) # complete_response = "" # # Add any assertions here to check the response diff --git a/litellm/utils.py b/litellm/utils.py index b7995863d..122659829 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1039,7 +1039,7 @@ def get_optional_params( # use the openai defaults optional_params["stop_sequences"] = stop elif custom_llm_provider == "huggingface": ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop",] + supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] _check_valid_arg(supported_params=supported_params) if temperature: @@ -1055,6 +1055,8 @@ def get_optional_params( # use the openai defaults optional_params["stop"] = stop if max_tokens: optional_params["max_new_tokens"] = max_tokens + if n: + optional_params["best_of"] = n if presence_penalty: optional_params["repetition_penalty"] = presence_penalty elif custom_llm_provider == "together_ai": diff --git a/pyproject.toml b/pyproject.toml index eedbb81a7..1bc0aa326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.815" +version = "0.1.816" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"