fix bedrock claude test

This commit is contained in:
Ishaan Jaff 2024-06-29 18:46:06 -07:00
parent 9eee36b449
commit f9ba3cf668
3 changed files with 20 additions and 9 deletions

View file

@ -799,7 +799,11 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig from .llms.bedrock_httpx import (
AmazonCohereChatConfig,
AmazonConverseConfig,
BEDROCK_CONVERSE_MODELS,
)
from .llms.bedrock import ( from .llms.bedrock import (
AmazonTitanConfig, AmazonTitanConfig,
AmazonAI21Config, AmazonAI21Config,

View file

@ -60,6 +60,12 @@ from .prompt_templates.factory import (
prompt_factory, prompt_factory,
) )
BEDROCK_CONVERSE_MODELS = [
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
]
iam_cache = DualCache() iam_cache = DualCache()
@ -439,6 +445,7 @@ class BedrockLLM(BaseLLM):
and aws_session_token is not None and aws_session_token is not None
): ### CHECK FOR AWS SESSION TOKEN ### ): ### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials from botocore.credentials import Credentials
credentials = Credentials( credentials = Credentials(
access_key=aws_access_key_id, access_key=aws_access_key_id,
secret_key=aws_secret_access_key, secret_key=aws_secret_access_key,
@ -1573,6 +1580,7 @@ class BedrockConverseLLM(BaseLLM):
and aws_session_token is not None and aws_session_token is not None
): ### CHECK FOR AWS SESSION TOKEN ### ): ### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials from botocore.credentials import Credentials
credentials = Credentials( credentials = Credentials(
access_key=aws_access_key_id, access_key=aws_access_key_id,
secret_key=aws_secret_access_key, secret_key=aws_secret_access_key,

View file

@ -2200,8 +2200,7 @@ def completion(
# boto3 reads keys from .env # boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if "aws_bedrock_client" in optional_params:
if ("aws_bedrock_client" in optional_params):
# Extract credentials for legacy boto3 client and pass thru to httpx # Extract credentials for legacy boto3 client and pass thru to httpx
aws_bedrock_client = optional_params.pop("aws_bedrock_client") aws_bedrock_client = optional_params.pop("aws_bedrock_client")
creds = aws_bedrock_client._get_credentials().get_frozen_credentials() creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
@ -2212,7 +2211,7 @@ def completion(
if creds.token: if creds.token:
optional_params["aws_session_token"] = creds.token optional_params["aws_session_token"] = creds.token
if model.startswith("anthropic"): if model in litellm.BEDROCK_CONVERSE_MODELS:
response = bedrock_converse_chat_completion.completion( response = bedrock_converse_chat_completion.completion(
model=model, model=model,
messages=messages, messages=messages,