mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(utils.py): support sagemaker llama2 custom endpoints
This commit is contained in:
parent
09c2c1610d
commit
b4c78c7b9e
4 changed files with 53 additions and 45 deletions
|
@ -9,6 +9,7 @@ from litellm.utils import ModelResponse, get_secret, Usage
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
class SagemakerError(Exception):
|
class SagemakerError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
|
@ -61,6 +62,7 @@ def completion(
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
custom_prompt_dict={},
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -107,19 +109,23 @@ def completion(
|
||||||
inference_params[k] = v
|
inference_params[k] = v
|
||||||
|
|
||||||
model = model
|
model = model
|
||||||
prompt = ""
|
if model in custom_prompt_dict:
|
||||||
for message in messages:
|
# check if the model has a registered custom prompt
|
||||||
if "role" in message:
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
if message["role"] == "user":
|
prompt = custom_prompt(
|
||||||
prompt += (
|
role_dict=model_prompt_details.get("roles", None),
|
||||||
f"{message['content']}"
|
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||||
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
|
messages=messages
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt += (
|
hf_model_name = model
|
||||||
f"{message['content']}"
|
if "jumpstart-dft-meta-textgeneration-llama" in model: # llama2 model
|
||||||
)
|
if model.endswith("-f") or "-f-" in model: # sagemaker default for a chat model
|
||||||
|
hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model
|
||||||
else:
|
else:
|
||||||
prompt += f"{message['content']}"
|
hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template
|
||||||
|
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||||
|
|
||||||
data = json.dumps({
|
data = json.dumps({
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
|
|
|
@ -1166,6 +1166,7 @@ def completion(
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging
|
logging_obj=logging
|
||||||
|
|
|
@ -1048,20 +1048,24 @@ def test_completion_sagemaker():
|
||||||
|
|
||||||
def test_completion_chat_sagemaker():
|
def test_completion_chat_sagemaker():
|
||||||
try:
|
try:
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
print("testing sagemaker")
|
print("testing sagemaker")
|
||||||
litellm.set_verbose=True
|
litellm.set_verbose=True
|
||||||
response = completion(
|
response = completion(
|
||||||
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f",
|
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
max_tokens=100,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
complete_response = ""
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk)
|
complete_response += chunk.choices[0].delta.content or ""
|
||||||
|
print(f"complete_response: {complete_response}")
|
||||||
|
assert len(complete_response) > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_chat_sagemaker()
|
test_completion_chat_sagemaker()
|
||||||
|
|
||||||
def test_completion_bedrock_titan():
|
def test_completion_bedrock_titan():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -2206,32 +2206,31 @@ def get_optional_params( # use the openai defaults
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
optional_params["max_output_tokens"] = max_tokens
|
optional_params["max_output_tokens"] = max_tokens
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
if "llama-2" in model.lower() or (
|
|
||||||
"llama" in model.lower() and "2" in model.lower() # some combination of llama and "2" should exist
|
|
||||||
): # jumpstart can also send "Llama-2-70b-chat-hf-48xlarge"
|
|
||||||
# llama-2 models on sagemaker support the following args
|
|
||||||
"""
|
|
||||||
max_new_tokens: Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer.
|
|
||||||
temperature: Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float.
|
|
||||||
top_p: In each step of text generation, sample from the smallest possible set of words with cumulative probability top_p. If specified, it must be a float between 0 and 1.
|
|
||||||
return_full_text: If True, input text will be part of the output generated text. If specified, it must be boolean. The default value for it is False.
|
|
||||||
"""
|
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = ["temperature", "max_tokens", "stream"]
|
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_new_tokens"] = max_tokens
|
|
||||||
if temperature is not None:
|
if temperature is not None:
|
||||||
|
if temperature == 0.0 or temperature == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
temperature = 0.01
|
||||||
optional_params["temperature"] = temperature
|
optional_params["temperature"] = temperature
|
||||||
if top_p is not None:
|
if top_p is not None:
|
||||||
optional_params["top_p"] = top_p
|
optional_params["top_p"] = top_p
|
||||||
if stream:
|
if n is not None:
|
||||||
|
optional_params["best_of"] = n
|
||||||
|
optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
if stream is not None:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
else:
|
if stop is not None:
|
||||||
## check if unsupported param passed in
|
optional_params["stop"] = stop
|
||||||
supported_params = []
|
if max_tokens is not None:
|
||||||
_check_valid_arg(supported_params=supported_params)
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if max_tokens == 0:
|
||||||
|
max_tokens = 1
|
||||||
|
optional_params["max_new_tokens"] = max_tokens
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
if "ai21" in model:
|
if "ai21" in model:
|
||||||
supported_params = ["max_tokens", "temperature", "top_p", "stream"]
|
supported_params = ["max_tokens", "temperature", "top_p", "stream"]
|
||||||
|
@ -5270,11 +5269,9 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
model_response.choices[0].finish_reason = "stop"
|
model_response.choices[0].finish_reason = "stop"
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
chunk_size = 30
|
new_chunk = self.completion_stream
|
||||||
new_chunk = self.completion_stream[:chunk_size]
|
|
||||||
completion_obj["content"] = new_chunk
|
completion_obj["content"] = new_chunk
|
||||||
self.completion_stream = self.completion_stream[chunk_size:]
|
self.completion_stream = self.completion_stream[len(self.completion_stream):]
|
||||||
time.sleep(0.05)
|
|
||||||
elif self.custom_llm_provider == "petals":
|
elif self.custom_llm_provider == "petals":
|
||||||
if len(self.completion_stream)==0:
|
if len(self.completion_stream)==0:
|
||||||
if self.sent_last_chunk:
|
if self.sent_last_chunk:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue