feat(util.py): Add OIDC support.

This commit is contained in:
David Manouchehri 2024-05-07 15:02:37 +00:00
parent ee1b1fe4f8
commit 4b655d8b33
No known key found for this signature in database
2 changed files with 87 additions and 0 deletions

View file

@ -33,6 +33,7 @@ from dataclasses import (
)
import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler
try:
# this works in python 3.8
@ -9288,6 +9289,59 @@ def get_secret(
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
if secret_name.startswith("oidc/"):
secret_name = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name.split("/", 1)
# TODO: Add caching for HTTP requests
match oidc_provider:
case "google":
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"},
)
if response.status_code == 200:
return response.text
else:
raise ValueError("Google OIDC provider failed")
case "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
case "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
case "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if actions_id_token_request_url is None or actions_id_token_request_token is None:
raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment")
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
headers={
"Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0",
},
)
if response.status_code == 200:
return response.text['value']
else:
raise ValueError("Github OIDC provider failed")
case _:
raise ValueError("Unsupported OIDC provider")
try:
if litellm.secret_manager_client is not None:
try: