diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 51be1f53d..1ee43ec2e 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -9,6 +9,7 @@ from litellm.utils import ModelResponse, get_secret, Usage import sys from copy import deepcopy import httpx +from .prompt_templates.factory import prompt_factory, custom_prompt class SagemakerError(Exception): def __init__(self, status_code, message): @@ -61,6 +62,7 @@ def completion( print_verbose: Callable, encoding, logging_obj, + custom_prompt_dict={}, optional_params=None, litellm_params=None, logger_fn=None, @@ -107,19 +109,23 @@ def completion( inference_params[k] = v model = model - prompt = "" - for message in messages: - if "role" in message: - if message["role"] == "user": - prompt += ( - f"{message['content']}" - ) + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages + ) + else: + hf_model_name = model + if "jumpstart-dft-meta-textgeneration-llama" in model: # llama2 model + if model.endswith("-f") or "-f-" in model: # sagemaker default for a chat model + hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model else: - prompt += ( - f"{message['content']}" - ) - else: - prompt += f"{message['content']}" + hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template + prompt = prompt_factory(model=hf_model_name, messages=messages) data = json.dumps({ "inputs": prompt, diff --git a/litellm/main.py b/litellm/main.py index 5c421a351..e0c7cf3b1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1166,6 +1166,7 @@ def completion( print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, + custom_prompt_dict=custom_prompt_dict, logger_fn=logger_fn, encoding=encoding, logging_obj=logging diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 15b547a85..3c4a9aa40 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1048,20 +1048,24 @@ def test_completion_sagemaker(): def test_completion_chat_sagemaker(): try: + messages = [{"role": "user", "content": "Hey, how's it going?"}] print("testing sagemaker") litellm.set_verbose=True response = completion( model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f", messages=messages, + max_tokens=100, stream=True, ) # Add any assertions here to check the response - print(response) + complete_response = "" for chunk in response: - print(chunk) + complete_response += chunk.choices[0].delta.content or "" + print(f"complete_response: {complete_response}") + assert len(complete_response) > 0 except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_chat_sagemaker() +test_completion_chat_sagemaker() def test_completion_bedrock_titan(): try: diff --git a/litellm/utils.py b/litellm/utils.py index f419ebcd3..814735f88 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2206,32 +2206,31 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens elif custom_llm_provider == "sagemaker": - if "llama-2" in model.lower() or ( - "llama" in model.lower() and "2" in model.lower() # some combination of llama and "2" should exist - ): # jumpstart can also send "Llama-2-70b-chat-hf-48xlarge" - # llama-2 models on sagemaker support the following args - """ - max_new_tokens: Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer. - temperature: Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float. - top_p: In each step of text generation, sample from the smallest possible set of words with cumulative probability top_p. If specified, it must be a float between 0 and 1. - return_full_text: If True, input text will be part of the output generated text. If specified, it must be boolean. The default value for it is False. - """ - ## check if unsupported param passed in - supported_params = ["temperature", "max_tokens", "stream"] - _check_valid_arg(supported_params=supported_params) - - if max_tokens is not None: - optional_params["max_new_tokens"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if stream: - optional_params["stream"] = stream - else: - ## check if unsupported param passed in - supported_params = [] - _check_valid_arg(supported_params=supported_params) + ## check if unsupported param passed in + supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + _check_valid_arg(supported_params=supported_params) + # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None + if temperature is not None: + if temperature == 0.0 or temperature == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + temperature = 0.01 + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if n is not None: + optional_params["best_of"] = n + optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints + if stream is not None: + optional_params["stream"] = stream + if stop is not None: + optional_params["stop"] = stop + if max_tokens is not None: + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if max_tokens == 0: + max_tokens = 1 + optional_params["max_new_tokens"] = max_tokens elif custom_llm_provider == "bedrock": if "ai21" in model: supported_params = ["max_tokens", "temperature", "top_p", "stream"] @@ -5270,11 +5269,9 @@ class CustomStreamWrapper: else: model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True - chunk_size = 30 - new_chunk = self.completion_stream[:chunk_size] + new_chunk = self.completion_stream completion_obj["content"] = new_chunk - self.completion_stream = self.completion_stream[chunk_size:] - time.sleep(0.05) + self.completion_stream = self.completion_stream[len(self.completion_stream):] elif self.custom_llm_provider == "petals": if len(self.completion_stream)==0: if self.sent_last_chunk: