diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 2482c5457..cb5b56bdd 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -121,10 +121,10 @@ def completion( ) else: if hf_model_name is None: - if "llama2" in model.lower(): # llama2 model - if "chat" in model.lower(): + if "llama-2" in model.lower(): # llama-2 model + if "chat" in model.lower(): # apply llama2 chat template hf_model_name = "meta-llama/Llama-2-7b-chat-hf" - else: + else: # apply regular llama2 template hf_model_name = "meta-llama/Llama-2-7b" hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) prompt = prompt_factory(model=hf_model_name, messages=messages) @@ -146,7 +146,7 @@ def completion( logging_obj.pre_call( input=prompt, api_key="", - additional_args={"complete_input_dict": data, "request_str": request_str}, + additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name}, ) ## COMPLETION CALL try: diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 0f0bff1cc..69aff761c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1035,30 +1035,27 @@ def test_completion_sagemaker(): print("testing sagemaker") litellm.set_verbose=True response = completion( - model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, temperature=0.2, max_tokens=80, - hf_model_name="meta-llama/Llama-2-7b", ) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_sagemaker() +test_completion_sagemaker() def test_completion_chat_sagemaker(): try: messages = [{"role": "user", "content": "Hey, how's it going?"}] - print("testing sagemaker") litellm.set_verbose=True response = completion( - model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-chat", + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, max_tokens=100, + temperature=0.7, stream=True, - n=2, - hf_model_name="meta-llama/Llama-2-7b-chat-hf", ) # Add any assertions here to check the response complete_response = "" @@ -1068,8 +1065,23 @@ def test_completion_chat_sagemaker(): assert len(complete_response) > 0 except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_chat_sagemaker() +# test_completion_chat_sagemaker() +def test_completion_chat_sagemaker_mistral(): + try: + messages = [{"role": "user", "content": "Hey, how's it going?"}] + + response = completion( + model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct", + messages=messages, + max_tokens=100, + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"An error occurred: {str(e)}") + +# test_completion_chat_sagemaker_mistral() def test_completion_bedrock_titan(): try: response = completion(