forked from phoenix/litellm-mirror
Merge pull request #3507 from Manouchehri/oidc-3505-part-1
Initial OIDC support (Google/GitHub/CircleCI -> Amazon Bedrock & Azure OpenAI)
This commit is contained in:
commit
40063798bd
6 changed files with 248 additions and 1 deletions
|
@ -8,6 +8,7 @@ from litellm.utils import (
|
|||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
TranscriptionResponse,
|
||||
get_secret,
|
||||
)
|
||||
from typing import Callable, Optional, BinaryIO
|
||||
from litellm import OpenAIConfig
|
||||
|
@ -16,6 +17,7 @@ import httpx # type: ignore
|
|||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
import uuid
|
||||
import os
|
||||
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
|
@ -126,6 +128,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
|||
return azure_client_params
|
||||
|
||||
|
||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
|
||||
|
||||
if azure_client_id is None or azure_tenant is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422,
|
||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||
)
|
||||
|
||||
oidc_token = get_secret(azure_ad_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=401,
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
req_token = httpx.post(
|
||||
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
|
||||
data={
|
||||
"client_id": azure_client_id,
|
||||
"grant_type": "client_credentials",
|
||||
"scope": "https://cognitiveservices.azure.com/.default",
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": oidc_token,
|
||||
},
|
||||
)
|
||||
|
||||
if req_token.status_code != 200:
|
||||
raise AzureOpenAIError(
|
||||
status_code=req_token.status_code,
|
||||
message=req_token.text,
|
||||
)
|
||||
|
||||
possible_azure_ad_token = req_token.json().get("access_token", None)
|
||||
|
||||
if possible_azure_ad_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token not returned"
|
||||
)
|
||||
|
||||
return possible_azure_ad_token
|
||||
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
@ -137,6 +184,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
headers["api-key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
return headers
|
||||
|
||||
|
@ -189,6 +238,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if acompletion is True:
|
||||
|
@ -276,6 +328,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
|
@ -351,6 +405,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
# setting Azure client
|
||||
|
@ -422,6 +478,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
|
@ -478,6 +536,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
|
@ -599,6 +659,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
## LOGGING
|
||||
|
@ -755,6 +817,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if aimg_generation == True:
|
||||
|
@ -833,6 +897,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if max_retries is not None:
|
||||
|
|
|
@ -551,6 +551,7 @@ def init_bedrock_client(
|
|||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
|
@ -567,6 +568,7 @@ def init_bedrock_client(
|
|||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
|
@ -582,6 +584,7 @@ def init_bedrock_client(
|
|||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
### SET REGION NAME
|
||||
|
@ -620,7 +623,38 @@ def init_bedrock_client(
|
|||
config = boto3.session.Config()
|
||||
|
||||
### CHECK STS ###
|
||||
if aws_role_name is not None and aws_session_name is not None:
|
||||
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client(
|
||||
"sts"
|
||||
)
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
)
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
|
@ -755,6 +789,7 @@ def completion(
|
|||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = optional_params.pop("aws_bedrock_client", None)
|
||||
|
@ -769,6 +804,7 @@ def completion(
|
|||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
@ -1291,6 +1327,7 @@ def embedding(
|
|||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
|
@ -1298,6 +1335,7 @@ def embedding(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
)
|
||||
|
@ -1380,6 +1418,7 @@ def image_generation(
|
|||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
|
@ -1387,6 +1426,7 @@ def image_generation(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
timeout=timeout,
|
||||
|
|
|
@ -48,6 +48,7 @@ from litellm.types.router import (
|
|||
AlertingConfig,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -2115,6 +2116,10 @@ class Router:
|
|||
raise ValueError(
|
||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||
)
|
||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||
if azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
if api_version is None:
|
||||
api_version = "2023-07-01-preview"
|
||||
|
||||
|
@ -2126,6 +2131,7 @@ class Router:
|
|||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -2150,6 +2156,7 @@ class Router:
|
|||
cache_key = f"{model_id}_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -2174,6 +2181,7 @@ class Router:
|
|||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -2198,6 +2206,7 @@ class Router:
|
|||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -2230,6 +2239,7 @@ class Router:
|
|||
"api_key": api_key,
|
||||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
}
|
||||
from litellm.llms.azure import select_azure_base_url_or_endpoint
|
||||
|
||||
|
|
|
@ -206,6 +206,35 @@ def test_completion_bedrock_claude_sts_client_auth():
|
|||
|
||||
# test_completion_bedrock_claude_sts_client_auth()
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="CIRCLE_OIDC_TOKEN_V2 is not set")
|
||||
def test_completion_bedrock_claude_sts_oidc_auth():
|
||||
print("\ncalling bedrock claude with oidc auth")
|
||||
import os
|
||||
|
||||
aws_web_identity_token = "oidc/circleci_v2/"
|
||||
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||
aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
||||
response = completion(
|
||||
model="bedrock/anthropic.claude-instant-v1",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.1,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name="my-test-session",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_bedrock_extra_headers():
|
||||
try:
|
||||
|
|
|
@ -23,3 +23,36 @@ def test_aws_secret_manager():
|
|||
print(f"secret_val: {secret_val}")
|
||||
|
||||
assert secret_val == "sk-1234"
|
||||
|
||||
|
||||
def redact_oidc_signature(secret_val):
|
||||
# remove the last part of `.` and replace it with "SIGNATURE_REMOVED"
|
||||
return secret_val.split(".")[:-1] + ["SIGNATURE_REMOVED"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('K_SERVICE') is None, reason="Cannot run without being in GCP Cloud Run")
|
||||
def test_oidc_google():
|
||||
secret_val = get_secret("oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('ACTIONS_ID_TOKEN_REQUEST_TOKEN') is None, reason="Cannot run without being in GitHub Actions")
|
||||
def test_oidc_github():
|
||||
secret_val = get_secret("oidc/github/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN') is None, reason="Cannot run without being in a CircleCI Runner")
|
||||
def test_oidc_circleci():
|
||||
secret_val = get_secret("oidc/circleci/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="Cannot run without being in a CircleCI Runner")
|
||||
def test_oidc_circleci_v2():
|
||||
secret_val = get_secret("oidc/circleci_v2/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
|
|
@ -33,6 +33,9 @@ from dataclasses import (
|
|||
)
|
||||
|
||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.caching import DualCache
|
||||
oidc_cache = DualCache()
|
||||
|
||||
try:
|
||||
# this works in python 3.8
|
||||
|
@ -9411,6 +9414,72 @@ def get_secret(
|
|||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||
if secret_name.startswith("oidc/"):
|
||||
secret_name_split = secret_name.replace("oidc/", "")
|
||||
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||
# TODO: Add caching for HTTP requests
|
||||
match oidc_provider:
|
||||
case "google":
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||
response = client.get(
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||
params={"audience": oidc_aud},
|
||||
headers={"Metadata-Flavor": "Google"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Google OIDC provider failed")
|
||||
case "circleci":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||
return env_secret
|
||||
case "circleci_v2":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||
return env_secret
|
||||
case "github":
|
||||
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||
if actions_id_token_request_url is None or actions_id_token_request_token is None:
|
||||
raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment")
|
||||
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = client.get(
|
||||
actions_id_token_request_url,
|
||||
params={"audience": oidc_aud},
|
||||
headers={
|
||||
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||
"Accept": "application/json; api-version=2.0",
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text['value']
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Github OIDC provider failed")
|
||||
case _:
|
||||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
|
||||
try:
|
||||
if litellm.secret_manager_client is not None:
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue