mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(utils.py): support sagemaker llama2 custom endpoints
This commit is contained in:
parent
09c2c1610d
commit
b4c78c7b9e
4 changed files with 53 additions and 45 deletions
|
@ -9,6 +9,7 @@ from litellm.utils import ModelResponse, get_secret, Usage
|
|||
import sys
|
||||
from copy import deepcopy
|
||||
import httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
class SagemakerError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -61,6 +62,7 @@ def completion(
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -107,19 +109,23 @@ def completion(
|
|||
inference_params[k] = v
|
||||
|
||||
model = model
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages
|
||||
)
|
||||
else:
|
||||
hf_model_name = model
|
||||
if "jumpstart-dft-meta-textgeneration-llama" in model: # llama2 model
|
||||
if model.endswith("-f") or "-f-" in model: # sagemaker default for a chat model
|
||||
hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model
|
||||
else:
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template
|
||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||
|
||||
data = json.dumps({
|
||||
"inputs": prompt,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue