add support for bedrock mistral models

This commit is contained in:
Tim Xia 2024-03-01 22:45:54 -05:00
parent f9ef3ce32d
commit 739f4f05f6
3 changed files with 72 additions and 24 deletions

View file

@ -492,6 +492,8 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
prompt = prompt_factory( prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock" model=model, messages=messages, custom_llm_provider="bedrock"
) )
elif provider == "mistral":
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="bedrock")
else: else:
prompt = "" prompt = ""
for message in messages: for message in messages:
@ -623,7 +625,16 @@ def completion(
"textGenerationConfig": inference_params, "textGenerationConfig": inference_params,
} }
) )
elif provider == "mistral":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
else: else:
data = json.dumps({}) data = json.dumps({})
@ -729,6 +740,9 @@ def completion(
outputText = response_body["generations"][0]["text"] outputText = response_body["generations"][0]["text"]
elif provider == "meta": elif provider == "meta":
outputText = response_body["generation"] outputText = response_body["generation"]
elif provider == "mistral":
outputText = response_body["outputs"][0]["text"]
model_response["finish_reason"] = response_body["outputs"][0]["stop_reason"]
else: # amazon titan else: # amazon titan
outputText = response_body.get("results")[0].get("outputText") outputText = response_body.get("results")[0].get("outputText")

View file

@ -674,6 +674,8 @@ def prompt_factory(
return claude_2_1_pt(messages=messages) return claude_2_1_pt(messages=messages)
else: else:
return anthropic_pt(messages=messages) return anthropic_pt(messages=messages)
elif "mistral." in model:
return mistral_instruct_pt(messages=messages)
try: try:
if "meta-llama/llama-2" in model and "chat" in model: if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages) return llama_2_chat_pt(messages=messages)

View file

@ -1,33 +1,33 @@
# @pytest.mark.skip(reason="AWS Suspended Account") # @pytest.mark.skip(reason="AWS Suspended Account")
# import sys, os import sys, os
# import traceback import traceback
# from dotenv import load_dotenv from dotenv import load_dotenv
# load_dotenv() load_dotenv()
# import os, io import os, io
# sys.path.insert( sys.path.insert(
# 0, os.path.abspath("../..") 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
# import pytest import pytest
# import litellm import litellm
# from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
# from litellm import RateLimitError from litellm import RateLimitError
# # litellm.num_retries = 3 # litellm.num_retries = 3
# litellm.cache = None litellm.cache = None
# litellm.success_callback = [] litellm.success_callback = []
# user_message = "Write a short poem about the sky" user_message = "Write a short poem about the sky"
# messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
# @pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
# def reset_callbacks(): def reset_callbacks():
# print("\npytest fixture - resetting callbacks") print("\npytest fixture - resetting callbacks")
# litellm.success_callback = [] litellm.success_callback = []
# litellm._async_success_callback = [] litellm._async_success_callback = []
# litellm.failure_callback = [] litellm.failure_callback = []
# litellm.callbacks = [] litellm.callbacks = []
# def test_completion_bedrock_claude_completion_auth(): # def test_completion_bedrock_claude_completion_auth():
@ -257,3 +257,35 @@
# # test_provisioned_throughput() # # test_provisioned_throughput()
def test_completion_bedrock_mistral_completion_auth():
print("calling bedrock mistral completion params auth")
import os
# aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
# aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
# aws_region_name = os.environ["AWS_REGION_NAME"]
# os.environ.pop("AWS_ACCESS_KEY_ID", None)
# os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
# os.environ.pop("AWS_REGION_NAME", None)
try:
response = completion(
model="bedrock/mistral.mistral-7b-instruct-v0:2",
messages=messages,
max_tokens=10,
temperature=0.1,
)
# Add any assertions here to check the response
print(response)
# os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
# os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
# os.environ["AWS_REGION_NAME"] = aws_region_name
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_bedrock_mistral_completion_auth()