forked from phoenix/litellm-mirror
fix(sagemaker.py): prompt templating fixes
This commit is contained in:
parent
0eccc1b1f8
commit
648d41c96f
2 changed files with 24 additions and 12 deletions
|
@ -121,10 +121,10 @@ def completion(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if hf_model_name is None:
|
if hf_model_name is None:
|
||||||
if "llama2" in model.lower(): # llama2 model
|
if "llama-2" in model.lower(): # llama-2 model
|
||||||
if "chat" in model.lower():
|
if "chat" in model.lower(): # apply llama2 chat template
|
||||||
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
|
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
else:
|
else: # apply regular llama2 template
|
||||||
hf_model_name = "meta-llama/Llama-2-7b"
|
hf_model_name = "meta-llama/Llama-2-7b"
|
||||||
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||||
|
@ -146,7 +146,7 @@ def completion(
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1035,30 +1035,27 @@ def test_completion_sagemaker():
|
||||||
print("testing sagemaker")
|
print("testing sagemaker")
|
||||||
litellm.set_verbose=True
|
litellm.set_verbose=True
|
||||||
response = completion(
|
response = completion(
|
||||||
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
|
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=80,
|
max_tokens=80,
|
||||||
hf_model_name="meta-llama/Llama-2-7b",
|
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_sagemaker()
|
test_completion_sagemaker()
|
||||||
|
|
||||||
def test_completion_chat_sagemaker():
|
def test_completion_chat_sagemaker():
|
||||||
try:
|
try:
|
||||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
print("testing sagemaker")
|
|
||||||
litellm.set_verbose=True
|
litellm.set_verbose=True
|
||||||
response = completion(
|
response = completion(
|
||||||
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-chat",
|
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
|
temperature=0.7,
|
||||||
stream=True,
|
stream=True,
|
||||||
n=2,
|
|
||||||
hf_model_name="meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
|
@ -1068,8 +1065,23 @@ def test_completion_chat_sagemaker():
|
||||||
assert len(complete_response) > 0
|
assert len(complete_response) > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
test_completion_chat_sagemaker()
|
# test_completion_chat_sagemaker()
|
||||||
|
|
||||||
|
def test_completion_chat_sagemaker_mistral():
|
||||||
|
try:
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An error occurred: {str(e)}")
|
||||||
|
|
||||||
|
# test_completion_chat_sagemaker_mistral()
|
||||||
def test_completion_bedrock_titan():
|
def test_completion_bedrock_titan():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue