fix(utils.py): support sagemaker llama2 custom endpoints

This commit is contained in:
Krrish Dholakia 2023-12-05 16:04:43 -08:00
parent 09c2c1610d
commit b4c78c7b9e
4 changed files with 53 additions and 45 deletions

View file

@ -9,6 +9,7 @@ from litellm.utils import ModelResponse, get_secret, Usage
import sys import sys
from copy import deepcopy from copy import deepcopy
import httpx import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class SagemakerError(Exception): class SagemakerError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -61,6 +62,7 @@ def completion(
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
custom_prompt_dict={},
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -107,19 +109,23 @@ def completion(
inference_params[k] = v inference_params[k] = v
model = model model = model
prompt = "" if model in custom_prompt_dict:
for message in messages: # check if the model has a registered custom prompt
if "role" in message: model_prompt_details = custom_prompt_dict[model]
if message["role"] == "user": prompt = custom_prompt(
prompt += ( role_dict=model_prompt_details.get("roles", None),
f"{message['content']}" initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
) final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages
)
else:
hf_model_name = model
if "jumpstart-dft-meta-textgeneration-llama" in model: # llama2 model
if model.endswith("-f") or "-f-" in model: # sagemaker default for a chat model
hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model
else: else:
prompt += ( hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template
f"{message['content']}" prompt = prompt_factory(model=hf_model_name, messages=messages)
)
else:
prompt += f"{message['content']}"
data = json.dumps({ data = json.dumps({
"inputs": prompt, "inputs": prompt,

View file

@ -1166,6 +1166,7 @@ def completion(
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
custom_prompt_dict=custom_prompt_dict,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging logging_obj=logging

View file

@ -1048,20 +1048,24 @@ def test_completion_sagemaker():
def test_completion_chat_sagemaker(): def test_completion_chat_sagemaker():
try: try:
messages = [{"role": "user", "content": "Hey, how's it going?"}]
print("testing sagemaker") print("testing sagemaker")
litellm.set_verbose=True litellm.set_verbose=True
response = completion( response = completion(
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f", model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f",
messages=messages, messages=messages,
max_tokens=100,
stream=True, stream=True,
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) complete_response = ""
for chunk in response: for chunk in response:
print(chunk) complete_response += chunk.choices[0].delta.content or ""
print(f"complete_response: {complete_response}")
assert len(complete_response) > 0
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_chat_sagemaker() test_completion_chat_sagemaker()
def test_completion_bedrock_titan(): def test_completion_bedrock_titan():
try: try:

View file

@ -2206,32 +2206,31 @@ def get_optional_params( # use the openai defaults
if max_tokens is not None: if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
if "llama-2" in model.lower() or ( ## check if unsupported param passed in
"llama" in model.lower() and "2" in model.lower() # some combination of llama and "2" should exist supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
): # jumpstart can also send "Llama-2-70b-chat-hf-48xlarge" _check_valid_arg(supported_params=supported_params)
# llama-2 models on sagemaker support the following args # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
""" if temperature is not None:
max_new_tokens: Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer. if temperature == 0.0 or temperature == 0:
temperature: Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float. # hugging face exception raised when temp==0
top_p: In each step of text generation, sample from the smallest possible set of words with cumulative probability top_p. If specified, it must be a float between 0 and 1. # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
return_full_text: If True, input text will be part of the output generated text. If specified, it must be boolean. The default value for it is False. temperature = 0.01
""" optional_params["temperature"] = temperature
## check if unsupported param passed in if top_p is not None:
supported_params = ["temperature", "max_tokens", "stream"] optional_params["top_p"] = top_p
_check_valid_arg(supported_params=supported_params) if n is not None:
optional_params["best_of"] = n
if max_tokens is not None: optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints
optional_params["max_new_tokens"] = max_tokens if stream is not None:
if temperature is not None: optional_params["stream"] = stream
optional_params["temperature"] = temperature if stop is not None:
if top_p is not None: optional_params["stop"] = stop
optional_params["top_p"] = top_p if max_tokens is not None:
if stream: # HF TGI raises the following exception when max_new_tokens==0
optional_params["stream"] = stream # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
else: if max_tokens == 0:
## check if unsupported param passed in max_tokens = 1
supported_params = [] optional_params["max_new_tokens"] = max_tokens
_check_valid_arg(supported_params=supported_params)
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
if "ai21" in model: if "ai21" in model:
supported_params = ["max_tokens", "temperature", "top_p", "stream"] supported_params = ["max_tokens", "temperature", "top_p", "stream"]
@ -5270,11 +5269,9 @@ class CustomStreamWrapper:
else: else:
model_response.choices[0].finish_reason = "stop" model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True self.sent_last_chunk = True
chunk_size = 30 new_chunk = self.completion_stream
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:] self.completion_stream = self.completion_stream[len(self.completion_stream):]
time.sleep(0.05)
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream)==0: if len(self.completion_stream)==0:
if self.sent_last_chunk: if self.sent_last_chunk: