mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
support n param for hf
This commit is contained in:
parent
cd2c1dff2d
commit
512769e841
5 changed files with 5 additions and 5 deletions
Binary file not shown.
|
@ -117,7 +117,6 @@ def completion(
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
inference_params.pop("details")
|
inference_params.pop("details")
|
||||||
inference_params.pop("return_full_text")
|
inference_params.pop("return_full_text")
|
||||||
inference_params.pop("task")
|
|
||||||
past_user_inputs = []
|
past_user_inputs = []
|
||||||
generated_responses = []
|
generated_responses = []
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -181,7 +180,6 @@ def completion(
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
inference_params.pop("details")
|
inference_params.pop("details")
|
||||||
inference_params.pop("return_full_text")
|
inference_params.pop("return_full_text")
|
||||||
inference_params.pop("task")
|
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": inference_params,
|
"parameters": inference_params,
|
||||||
|
|
|
@ -351,7 +351,7 @@ def test_completion_cohere_stream_bad_key():
|
||||||
# },
|
# },
|
||||||
# ]
|
# ]
|
||||||
# response = completion(
|
# response = completion(
|
||||||
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
|
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000, n=1
|
||||||
# )
|
# )
|
||||||
# complete_response = ""
|
# complete_response = ""
|
||||||
# # Add any assertions here to check the response
|
# # Add any assertions here to check the response
|
||||||
|
|
|
@ -1039,7 +1039,7 @@ def get_optional_params( # use the openai defaults
|
||||||
optional_params["stop_sequences"] = stop
|
optional_params["stop_sequences"] = stop
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop",]
|
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
|
||||||
if temperature:
|
if temperature:
|
||||||
|
@ -1055,6 +1055,8 @@ def get_optional_params( # use the openai defaults
|
||||||
optional_params["stop"] = stop
|
optional_params["stop"] = stop
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
optional_params["max_new_tokens"] = max_tokens
|
optional_params["max_new_tokens"] = max_tokens
|
||||||
|
if n:
|
||||||
|
optional_params["best_of"] = n
|
||||||
if presence_penalty:
|
if presence_penalty:
|
||||||
optional_params["repetition_penalty"] = presence_penalty
|
optional_params["repetition_penalty"] = presence_penalty
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.815"
|
version = "0.1.816"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue