fix(utils.py): support sagemaker llama2 custom endpoints

This commit is contained in:
Krrish Dholakia 2023-12-05 16:04:43 -08:00
parent 09c2c1610d
commit b4c78c7b9e
4 changed files with 53 additions and 45 deletions

View file

@ -2206,32 +2206,31 @@ def get_optional_params( # use the openai defaults
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "sagemaker":
if "llama-2" in model.lower() or (
"llama" in model.lower() and "2" in model.lower() # some combination of llama and "2" should exist
): # jumpstart can also send "Llama-2-70b-chat-hf-48xlarge"
# llama-2 models on sagemaker support the following args
"""
max_new_tokens: Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer.
temperature: Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float.
top_p: In each step of text generation, sample from the smallest possible set of words with cumulative probability top_p. If specified, it must be a float between 0 and 1.
return_full_text: If True, input text will be part of the output generated text. If specified, it must be boolean. The default value for it is False.
"""
## check if unsupported param passed in
supported_params = ["temperature", "max_tokens", "stream"]
_check_valid_arg(supported_params=supported_params)
if max_tokens is not None:
optional_params["max_new_tokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stream:
optional_params["stream"] = stream
else:
## check if unsupported param passed in
supported_params = []
_check_valid_arg(supported_params=supported_params)
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
_check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if temperature is not None:
if temperature == 0.0 or temperature == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
temperature = 0.01
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if max_tokens == 0:
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
elif custom_llm_provider == "bedrock":
if "ai21" in model:
supported_params = ["max_tokens", "temperature", "top_p", "stream"]
@ -5270,11 +5269,9 @@ class CustomStreamWrapper:
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
chunk_size = 30
new_chunk = self.completion_stream[:chunk_size]
new_chunk = self.completion_stream
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
self.completion_stream = self.completion_stream[len(self.completion_stream):]
elif self.custom_llm_provider == "petals":
if len(self.completion_stream)==0:
if self.sent_last_chunk: