fix - verify license without api request

This commit is contained in:
Ishaan Jaff 2024-06-25 13:55:54 -07:00
parent 6e02ac0056
commit 4c99010eee
3 changed files with 75 additions and 0 deletions

View file

@ -1,6 +1,14 @@
# What is this? # What is this?
## If litellm license in env, checks if it's valid ## If litellm license in env, checks if it's valid
import base64
import json
import os import os
from datetime import datetime
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
@ -15,6 +23,20 @@ class LicenseCheck:
def __init__(self) -> None: def __init__(self) -> None:
self.license_str = os.getenv("LITELLM_LICENSE", None) self.license_str = os.getenv("LITELLM_LICENSE", None)
self.http_handler = HTTPHandler() self.http_handler = HTTPHandler()
self.public_key = None
self.read_public_key()
def read_public_key(self):
# current dir
current_dir = os.path.dirname(os.path.realpath(__file__))
# check if public_key.pem exists
_path_to_public_key = os.path.join(current_dir, "public_key.pem")
if os.path.exists(_path_to_public_key):
with open(_path_to_public_key, "rb") as key_file:
self.public_key = serialization.load_pem_public_key(key_file.read())
else:
self.public_key = None
def _verify(self, license_str: str) -> bool: def _verify(self, license_str: str) -> bool:
url = "{}/verify_license/{}".format(self.base_url, license_str) url = "{}/verify_license/{}".format(self.base_url, license_str)
@ -35,11 +57,54 @@ class LicenseCheck:
return False return False
def is_premium(self) -> bool: def is_premium(self) -> bool:
"""
1. verify_license_without_api_request: checks if license was generate using private / public key pair
2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license
"""
try: try:
if self.license_str is None: if self.license_str is None:
return False return False
elif self.verify_license_without_api_request(
public_key=self.public_key, license_key=self.license_str
):
return True
elif self._verify(license_str=self.license_str): elif self._verify(license_str=self.license_str):
return True return True
return False return False
except Exception as e: except Exception as e:
return False return False
def verify_license_without_api_request(self, public_key, license_key):
try:
# Decode the license key
decoded = base64.b64decode(license_key)
message, signature = decoded.split(b".", 1)
# Verify the signature
public_key.verify(
signature,
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
# Decode and parse the data
license_data = json.loads(message.decode())
# debug information provided in license data
verbose_proxy_logger.debug("License data: %s", license_data)
# Check expiration date
expiration_date = datetime.strptime(
license_data["expiration_date"], "%Y-%m-%d"
)
if expiration_date < datetime.now():
return False, "License has expired"
return True
except Exception as e:
return False

View file

@ -0,0 +1,9 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmfBuNiNzDkNWyce23koQ
w0vq3bSVHkq7fd9Sw/U1q7FwRwL221daLTyGWssd8xAoQSFXAJKoBwzJQ9wd+o44
lfL54E3a61nfjZuF+D9ntpXZFfEAxLVtIahDeQjUz4b/EpgciWIJyUfjCJrQo6LY
eyAZPTGSO8V3zHyaU+CFywq5XCuCnfZqCZeCw051St59A2v8W32mXSCJ+A+x0hYP
yXJyRRFcefSFG5IBuRHr4Y24Vx7NUIAoco5cnxJho9g2z3J/Hb0GKW+oBNvRVumk
nuA2Ljmjh4yI0OoTIW8ZWxemvCCJHSjdfKlMyb+QI4fmeiIUZzP5Au+F561Styqq
YQIDAQAB
-----END PUBLIC KEY-----

View file

@ -31,6 +31,7 @@ azure-identity==1.16.1 # for azure content safety
opentelemetry-api==1.25.0 opentelemetry-api==1.25.0
opentelemetry-sdk==1.25.0 opentelemetry-sdk==1.25.0
opentelemetry-exporter-otlp==1.25.0 opentelemetry-exporter-otlp==1.25.0
cryptography==42.0.7
### LITELLM PACKAGE DEPENDENCIES ### LITELLM PACKAGE DEPENDENCIES
python-dotenv==1.0.0 # for env python-dotenv==1.0.0 # for env