forked from phoenix/litellm-mirror
feat(util.py): Add OIDC support.
This commit is contained in:
parent
ee1b1fe4f8
commit
4b655d8b33
2 changed files with 87 additions and 0 deletions
|
@ -23,3 +23,36 @@ def test_aws_secret_manager():
|
||||||
print(f"secret_val: {secret_val}")
|
print(f"secret_val: {secret_val}")
|
||||||
|
|
||||||
assert secret_val == "sk-1234"
|
assert secret_val == "sk-1234"
|
||||||
|
|
||||||
|
|
||||||
|
def redact_oidc_signature(secret_val):
|
||||||
|
# remove the last part of `.` and replace it with "SIGNATURE_REMOVED"
|
||||||
|
return secret_val.split(".")[:-1] + ["SIGNATURE_REMOVED"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('K_SERVICE') is None, reason="Cannot run without being in GCP Cloud Run")
|
||||||
|
def test_oidc_google():
|
||||||
|
secret_val = get_secret("oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||||
|
|
||||||
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('ACTIONS_ID_TOKEN_REQUEST_TOKEN') is None, reason="Cannot run without being in GitHub Actions")
|
||||||
|
def test_oidc_github():
|
||||||
|
secret_val = get_secret("oidc/github/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||||
|
|
||||||
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN') is None, reason="Cannot run without being in a CircleCI Runner")
|
||||||
|
def test_oidc_circleci():
|
||||||
|
secret_val = get_secret("oidc/circleci/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||||
|
|
||||||
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="Cannot run without being in a CircleCI Runner")
|
||||||
|
def test_oidc_circleci_v2():
|
||||||
|
secret_val = get_secret("oidc/circleci_v2/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke")
|
||||||
|
|
||||||
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
|
@ -33,6 +33,7 @@ from dataclasses import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# this works in python 3.8
|
# this works in python 3.8
|
||||||
|
@ -9288,6 +9289,59 @@ def get_secret(
|
||||||
if secret_name.startswith("os.environ/"):
|
if secret_name.startswith("os.environ/"):
|
||||||
secret_name = secret_name.replace("os.environ/", "")
|
secret_name = secret_name.replace("os.environ/", "")
|
||||||
|
|
||||||
|
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||||
|
if secret_name.startswith("oidc/"):
|
||||||
|
secret_name = secret_name.replace("oidc/", "")
|
||||||
|
oidc_provider, oidc_aud = secret_name.split("/", 1)
|
||||||
|
# TODO: Add caching for HTTP requests
|
||||||
|
match oidc_provider:
|
||||||
|
case "google":
|
||||||
|
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
|
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||||
|
response = client.get(
|
||||||
|
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||||
|
params={"audience": oidc_aud},
|
||||||
|
headers={"Metadata-Flavor": "Google"},
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.text
|
||||||
|
else:
|
||||||
|
raise ValueError("Google OIDC provider failed")
|
||||||
|
case "circleci":
|
||||||
|
# https://circleci.com/docs/openid-connect-tokens/
|
||||||
|
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||||
|
if env_secret is None:
|
||||||
|
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||||
|
return env_secret
|
||||||
|
case "circleci_v2":
|
||||||
|
# https://circleci.com/docs/openid-connect-tokens/
|
||||||
|
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||||
|
if env_secret is None:
|
||||||
|
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||||
|
return env_secret
|
||||||
|
case "github":
|
||||||
|
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||||
|
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||||
|
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||||
|
if actions_id_token_request_url is None or actions_id_token_request_token is None:
|
||||||
|
raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment")
|
||||||
|
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
|
response = client.get(
|
||||||
|
actions_id_token_request_url,
|
||||||
|
params={"audience": oidc_aud},
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||||
|
"Accept": "application/json; api-version=2.0",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.text['value']
|
||||||
|
else:
|
||||||
|
raise ValueError("Github OIDC provider failed")
|
||||||
|
case _:
|
||||||
|
raise ValueError("Unsupported OIDC provider")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if litellm.secret_manager_client is not None:
|
if litellm.secret_manager_client is not None:
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue