Merge branch 'main' into litellm_batch_write_redis_cache

This commit is contained in:
Ishaan Jaff 2024-03-25 16:41:29 -07:00 committed by GitHub
commit 7134d66fae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 359 additions and 51 deletions

View file

@ -32,8 +32,9 @@ litellm_settings:
cache: True # set cache responses to True, litellm defaults to using a redis cache cache: True # set cache responses to True, litellm defaults to using a redis cache
``` ```
#### [OPTIONAL] Step 1.5: Add redis namespaces #### [OPTIONAL] Step 1.5: Add redis namespaces, default ttl
## Namespace
If you want to create some folder for your keys, you can set a namespace, like this: If you want to create some folder for your keys, you can set a namespace, like this:
```yaml ```yaml
@ -50,6 +51,16 @@ and keys will be stored like:
litellm_caching:<hash> litellm_caching:<hash>
``` ```
## TTL
```yaml
litellm_settings:
cache: true
cache_params: # set cache params for redis
type: redis
ttl: 600 # will be cached on redis for 600s
```
#### Step 2: Add Redis Credentials to .env #### Step 2: Add Redis Credentials to .env
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching. Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.

View file

@ -107,4 +107,38 @@ general_settings:
master_key: sk-1234 master_key: sk-1234
enable_jwt_auth: True enable_jwt_auth: True
allowed_routes: ["/chat/completions", "/embeddings"] allowed_routes: ["/chat/completions", "/embeddings"]
```
## Advanced - Set Accepted JWT Scope Names
Change the string in JWT 'scopes', that litellm evaluates to see if a user has admin access.
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_proxy_roles:
proxy_admin: "litellm-proxy-admin"
```
### Allowed LiteLLM scopes
```python
class LiteLLMProxyRoles(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin"
proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth.
```
### JWT Scopes
Here's what scopes on JWT-Auth tokens look like
**Can be a list**
```
scope: ["litellm-proxy-admin",...]
```
**Can be a space-separated string**
```
scope: "litellm-proxy-admin ..."
``` ```

View file

@ -56,6 +56,7 @@ baseten_key: Optional[str] = None
aleph_alpha_key: Optional[str] = None aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None nlp_cloud_key: Optional[str] = None
use_client: bool = False use_client: bool = False
disable_streaming_logging: bool = False
### GUARDRAILS ### ### GUARDRAILS ###
llamaguard_model_name: Optional[str] = None llamaguard_model_name: Optional[str] = None
presidio_ad_hoc_recognizers: Optional[str] = None presidio_ad_hoc_recognizers: Optional[str] = None

View file

@ -899,6 +899,7 @@ class Cache:
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
ttl: Optional[float] = None,
similarity_threshold: Optional[float] = None, similarity_threshold: Optional[float] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[ List[
@ -996,6 +997,7 @@ class Cache:
self.type = type self.type = type
self.namespace = namespace self.namespace = namespace
self.redis_flush_size = redis_flush_size self.redis_flush_size = redis_flush_size
self.ttl = ttl
def get_cache_key(self, *args, **kwargs): def get_cache_key(self, *args, **kwargs):
""" """
@ -1235,6 +1237,9 @@ class Cache:
if isinstance(result, OpenAIObject): if isinstance(result, OpenAIObject):
result = result.model_dump_json() result = result.model_dump_json()
## DEFAULT TTL ##
if self.ttl is not None:
kwargs["ttl"] = self.ttl
## Get Cache-Controls ## ## Get Cache-Controls ##
if kwargs.get("cache", None) is not None and isinstance( if kwargs.get("cache", None) is not None and isinstance(
kwargs.get("cache"), dict kwargs.get("cache"), dict
@ -1242,6 +1247,7 @@ class Cache:
for k, v in kwargs.get("cache").items(): for k, v in kwargs.get("cache").items():
if k == "ttl": if k == "ttl":
kwargs["ttl"] = v kwargs["ttl"] = v
cached_data = {"timestamp": time.time(), "response": result} cached_data = {"timestamp": time.time(), "response": result}
return cache_key, cached_data, kwargs return cache_key, cached_data, kwargs
else: else:

View file

@ -14,11 +14,6 @@ def hash_token(token: str):
return hashed_token return hashed_token
class LiteLLMProxyRoles(enum.Enum):
PROXY_ADMIN = "litellm_proxy_admin"
USER = "litellm_user"
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
""" """
Implements default functions, all pydantic objects should have. Implements default functions, all pydantic objects should have.
@ -42,6 +37,11 @@ class LiteLLMBase(BaseModel):
protected_namespaces = () protected_namespaces = ()
class LiteLLMProxyRoles(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin"
proxy_user: str = "litellm_user"
class LiteLLMPromptInjectionParams(LiteLLMBase): class LiteLLMPromptInjectionParams(LiteLLMBase):
heuristics_check: bool = False heuristics_check: bool = False
vector_db_check: bool = False vector_db_check: bool = False

View file

@ -67,17 +67,21 @@ class JWTHandler:
self.http_handler = HTTPHandler() self.http_handler = HTTPHandler()
def update_environment( def update_environment(
self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache self,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
litellm_proxy_roles: LiteLLMProxyRoles,
) -> None: ) -> None:
self.prisma_client = prisma_client self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache self.user_api_key_cache = user_api_key_cache
self.litellm_proxy_roles = litellm_proxy_roles
def is_jwt(self, token: str): def is_jwt(self, token: str):
parts = token.split(".") parts = token.split(".")
return len(parts) == 3 return len(parts) == 3
def is_admin(self, scopes: list) -> bool: def is_admin(self, scopes: list) -> bool:
if LiteLLMProxyRoles.PROXY_ADMIN.value in scopes: if self.litellm_proxy_roles.proxy_admin in scopes:
return True return True
return False return False
@ -90,7 +94,7 @@ class JWTHandler:
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
team_id = token["azp"] team_id = token["client_id"]
except KeyError: except KeyError:
team_id = default_value team_id = default_value
return team_id return team_id
@ -130,58 +134,94 @@ class JWTHandler:
def get_scopes(self, token: dict) -> list: def get_scopes(self, token: dict) -> list:
try: try:
# Assuming the scopes are stored in 'scope' claim and are space-separated if isinstance(token["scope"], str):
scopes = token["scope"].split() # Assuming the scopes are stored in 'scope' claim and are space-separated
scopes = token["scope"].split()
elif isinstance(token["scope"], list):
scopes = token["scope"]
else:
raise Exception(
f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str."
)
except KeyError: except KeyError:
scopes = [] scopes = []
return scopes return scopes
async def auth_jwt(self, token: str) -> dict: async def get_public_key(self, kid: Optional[str]) -> dict:
from jwt.algorithms import RSAAlgorithm
keys_url = os.getenv("JWT_PUBLIC_KEY_URL") keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
if keys_url is None: if keys_url is None:
raise Exception("Missing JWT Public Key URL from environment.") raise Exception("Missing JWT Public Key URL from environment.")
response = await self.http_handler.get(keys_url) cached_keys = await self.user_api_key_cache.async_get_cache(
"litellm_jwt_auth_keys"
)
if cached_keys is None:
response = await self.http_handler.get(keys_url)
keys = response.json()["keys"] keys = response.json()["keys"]
await self.user_api_key_cache.async_set_cache(
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
)
else:
keys = cached_keys
public_key: Optional[dict] = None
if len(keys) == 1:
public_key = keys[0]
elif len(keys) > 1:
for key in keys:
if kid is not None and key["kid"] == kid:
public_key = key
if public_key is None:
raise Exception(
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}"
)
return public_key
async def auth_jwt(self, token: str) -> dict:
from jwt.algorithms import RSAAlgorithm
header = jwt.get_unverified_header(token) header = jwt.get_unverified_header(token)
verbose_proxy_logger.debug("header: %s", header) verbose_proxy_logger.debug("header: %s", header)
if "kid" in header: kid = header.get("kid", None)
kid = header["kid"]
else:
raise Exception(f"Expected 'kid' in header. header={header}.")
for key in keys: public_key = await self.get_public_key(kid=kid)
if key["kid"] == kid:
jwk = {
"kty": key["kty"],
"kid": key["kid"],
"n": key["n"],
"e": key["e"],
}
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))
try: if public_key is not None and isinstance(public_key, dict):
# decode the token using the public key jwk = {}
payload = jwt.decode( if "kty" in public_key:
token, jwk["kty"] = public_key["kty"]
public_key, # type: ignore if "kid" in public_key:
algorithms=["RS256"], jwk["kid"] = public_key["kid"]
audience="account", if "n" in public_key:
) jwk["n"] = public_key["n"]
return payload if "e" in public_key:
jwk["e"] = public_key["e"]
except jwt.ExpiredSignatureError: public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
# the token is expired, do something to refresh it
raise Exception("Token Expired") try:
except Exception as e: # decode the token using the public key
raise Exception(f"Validation fails: {str(e)}") payload = jwt.decode(
token,
public_key_rsa, # type: ignore
algorithms=["RS256"],
options={"verify_aud": False},
)
return payload
except jwt.ExpiredSignatureError:
# the token is expired, do something to refresh it
raise Exception("Token Expired")
except Exception as e:
raise Exception(f"Validation fails: {str(e)}")
raise Exception("Invalid JWT Submitted") raise Exception("Invalid JWT Submitted")

View file

@ -2710,7 +2710,11 @@ async def startup_event():
## JWT AUTH ## ## JWT AUTH ##
jwt_handler.update_environment( jwt_handler.update_environment(
prisma_client=prisma_client, user_api_key_cache=user_api_key_cache prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
litellm_proxy_roles=LiteLLMProxyRoles(
**general_settings.get("litellm_proxy_roles", {})
),
) )
if use_background_health_checks: if use_background_health_checks:

View file

@ -116,6 +116,23 @@ def test_caching_with_ttl():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_caching_with_default_ttl():
try:
litellm.set_verbose = True
litellm.cache = Cache(ttl=0)
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
litellm.success_callback = []
litellm._async_success_callback = []
assert response2["id"] != response1["id"]
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
def test_caching_with_cache_controls(): def test_caching_with_cache_controls():
try: try:
litellm.set_verbose = True litellm.set_verbose = True

179
litellm/tests/test_jwt.py Normal file
View file

@ -0,0 +1,179 @@
#### What this tests ####
# Unit tests for JWT-Auth
import sys, os, asyncio, time, random
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm.proxy._types import LiteLLMProxyRoles
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.caching import DualCache
from datetime import datetime, timedelta
public_key = {
"kty": "RSA",
"e": "AQAB",
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
"alg": "RS256",
}
def test_load_config_with_custom_role_names():
config = {
"general_settings": {
"litellm_proxy_roles": {"proxy_admin": "litellm-proxy-admin"}
}
}
proxy_roles = LiteLLMProxyRoles(
**config.get("general_settings", {}).get("litellm_proxy_roles", {})
)
print(f"proxy_roles: {proxy_roles}")
assert proxy_roles.proxy_admin == "litellm-proxy-admin"
# test_load_config_with_custom_role_names()
@pytest.mark.asyncio
async def test_token_single_public_key():
import jwt
jwt_handler = JWTHandler()
backend_keys = {
"keys": [
{
"kty": "RSA",
"use": "sig",
"e": "AQAB",
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ",
"alg": "RS256",
}
]
}
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"])
jwt_handler.user_api_key_cache = cache
public_key = await jwt_handler.get_public_key(kid=None)
assert public_key is not None
assert isinstance(public_key, dict)
assert (
public_key["n"]
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
)
@pytest.mark.asyncio
async def test_valid_invalid_token():
"""
Tests
- valid token
- invalid token
"""
import jwt, json
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
assert isinstance(public_jwk, dict)
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-proxy-admin",
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
response = await jwt_handler.auth_jwt(token=token)
assert response is not None
assert isinstance(response, dict)
print(f"response: {response}")
# INVALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-NO-SCOPE",
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# verify token
try:
response = await jwt_handler.auth_jwt(token=token)
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")

View file

@ -9617,15 +9617,31 @@ class CustomStreamWrapper:
) )
def set_logging_event_loop(self, loop): def set_logging_event_loop(self, loop):
"""
import litellm, asyncio
loop = asyncio.get_event_loop() # 👈 gets the current event loop
response = litellm.completion(.., stream=True)
response.set_logging_event_loop(loop=loop) # 👈 enables async_success callbacks for sync logging
for chunk in response:
...
"""
self.logging_loop = loop self.logging_loop = loop
async def your_async_function(self):
# Your asynchronous code here
return "Your asynchronous code is running"
def run_success_logging_in_thread(self, processed_chunk): def run_success_logging_in_thread(self, processed_chunk):
# Create an event loop for the new thread if litellm.disable_streaming_logging == True:
"""
[NOT RECOMMENDED]
Set this via `litellm.disable_streaming_logging = True`.
Disables streaming logging.
"""
return
## ASYNC LOGGING ## ASYNC LOGGING
# Create an event loop for the new thread
if self.logging_loop is not None: if self.logging_loop is not None:
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(
self.logging_obj.async_success_handler(processed_chunk), self.logging_obj.async_success_handler(processed_chunk),

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.34.1" version = "1.34.2"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.34.1" version = "1.34.2"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]