mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat(bedrock.py): Support using OIDC tokens.
This commit is contained in:
parent
b0b87ba79b
commit
47449d19e8
2 changed files with 70 additions and 1 deletions
|
@ -550,6 +550,7 @@ def init_bedrock_client(
|
||||||
aws_session_name: Optional[str] = None,
|
aws_session_name: Optional[str] = None,
|
||||||
aws_profile_name: Optional[str] = None,
|
aws_profile_name: Optional[str] = None,
|
||||||
aws_role_name: Optional[str] = None,
|
aws_role_name: Optional[str] = None,
|
||||||
|
aws_web_identity_token: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
):
|
):
|
||||||
|
@ -566,6 +567,7 @@ def init_bedrock_client(
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
aws_role_name,
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Iterate over parameters and update if needed
|
# Iterate over parameters and update if needed
|
||||||
|
@ -581,6 +583,7 @@ def init_bedrock_client(
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
aws_role_name,
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
### SET REGION NAME
|
### SET REGION NAME
|
||||||
|
@ -619,7 +622,38 @@ def init_bedrock_client(
|
||||||
config = boto3.session.Config()
|
config = boto3.session.Config()
|
||||||
|
|
||||||
### CHECK STS ###
|
### CHECK STS ###
|
||||||
if aws_role_name is not None and aws_session_name is not None:
|
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||||
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts"
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
# use sts if role name passed in
|
# use sts if role name passed in
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
|
@ -752,6 +786,7 @@ def completion(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = optional_params.pop("aws_bedrock_client", None)
|
client = optional_params.pop("aws_bedrock_client", None)
|
||||||
|
@ -766,6 +801,7 @@ def completion(
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
aws_profile_name=aws_profile_name,
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
@ -1288,6 +1324,7 @@ def embedding(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = init_bedrock_client(
|
client = init_bedrock_client(
|
||||||
|
@ -1295,6 +1332,7 @@ def embedding(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
)
|
)
|
||||||
|
@ -1377,6 +1415,7 @@ def image_generation(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = init_bedrock_client(
|
client = init_bedrock_client(
|
||||||
|
@ -1384,6 +1423,7 @@ def image_generation(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
|
|
@ -206,6 +206,35 @@ def test_completion_bedrock_claude_sts_client_auth():
|
||||||
|
|
||||||
# test_completion_bedrock_claude_sts_client_auth()
|
# test_completion_bedrock_claude_sts_client_auth()
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="CIRCLE_OIDC_TOKEN_V2 is not set")
|
||||||
|
def test_completion_bedrock_claude_sts_oidc_auth():
|
||||||
|
print("\ncalling bedrock claude with oidc auth")
|
||||||
|
import os
|
||||||
|
|
||||||
|
aws_web_identity_token = "oidc/circleci_v2/"
|
||||||
|
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
|
aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name="my-test-session",
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_extra_headers():
|
def test_bedrock_extra_headers():
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue