support n param for hf

This commit is contained in:
Krrish Dholakia 2023-10-03 07:10:05 -07:00
parent cd2c1dff2d
commit 512769e841
5 changed files with 5 additions and 5 deletions

View file

@ -117,7 +117,6 @@ def completion(
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params.pop("details") inference_params.pop("details")
inference_params.pop("return_full_text") inference_params.pop("return_full_text")
inference_params.pop("task")
past_user_inputs = [] past_user_inputs = []
generated_responses = [] generated_responses = []
text = "" text = ""
@ -181,7 +180,6 @@ def completion(
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params.pop("details") inference_params.pop("details")
inference_params.pop("return_full_text") inference_params.pop("return_full_text")
inference_params.pop("task")
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, "parameters": inference_params,

View file

@ -351,7 +351,7 @@ def test_completion_cohere_stream_bad_key():
# }, # },
# ] # ]
# response = completion( # response = completion(
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000 # model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000, n=1
# ) # )
# complete_response = "" # complete_response = ""
# # Add any assertions here to check the response # # Add any assertions here to check the response

View file

@ -1039,7 +1039,7 @@ def get_optional_params( # use the openai defaults
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop",] supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if temperature: if temperature:
@ -1055,6 +1055,8 @@ def get_optional_params( # use the openai defaults
optional_params["stop"] = stop optional_params["stop"] = stop
if max_tokens: if max_tokens:
optional_params["max_new_tokens"] = max_tokens optional_params["max_new_tokens"] = max_tokens
if n:
optional_params["best_of"] = n
if presence_penalty: if presence_penalty:
optional_params["repetition_penalty"] = presence_penalty optional_params["repetition_penalty"] = presence_penalty
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.815" version = "0.1.816"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"