mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
add support for bedrock mistral models
This commit is contained in:
parent
f9ef3ce32d
commit
739f4f05f6
3 changed files with 72 additions and 24 deletions
|
@ -492,6 +492,8 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
prompt = prompt_factory(
|
prompt = prompt_factory(
|
||||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
|
elif provider == "mistral":
|
||||||
|
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="bedrock")
|
||||||
else:
|
else:
|
||||||
prompt = ""
|
prompt = ""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
@ -623,7 +625,16 @@ def completion(
|
||||||
"textGenerationConfig": inference_params,
|
"textGenerationConfig": inference_params,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif provider == "mistral":
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonLlamaConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
|
||||||
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
else:
|
else:
|
||||||
data = json.dumps({})
|
data = json.dumps({})
|
||||||
|
|
||||||
|
@ -729,6 +740,9 @@ def completion(
|
||||||
outputText = response_body["generations"][0]["text"]
|
outputText = response_body["generations"][0]["text"]
|
||||||
elif provider == "meta":
|
elif provider == "meta":
|
||||||
outputText = response_body["generation"]
|
outputText = response_body["generation"]
|
||||||
|
elif provider == "mistral":
|
||||||
|
outputText = response_body["outputs"][0]["text"]
|
||||||
|
model_response["finish_reason"] = response_body["outputs"][0]["stop_reason"]
|
||||||
else: # amazon titan
|
else: # amazon titan
|
||||||
outputText = response_body.get("results")[0].get("outputText")
|
outputText = response_body.get("results")[0].get("outputText")
|
||||||
|
|
||||||
|
|
|
@ -674,6 +674,8 @@ def prompt_factory(
|
||||||
return claude_2_1_pt(messages=messages)
|
return claude_2_1_pt(messages=messages)
|
||||||
else:
|
else:
|
||||||
return anthropic_pt(messages=messages)
|
return anthropic_pt(messages=messages)
|
||||||
|
elif "mistral." in model:
|
||||||
|
return mistral_instruct_pt(messages=messages)
|
||||||
try:
|
try:
|
||||||
if "meta-llama/llama-2" in model and "chat" in model:
|
if "meta-llama/llama-2" in model and "chat" in model:
|
||||||
return llama_2_chat_pt(messages=messages)
|
return llama_2_chat_pt(messages=messages)
|
||||||
|
|
|
@ -1,33 +1,33 @@
|
||||||
# @pytest.mark.skip(reason="AWS Suspended Account")
|
# @pytest.mark.skip(reason="AWS Suspended Account")
|
||||||
# import sys, os
|
import sys, os
|
||||||
# import traceback
|
import traceback
|
||||||
# from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# load_dotenv()
|
load_dotenv()
|
||||||
# import os, io
|
import os, io
|
||||||
|
|
||||||
# sys.path.insert(
|
sys.path.insert(
|
||||||
# 0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
# ) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
# import pytest
|
import pytest
|
||||||
# import litellm
|
import litellm
|
||||||
# from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
# from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
|
||||||
# # litellm.num_retries = 3
|
# litellm.num_retries = 3
|
||||||
# litellm.cache = None
|
litellm.cache = None
|
||||||
# litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
# user_message = "Write a short poem about the sky"
|
user_message = "Write a short poem about the sky"
|
||||||
# messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
|
||||||
|
|
||||||
# @pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
# def reset_callbacks():
|
def reset_callbacks():
|
||||||
# print("\npytest fixture - resetting callbacks")
|
print("\npytest fixture - resetting callbacks")
|
||||||
# litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
# litellm._async_success_callback = []
|
litellm._async_success_callback = []
|
||||||
# litellm.failure_callback = []
|
litellm.failure_callback = []
|
||||||
# litellm.callbacks = []
|
litellm.callbacks = []
|
||||||
|
|
||||||
|
|
||||||
# def test_completion_bedrock_claude_completion_auth():
|
# def test_completion_bedrock_claude_completion_auth():
|
||||||
|
@ -257,3 +257,35 @@
|
||||||
|
|
||||||
|
|
||||||
# # test_provisioned_throughput()
|
# # test_provisioned_throughput()
|
||||||
|
|
||||||
|
def test_completion_bedrock_mistral_completion_auth():
|
||||||
|
print("calling bedrock mistral completion params auth")
|
||||||
|
import os
|
||||||
|
|
||||||
|
# aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
|
||||||
|
# aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
|
||||||
|
# aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
|
|
||||||
|
# os.environ.pop("AWS_ACCESS_KEY_ID", None)
|
||||||
|
# os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
|
||||||
|
# os.environ.pop("AWS_REGION_NAME", None)
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/mistral.mistral-7b-instruct-v0:2",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
# os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
||||||
|
# os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
||||||
|
# os.environ["AWS_REGION_NAME"] = aws_region_name
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
test_completion_bedrock_mistral_completion_auth()
|
Loading…
Add table
Add a link
Reference in a new issue