Merge pull request #3507 from Manouchehri/oidc-3505-part-1

Initial OIDC support (Google/GitHub/CircleCI -> Amazon Bedrock & Azure OpenAI)
This commit is contained in:
Krish Dholakia 2024-05-11 09:25:17 -07:00 committed by GitHub
commit 40063798bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 248 additions and 1 deletions

View file

@ -8,6 +8,7 @@ from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
TranscriptionResponse,
get_secret,
)
from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig
@ -16,6 +17,7 @@ import httpx # type: ignore
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid
import os
class AzureOpenAIError(Exception):
@ -126,6 +128,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
return azure_client_params
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
if azure_client_id is None or azure_tenant is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
req_token = httpx.post(
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": "https://cognitiveservices.azure.com/.default",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
possible_azure_ad_token = req_token.json().get("access_token", None)
if possible_azure_ad_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token not returned"
)
return possible_azure_ad_token
class AzureChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
@ -137,6 +184,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None:
headers["api-key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
@ -189,6 +238,9 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if acompletion is True:
@ -276,6 +328,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AzureOpenAI(**azure_client_params)
@ -351,6 +405,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
# setting Azure client
@ -422,6 +478,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AzureOpenAI(**azure_client_params)
@ -478,6 +536,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
@ -599,6 +659,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
## LOGGING
@ -755,6 +817,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True:
@ -833,6 +897,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if max_retries is not None:

View file

@ -551,6 +551,7 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
extra_headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
@ -567,6 +568,7 @@ def init_bedrock_client(
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
@ -582,6 +584,7 @@ def init_bedrock_client(
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### SET REGION NAME
@ -620,7 +623,38 @@ def init_bedrock_client(
config = boto3.session.Config()
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts"
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in
sts_client = boto3.client(
"sts",
@ -755,6 +789,7 @@ def completion(
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop("aws_bedrock_client", None)
@ -769,6 +804,7 @@ def completion(
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_web_identity_token=aws_web_identity_token,
extra_headers=extra_headers,
timeout=timeout,
)
@ -1291,6 +1327,7 @@ def embedding(
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
@ -1298,6 +1335,7 @@ def embedding(
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
)
@ -1380,6 +1418,7 @@ def image_generation(
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
@ -1387,6 +1426,7 @@ def image_generation(
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
timeout=timeout,

View file

@ -48,6 +48,7 @@ from litellm.types.router import (
AlertingConfig,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
class Router:
@ -2115,6 +2116,10 @@ class Router:
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
)
azure_ad_token = litellm_params.get("azure_ad_token")
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None:
api_version = "2023-07-01-preview"
@ -2126,6 +2131,7 @@ class Router:
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
@ -2150,6 +2156,7 @@ class Router:
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
@ -2174,6 +2181,7 @@ class Router:
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
@ -2198,6 +2206,7 @@ class Router:
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
@ -2230,6 +2239,7 @@ class Router:
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
}
from litellm.llms.azure import select_azure_base_url_or_endpoint

View file

@ -206,6 +206,35 @@ def test_completion_bedrock_claude_sts_client_auth():
# test_completion_bedrock_claude_sts_client_auth()
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="CIRCLE_OIDC_TOKEN_V2 is not set")
def test_completion_bedrock_claude_sts_oidc_auth():
print("\ncalling bedrock claude with oidc auth")
import os
aws_web_identity_token = "oidc/circleci_v2/"
aws_region_name = os.environ["AWS_REGION_NAME"]
aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
try:
litellm.set_verbose = True
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_region_name=aws_region_name,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
# Add any assertions here to check the response
print(response)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_extra_headers():
try:

View file

@ -23,3 +23,36 @@ def test_aws_secret_manager():
print(f"secret_val: {secret_val}")
assert secret_val == "sk-1234"
def redact_oidc_signature(secret_val):
# remove the last part of `.` and replace it with "SIGNATURE_REMOVED"
return secret_val.split(".")[:-1] + ["SIGNATURE_REMOVED"]
@pytest.mark.skipif(os.environ.get('K_SERVICE') is None, reason="Cannot run without being in GCP Cloud Run")
def test_oidc_google():
secret_val = get_secret("oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(os.environ.get('ACTIONS_ID_TOKEN_REQUEST_TOKEN') is None, reason="Cannot run without being in GitHub Actions")
def test_oidc_github():
secret_val = get_secret("oidc/github/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN') is None, reason="Cannot run without being in a CircleCI Runner")
def test_oidc_circleci():
secret_val = get_secret("oidc/circleci/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="Cannot run without being in a CircleCI Runner")
def test_oidc_circleci_v2():
secret_val = get_secret("oidc/circleci_v2/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
print(f"secret_val: {redact_oidc_signature(secret_val)}")

View file

@ -33,6 +33,9 @@ from dataclasses import (
)
import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.caching import DualCache
oidc_cache = DualCache()
try:
# this works in python 3.8
@ -9411,6 +9414,72 @@ def get_secret(
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
# TODO: Add caching for HTTP requests
match oidc_provider:
case "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"},
)
if response.status_code == 200:
oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
return oidc_token
else:
raise ValueError("Google OIDC provider failed")
case "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
case "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
case "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if actions_id_token_request_url is None or actions_id_token_request_token is None:
raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment")
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
headers={
"Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0",
},
)
if response.status_code == 200:
oidc_token = response.text['value']
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
case _:
raise ValueError("Unsupported OIDC provider")
try:
if litellm.secret_manager_client is not None:
try: