forked from phoenix/litellm-mirror
Merge pull request #4409 from BerriAI/litellm_enforce_premium_license_in_vpc
enterprise - allow verifying license in air gapped vpc
This commit is contained in:
commit
38ec55bd95
3 changed files with 82 additions and 0 deletions
|
@ -1,6 +1,11 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## If litellm license in env, checks if it's valid
|
## If litellm license in env, checks if it's valid
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,6 +20,26 @@ class LicenseCheck:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||||
self.http_handler = HTTPHandler()
|
self.http_handler = HTTPHandler()
|
||||||
|
self.public_key = None
|
||||||
|
self.read_public_key()
|
||||||
|
|
||||||
|
def read_public_key(self):
|
||||||
|
try:
|
||||||
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||||
|
|
||||||
|
# current dir
|
||||||
|
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
# check if public_key.pem exists
|
||||||
|
_path_to_public_key = os.path.join(current_dir, "public_key.pem")
|
||||||
|
if os.path.exists(_path_to_public_key):
|
||||||
|
with open(_path_to_public_key, "rb") as key_file:
|
||||||
|
self.public_key = serialization.load_pem_public_key(key_file.read())
|
||||||
|
else:
|
||||||
|
self.public_key = None
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
|
||||||
|
|
||||||
def _verify(self, license_str: str) -> bool:
|
def _verify(self, license_str: str) -> bool:
|
||||||
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
||||||
|
@ -35,11 +60,58 @@ class LicenseCheck:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_premium(self) -> bool:
|
def is_premium(self) -> bool:
|
||||||
|
"""
|
||||||
|
1. verify_license_without_api_request: checks if license was generate using private / public key pair
|
||||||
|
2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if self.license_str is None:
|
if self.license_str is None:
|
||||||
return False
|
return False
|
||||||
|
elif self.verify_license_without_api_request(
|
||||||
|
public_key=self.public_key, license_key=self.license_str
|
||||||
|
):
|
||||||
|
return True
|
||||||
elif self._verify(license_str=self.license_str):
|
elif self._verify(license_str=self.license_str):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def verify_license_without_api_request(self, public_key, license_key):
|
||||||
|
try:
|
||||||
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||||
|
|
||||||
|
# Decode the license key
|
||||||
|
decoded = base64.b64decode(license_key)
|
||||||
|
message, signature = decoded.split(b".", 1)
|
||||||
|
|
||||||
|
# Verify the signature
|
||||||
|
public_key.verify(
|
||||||
|
signature,
|
||||||
|
message,
|
||||||
|
padding.PSS(
|
||||||
|
mgf=padding.MGF1(hashes.SHA256()),
|
||||||
|
salt_length=padding.PSS.MAX_LENGTH,
|
||||||
|
),
|
||||||
|
hashes.SHA256(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode and parse the data
|
||||||
|
license_data = json.loads(message.decode())
|
||||||
|
|
||||||
|
# debug information provided in license data
|
||||||
|
verbose_proxy_logger.debug("License data: %s", license_data)
|
||||||
|
|
||||||
|
# Check expiration date
|
||||||
|
expiration_date = datetime.strptime(
|
||||||
|
license_data["expiration_date"], "%Y-%m-%d"
|
||||||
|
)
|
||||||
|
if expiration_date < datetime.now():
|
||||||
|
return False, "License has expired"
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(str(e))
|
||||||
|
return False
|
||||||
|
|
9
litellm/proxy/auth/public_key.pem
Normal file
9
litellm/proxy/auth/public_key.pem
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
-----BEGIN PUBLIC KEY-----
|
||||||
|
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmfBuNiNzDkNWyce23koQ
|
||||||
|
w0vq3bSVHkq7fd9Sw/U1q7FwRwL221daLTyGWssd8xAoQSFXAJKoBwzJQ9wd+o44
|
||||||
|
lfL54E3a61nfjZuF+D9ntpXZFfEAxLVtIahDeQjUz4b/EpgciWIJyUfjCJrQo6LY
|
||||||
|
eyAZPTGSO8V3zHyaU+CFywq5XCuCnfZqCZeCw051St59A2v8W32mXSCJ+A+x0hYP
|
||||||
|
yXJyRRFcefSFG5IBuRHr4Y24Vx7NUIAoco5cnxJho9g2z3J/Hb0GKW+oBNvRVumk
|
||||||
|
nuA2Ljmjh4yI0OoTIW8ZWxemvCCJHSjdfKlMyb+QI4fmeiIUZzP5Au+F561Styqq
|
||||||
|
YQIDAQAB
|
||||||
|
-----END PUBLIC KEY-----
|
|
@ -31,6 +31,7 @@ azure-identity==1.16.1 # for azure content safety
|
||||||
opentelemetry-api==1.25.0
|
opentelemetry-api==1.25.0
|
||||||
opentelemetry-sdk==1.25.0
|
opentelemetry-sdk==1.25.0
|
||||||
opentelemetry-exporter-otlp==1.25.0
|
opentelemetry-exporter-otlp==1.25.0
|
||||||
|
cryptography==42.0.7
|
||||||
|
|
||||||
### LITELLM PACKAGE DEPENDENCIES
|
### LITELLM PACKAGE DEPENDENCIES
|
||||||
python-dotenv==1.0.0 # for env
|
python-dotenv==1.0.0 # for env
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue