forked from phoenix/litellm-mirror
fix bedrock claude test
This commit is contained in:
parent
9eee36b449
commit
f9ba3cf668
3 changed files with 20 additions and 9 deletions
|
@ -799,7 +799,11 @@ from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
|
from .llms.bedrock_httpx import (
|
||||||
|
AmazonCohereChatConfig,
|
||||||
|
AmazonConverseConfig,
|
||||||
|
BEDROCK_CONVERSE_MODELS,
|
||||||
|
)
|
||||||
from .llms.bedrock import (
|
from .llms.bedrock import (
|
||||||
AmazonTitanConfig,
|
AmazonTitanConfig,
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
|
|
|
@ -60,6 +60,12 @@ from .prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BEDROCK_CONVERSE_MODELS = [
|
||||||
|
"anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
]
|
||||||
|
|
||||||
iam_cache = DualCache()
|
iam_cache = DualCache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -437,14 +443,15 @@ class BedrockLLM(BaseLLM):
|
||||||
aws_access_key_id is not None
|
aws_access_key_id is not None
|
||||||
and aws_secret_access_key is not None
|
and aws_secret_access_key is not None
|
||||||
and aws_session_token is not None
|
and aws_session_token is not None
|
||||||
): ### CHECK FOR AWS SESSION TOKEN ###
|
): ### CHECK FOR AWS SESSION TOKEN ###
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
credentials = Credentials(
|
credentials = Credentials(
|
||||||
access_key=aws_access_key_id,
|
access_key=aws_access_key_id,
|
||||||
secret_key=aws_secret_access_key,
|
secret_key=aws_secret_access_key,
|
||||||
token=aws_session_token,
|
token=aws_session_token,
|
||||||
)
|
)
|
||||||
return credentials
|
return credentials
|
||||||
else:
|
else:
|
||||||
session = boto3.Session(
|
session = boto3.Session(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
@ -1571,14 +1578,15 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
aws_access_key_id is not None
|
aws_access_key_id is not None
|
||||||
and aws_secret_access_key is not None
|
and aws_secret_access_key is not None
|
||||||
and aws_session_token is not None
|
and aws_session_token is not None
|
||||||
): ### CHECK FOR AWS SESSION TOKEN ###
|
): ### CHECK FOR AWS SESSION TOKEN ###
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
credentials = Credentials(
|
credentials = Credentials(
|
||||||
access_key=aws_access_key_id,
|
access_key=aws_access_key_id,
|
||||||
secret_key=aws_secret_access_key,
|
secret_key=aws_secret_access_key,
|
||||||
token=aws_session_token,
|
token=aws_session_token,
|
||||||
)
|
)
|
||||||
return credentials
|
return credentials
|
||||||
else:
|
else:
|
||||||
session = boto3.Session(
|
session = boto3.Session(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
|
|
@ -2200,8 +2200,7 @@ def completion(
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
|
if "aws_bedrock_client" in optional_params:
|
||||||
if ("aws_bedrock_client" in optional_params):
|
|
||||||
# Extract credentials for legacy boto3 client and pass thru to httpx
|
# Extract credentials for legacy boto3 client and pass thru to httpx
|
||||||
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
|
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
|
||||||
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
|
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
|
||||||
|
@ -2210,9 +2209,9 @@ def completion(
|
||||||
if creds.secret_key:
|
if creds.secret_key:
|
||||||
optional_params["aws_secret_access_key"] = creds.secret_key
|
optional_params["aws_secret_access_key"] = creds.secret_key
|
||||||
if creds.token:
|
if creds.token:
|
||||||
optional_params["aws_session_token"] = creds.token
|
optional_params["aws_session_token"] = creds.token
|
||||||
|
|
||||||
if model.startswith("anthropic"):
|
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||||
response = bedrock_converse_chat_completion.completion(
|
response = bedrock_converse_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue