fix(utils.py): support sagemaker llama2 custom endpoints

This commit is contained in:
Krrish Dholakia 2023-12-05 16:04:43 -08:00
parent 09c2c1610d
commit b4c78c7b9e
4 changed files with 53 additions and 45 deletions

View file

@ -9,6 +9,7 @@ from litellm.utils import ModelResponse, get_secret, Usage
import sys
from copy import deepcopy
import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class SagemakerError(Exception):
def __init__(self, status_code, message):
@ -61,6 +62,7 @@ def completion(
print_verbose: Callable,
encoding,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -107,19 +109,23 @@ def completion(
inference_params[k] = v
model = model
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages
)
else:
hf_model_name = model
if "jumpstart-dft-meta-textgeneration-llama" in model: # llama2 model
if model.endswith("-f") or "-f-" in model: # sagemaker default for a chat model
hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model
else:
prompt += (
f"{message['content']}"
)
else:
prompt += f"{message['content']}"
hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template
prompt = prompt_factory(model=hf_model_name, messages=messages)
data = json.dumps({
"inputs": prompt,