From f9ba3cf6685a3ee5b160e191b8fc724b6df4ea42 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 29 Jun 2024 18:46:06 -0700 Subject: [PATCH] fix bedrock claude test --- litellm/__init__.py | 6 +++++- litellm/llms/bedrock_httpx.py | 16 ++++++++++++---- litellm/main.py | 7 +++---- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index a8d9a80a2..5bd5d1a16 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -799,7 +799,11 @@ from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig +from .llms.bedrock_httpx import ( + AmazonCohereChatConfig, + AmazonConverseConfig, + BEDROCK_CONVERSE_MODELS, +) from .llms.bedrock import ( AmazonTitanConfig, AmazonAI21Config, diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index d376808b7..3faaf40f1 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -60,6 +60,12 @@ from .prompt_templates.factory import ( prompt_factory, ) +BEDROCK_CONVERSE_MODELS = [ + "anthropic.claude-3-opus-20240229-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", +] + iam_cache = DualCache() @@ -437,14 +443,15 @@ class BedrockLLM(BaseLLM): aws_access_key_id is not None and aws_secret_access_key is not None and aws_session_token is not None - ): ### CHECK FOR AWS SESSION TOKEN ### + ): ### CHECK FOR AWS SESSION TOKEN ### from botocore.credentials import Credentials + credentials = Credentials( access_key=aws_access_key_id, secret_key=aws_secret_access_key, token=aws_session_token, ) - return credentials + return credentials else: session = boto3.Session( aws_access_key_id=aws_access_key_id, @@ -1571,14 +1578,15 @@ class BedrockConverseLLM(BaseLLM): aws_access_key_id is not None and aws_secret_access_key is not None and aws_session_token is not None - ): ### CHECK FOR AWS SESSION TOKEN ### + ): ### CHECK FOR AWS SESSION TOKEN ### from botocore.credentials import Credentials + credentials = Credentials( access_key=aws_access_key_id, secret_key=aws_secret_access_key, token=aws_session_token, ) - return credentials + return credentials else: session = boto3.Session( aws_access_key_id=aws_access_key_id, diff --git a/litellm/main.py b/litellm/main.py index 951a79c45..10bcbe9e3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2200,8 +2200,7 @@ def completion( # boto3 reads keys from .env custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - if ("aws_bedrock_client" in optional_params): + if "aws_bedrock_client" in optional_params: # Extract credentials for legacy boto3 client and pass thru to httpx aws_bedrock_client = optional_params.pop("aws_bedrock_client") creds = aws_bedrock_client._get_credentials().get_frozen_credentials() @@ -2210,9 +2209,9 @@ def completion( if creds.secret_key: optional_params["aws_secret_access_key"] = creds.secret_key if creds.token: - optional_params["aws_session_token"] = creds.token + optional_params["aws_session_token"] = creds.token - if model.startswith("anthropic"): + if model in litellm.BEDROCK_CONVERSE_MODELS: response = bedrock_converse_chat_completion.completion( model=model, messages=messages,