forked from phoenix/litellm-mirror
feat(utils.py) - Add OIDC caching for Google Cloud Run and GitHub Actions.
This commit is contained in:
parent
cb49fb004d
commit
44b1b21911
1 changed files with 19 additions and 4 deletions
|
@ -34,6 +34,8 @@ from dataclasses import (
|
||||||
|
|
||||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
oidc_cache = DualCache()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# this works in python 3.8
|
# this works in python 3.8
|
||||||
|
@ -9291,11 +9293,15 @@ def get_secret(
|
||||||
|
|
||||||
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||||
if secret_name.startswith("oidc/"):
|
if secret_name.startswith("oidc/"):
|
||||||
secret_name = secret_name.replace("oidc/", "")
|
secret_name_split = secret_name.replace("oidc/", "")
|
||||||
oidc_provider, oidc_aud = secret_name.split("/", 1)
|
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||||
# TODO: Add caching for HTTP requests
|
# TODO: Add caching for HTTP requests
|
||||||
match oidc_provider:
|
match oidc_provider:
|
||||||
case "google":
|
case "google":
|
||||||
|
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||||
|
if oidc_token is not None:
|
||||||
|
return oidc_token
|
||||||
|
|
||||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||||
response = client.get(
|
response = client.get(
|
||||||
|
@ -9304,7 +9310,9 @@ def get_secret(
|
||||||
headers={"Metadata-Flavor": "Google"},
|
headers={"Metadata-Flavor": "Google"},
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.text
|
oidc_token = response.text
|
||||||
|
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||||
|
return oidc_token
|
||||||
else:
|
else:
|
||||||
raise ValueError("Google OIDC provider failed")
|
raise ValueError("Google OIDC provider failed")
|
||||||
case "circleci":
|
case "circleci":
|
||||||
|
@ -9325,6 +9333,11 @@ def get_secret(
|
||||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||||
if actions_id_token_request_url is None or actions_id_token_request_token is None:
|
if actions_id_token_request_url is None or actions_id_token_request_token is None:
|
||||||
raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment")
|
raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment")
|
||||||
|
|
||||||
|
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||||
|
if oidc_token is not None:
|
||||||
|
return oidc_token
|
||||||
|
|
||||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
response = client.get(
|
response = client.get(
|
||||||
actions_id_token_request_url,
|
actions_id_token_request_url,
|
||||||
|
@ -9335,7 +9348,9 @@ def get_secret(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.text['value']
|
oidc_token = response.text['value']
|
||||||
|
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||||
|
return oidc_token
|
||||||
else:
|
else:
|
||||||
raise ValueError("Github OIDC provider failed")
|
raise ValueError("Github OIDC provider failed")
|
||||||
case _:
|
case _:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue