diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index b7f1c50236..4806a57e2a 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -492,6 +492,8 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): prompt = prompt_factory( model=model, messages=messages, custom_llm_provider="bedrock" ) + elif provider == "mistral": + prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="bedrock") else: prompt = "" for message in messages: @@ -623,7 +625,16 @@ def completion( "textGenerationConfig": inference_params, } ) + elif provider == "mistral": + ## LOAD CONFIG + config = litellm.AmazonLlamaConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + data = json.dumps({"prompt": prompt, **inference_params}) else: data = json.dumps({}) @@ -729,6 +740,9 @@ def completion( outputText = response_body["generations"][0]["text"] elif provider == "meta": outputText = response_body["generation"] + elif provider == "mistral": + outputText = response_body["outputs"][0]["text"] + model_response["finish_reason"] = response_body["outputs"][0]["stop_reason"] else: # amazon titan outputText = response_body.get("results")[0].get("outputText") diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 4ed4d9295d..cc8d3d49bd 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -674,6 +674,8 @@ def prompt_factory( return claude_2_1_pt(messages=messages) else: return anthropic_pt(messages=messages) + elif "mistral." in model: + return mistral_instruct_pt(messages=messages) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index a448fc3a57..4a9164019e 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -1,33 +1,33 @@ # @pytest.mark.skip(reason="AWS Suspended Account") -# import sys, os -# import traceback -# from dotenv import load_dotenv +import sys, os +import traceback +from dotenv import load_dotenv -# load_dotenv() -# import os, io +load_dotenv() +import os, io -# sys.path.insert( -# 0, os.path.abspath("../..") -# ) # Adds the parent directory to the system path -# import pytest -# import litellm -# from litellm import embedding, completion, completion_cost, Timeout -# from litellm import RateLimitError +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import embedding, completion, completion_cost, Timeout +from litellm import RateLimitError -# # litellm.num_retries = 3 -# litellm.cache = None -# litellm.success_callback = [] -# user_message = "Write a short poem about the sky" -# messages = [{"content": user_message, "role": "user"}] +# litellm.num_retries = 3 +litellm.cache = None +litellm.success_callback = [] +user_message = "Write a short poem about the sky" +messages = [{"content": user_message, "role": "user"}] -# @pytest.fixture(autouse=True) -# def reset_callbacks(): -# print("\npytest fixture - resetting callbacks") -# litellm.success_callback = [] -# litellm._async_success_callback = [] -# litellm.failure_callback = [] -# litellm.callbacks = [] +@pytest.fixture(autouse=True) +def reset_callbacks(): + print("\npytest fixture - resetting callbacks") + litellm.success_callback = [] + litellm._async_success_callback = [] + litellm.failure_callback = [] + litellm.callbacks = [] # def test_completion_bedrock_claude_completion_auth(): @@ -257,3 +257,35 @@ # # test_provisioned_throughput() + +def test_completion_bedrock_mistral_completion_auth(): + print("calling bedrock mistral completion params auth") + import os + + # aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] + # aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] + # aws_region_name = os.environ["AWS_REGION_NAME"] + + # os.environ.pop("AWS_ACCESS_KEY_ID", None) + # os.environ.pop("AWS_SECRET_ACCESS_KEY", None) + # os.environ.pop("AWS_REGION_NAME", None) + try: + response = completion( + model="bedrock/mistral.mistral-7b-instruct-v0:2", + messages=messages, + max_tokens=10, + temperature=0.1, + ) + # Add any assertions here to check the response + print(response) + + # os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id + # os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key + # os.environ["AWS_REGION_NAME"] = aws_region_name + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +test_completion_bedrock_mistral_completion_auth() \ No newline at end of file