support n param for hf

This commit is contained in:
Krrish Dholakia 2023-10-03 07:10:05 -07:00
parent cd2c1dff2d
commit 512769e841
5 changed files with 5 additions and 5 deletions

View file

@ -117,7 +117,6 @@ def completion(
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
inference_params.pop("task")
past_user_inputs = []
generated_responses = []
text = ""
@ -181,7 +180,6 @@ def completion(
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
inference_params.pop("return_full_text")
inference_params.pop("task")
data = {
"inputs": prompt,
"parameters": inference_params,

View file

@ -351,7 +351,7 @@ def test_completion_cohere_stream_bad_key():
# },
# ]
# response = completion(
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000, n=1
# )
# complete_response = ""
# # Add any assertions here to check the response

View file

@ -1039,7 +1039,7 @@ def get_optional_params( # use the openai defaults
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "huggingface":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop",]
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
_check_valid_arg(supported_params=supported_params)
if temperature:
@ -1055,6 +1055,8 @@ def get_optional_params( # use the openai defaults
optional_params["stop"] = stop
if max_tokens:
optional_params["max_new_tokens"] = max_tokens
if n:
optional_params["best_of"] = n
if presence_penalty:
optional_params["repetition_penalty"] = presence_penalty
elif custom_llm_provider == "together_ai":

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.815"
version = "0.1.816"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"