mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* fix(converse_transformation.py): handle cross region model name when getting openai param support Fixes https://github.com/BerriAI/litellm/issues/6291 * LiteLLM Minor Fixes & Improvements (10/17/2024) (#6293) * fix(ui_sso.py): fix faulty admin only check Fixes https://github.com/BerriAI/litellm/issues/6286 * refactor(sso_helper_utils.py): refactor /sso/callback to use helper utils, covered by unit testing Prevent future regressions * feat(prompt_factory): support 'ensure_alternating_roles' param Closes https://github.com/BerriAI/litellm/issues/6257 * fix(proxy/utils.py): add dailytagspend to expected views * feat(auth_utils.py): support setting regex for clientside auth credentials Fixes https://github.com/BerriAI/litellm/issues/6203 * build(cookbook): add tutorial for mlflow + langchain + litellm proxy tracing * feat(argilla.py): add argilla logging integration Closes https://github.com/BerriAI/litellm/issues/6201 * fix: fix linting errors * fix: fix ruff error * test: fix test * fix: update vertex ai assumption - parts not always guaranteed (#6296) * docs(configs.md): add argila env var to docs * docs(user_keys.md): add regex doc for clientside auth params * docs(argilla.md): add doc on argilla logging * docs(argilla.md): add sampling rate to argilla calls * bump: version 1.49.6 → 1.49.7 * add gpt-4o-audio models to model cost map (#6306) * (code quality) add ruff check PLR0915 for `too-many-statements` (#6309) * ruff add PLR0915 * add noqa for PLR0915 * fix noqa * add # noqa: PLR0915 * # noqa: PLR0915 * # noqa: PLR0915 * # noqa: PLR0915 * add # noqa: PLR0915 * # noqa: PLR0915 * # noqa: PLR0915 * # noqa: PLR0915 * # noqa: PLR0915 * doc fix Turn on / off caching per Key. (#6297) * (feat) Support `audio`, `modalities` params (#6304) * add audio, modalities param * add test for gpt audio models * add get_supported_openai_params for GPT audio models * add supported params for audio * test_audio_output_from_model * bump openai to openai==1.52.0 * bump openai on pyproject * fix audio test * fix test mock_chat_response * handle audio for Message * fix handling audio for OAI compatible API endpoints * fix linting * fix mock dbrx test * (feat) Support audio param in responses streaming (#6312) * add audio, modalities param * add test for gpt audio models * add get_supported_openai_params for GPT audio models * add supported params for audio * test_audio_output_from_model * bump openai to openai==1.52.0 * bump openai on pyproject * fix audio test * fix test mock_chat_response * handle audio for Message * fix handling audio for OAI compatible API endpoints * fix linting * fix mock dbrx test * add audio to Delta * handle model_response.choices.delta.audio * fix linting * build(model_prices_and_context_window.json): add gpt-4o-audio audio token cost tracking * refactor(model_prices_and_context_window.json): refactor 'supports_audio' to be 'supports_audio_input' and 'supports_audio_output' Allows for flag to be used for openai + gemini models (both support audio input) * feat(cost_calculation.py): support cost calc for audio model Closes https://github.com/BerriAI/litellm/issues/6302 * feat(utils.py): expose new `supports_audio_input` and `supports_audio_output` functions Closes https://github.com/BerriAI/litellm/issues/6303 * feat(handle_jwt.py): support single dict list * fix(cost_calculator.py): fix linting errors * fix: fix linting error * fix(cost_calculator): move to using standard openai usage cached tokens value * test: fix test --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
313 lines
11 KiB
Python
313 lines
11 KiB
Python
"""
|
|
Supports using JWT's for authenticating into the proxy.
|
|
|
|
Currently only supports admin.
|
|
|
|
JWT token must have 'litellm_proxy_admin' in scope.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
from typing import Optional, cast
|
|
|
|
from cryptography import x509
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives import serialization
|
|
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
|
from litellm.proxy._types import JWKKeyValue, JWTKeyItem, LiteLLM_JWTAuth
|
|
from litellm.proxy.utils import PrismaClient
|
|
|
|
|
|
class JWTHandler:
|
|
"""
|
|
- treat the sub id passed in as the user id
|
|
- return an error if id making request doesn't exist in proxy user table
|
|
- track spend against the user id
|
|
- if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets
|
|
"""
|
|
|
|
prisma_client: Optional[PrismaClient]
|
|
user_api_key_cache: DualCache
|
|
|
|
def __init__(
|
|
self,
|
|
) -> None:
|
|
self.http_handler = HTTPHandler()
|
|
|
|
def update_environment(
|
|
self,
|
|
prisma_client: Optional[PrismaClient],
|
|
user_api_key_cache: DualCache,
|
|
litellm_jwtauth: LiteLLM_JWTAuth,
|
|
) -> None:
|
|
self.prisma_client = prisma_client
|
|
self.user_api_key_cache = user_api_key_cache
|
|
self.litellm_jwtauth = litellm_jwtauth
|
|
|
|
def is_jwt(self, token: str):
|
|
parts = token.split(".")
|
|
return len(parts) == 3
|
|
|
|
def is_admin(self, scopes: list) -> bool:
|
|
if self.litellm_jwtauth.admin_jwt_scope in scopes:
|
|
return True
|
|
return False
|
|
|
|
def get_end_user_id(
|
|
self, token: dict, default_value: Optional[str]
|
|
) -> Optional[str]:
|
|
try:
|
|
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
|
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
|
else:
|
|
user_id = None
|
|
except KeyError:
|
|
user_id = default_value
|
|
return user_id
|
|
|
|
def is_required_team_id(self) -> bool:
|
|
"""
|
|
Returns:
|
|
- True: if 'team_id_jwt_field' is set
|
|
- False: if not
|
|
"""
|
|
if self.litellm_jwtauth.team_id_jwt_field is None:
|
|
return False
|
|
return True
|
|
|
|
def is_enforced_email_domain(self) -> bool:
|
|
"""
|
|
Returns:
|
|
- True: if 'user_allowed_email_domain' is set
|
|
- False: if 'user_allowed_email_domain' is None
|
|
"""
|
|
|
|
if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance(
|
|
self.litellm_jwtauth.user_allowed_email_domain, str
|
|
):
|
|
return True
|
|
return False
|
|
|
|
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
|
try:
|
|
if self.litellm_jwtauth.team_id_jwt_field is not None:
|
|
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
|
elif self.litellm_jwtauth.team_id_default is not None:
|
|
team_id = self.litellm_jwtauth.team_id_default
|
|
else:
|
|
team_id = None
|
|
except KeyError:
|
|
team_id = default_value
|
|
return team_id
|
|
|
|
def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool:
|
|
"""
|
|
Returns:
|
|
- True: if 'user_id_upsert' is set AND valid_user_email is not False
|
|
- False: if not
|
|
"""
|
|
if valid_user_email is False:
|
|
return False
|
|
return self.litellm_jwtauth.user_id_upsert
|
|
|
|
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
|
try:
|
|
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
|
user_id = token[self.litellm_jwtauth.user_id_jwt_field]
|
|
else:
|
|
user_id = default_value
|
|
except KeyError:
|
|
user_id = default_value
|
|
return user_id
|
|
|
|
def get_user_email(
|
|
self, token: dict, default_value: Optional[str]
|
|
) -> Optional[str]:
|
|
try:
|
|
if self.litellm_jwtauth.user_email_jwt_field is not None:
|
|
user_email = token[self.litellm_jwtauth.user_email_jwt_field]
|
|
else:
|
|
user_email = None
|
|
except KeyError:
|
|
user_email = default_value
|
|
return user_email
|
|
|
|
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
|
try:
|
|
if self.litellm_jwtauth.org_id_jwt_field is not None:
|
|
org_id = token[self.litellm_jwtauth.org_id_jwt_field]
|
|
else:
|
|
org_id = None
|
|
except KeyError:
|
|
org_id = default_value
|
|
return org_id
|
|
|
|
def get_scopes(self, token: dict) -> list:
|
|
try:
|
|
if isinstance(token["scope"], str):
|
|
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
|
scopes = token["scope"].split()
|
|
elif isinstance(token["scope"], list):
|
|
scopes = token["scope"]
|
|
else:
|
|
raise Exception(
|
|
f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str."
|
|
)
|
|
except KeyError:
|
|
scopes = []
|
|
return scopes
|
|
|
|
async def get_public_key(self, kid: Optional[str]) -> dict:
|
|
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
|
|
|
if keys_url is None:
|
|
raise Exception("Missing JWT Public Key URL from environment.")
|
|
|
|
cached_keys = await self.user_api_key_cache.async_get_cache(
|
|
"litellm_jwt_auth_keys"
|
|
)
|
|
if cached_keys is None:
|
|
response = await self.http_handler.get(keys_url)
|
|
|
|
response_json = response.json()
|
|
if "keys" in response_json:
|
|
keys: JWKKeyValue = response.json()["keys"]
|
|
else:
|
|
keys = response_json
|
|
|
|
await self.user_api_key_cache.async_set_cache(
|
|
key="litellm_jwt_auth_keys",
|
|
value=keys,
|
|
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
|
|
)
|
|
else:
|
|
keys = cached_keys
|
|
|
|
public_key = self.parse_keys(keys=keys, kid=kid)
|
|
if public_key is None:
|
|
raise Exception(
|
|
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}"
|
|
)
|
|
return cast(dict, public_key)
|
|
|
|
def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]:
|
|
public_key: Optional[JWTKeyItem] = None
|
|
if len(keys) == 1:
|
|
if isinstance(keys, dict) and (keys.get("kid", None) == kid or kid is None):
|
|
public_key = keys
|
|
elif isinstance(keys, list) and (
|
|
keys[0].get("kid", None) == kid or kid is None
|
|
):
|
|
public_key = keys[0]
|
|
elif len(keys) > 1:
|
|
for key in keys:
|
|
if isinstance(key, dict):
|
|
key_kid = key.get("kid", None)
|
|
else:
|
|
key_kid = None
|
|
if (
|
|
kid is not None
|
|
and isinstance(key, dict)
|
|
and key_kid is not None
|
|
and key_kid == kid
|
|
):
|
|
public_key = key
|
|
|
|
return public_key
|
|
|
|
def is_allowed_domain(self, user_email: str) -> bool:
|
|
if self.litellm_jwtauth.user_allowed_email_domain is None:
|
|
return True
|
|
|
|
email_domain = user_email.split("@")[-1] # Extract domain from email
|
|
if email_domain == self.litellm_jwtauth.user_allowed_email_domain:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
async def auth_jwt(self, token: str) -> dict:
|
|
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
|
|
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
|
|
# the key in different ways (e.g. HS* and RS*)."
|
|
algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
|
|
|
|
audience = os.getenv("JWT_AUDIENCE")
|
|
decode_options = None
|
|
if audience is None:
|
|
decode_options = {"verify_aud": False}
|
|
|
|
import jwt
|
|
from jwt.algorithms import RSAAlgorithm
|
|
|
|
header = jwt.get_unverified_header(token)
|
|
|
|
verbose_proxy_logger.debug("header: %s", header)
|
|
|
|
kid = header.get("kid", None)
|
|
|
|
public_key = await self.get_public_key(kid=kid)
|
|
|
|
if public_key is not None and isinstance(public_key, dict):
|
|
jwk = {}
|
|
if "kty" in public_key:
|
|
jwk["kty"] = public_key["kty"]
|
|
if "kid" in public_key:
|
|
jwk["kid"] = public_key["kid"]
|
|
if "n" in public_key:
|
|
jwk["n"] = public_key["n"]
|
|
if "e" in public_key:
|
|
jwk["e"] = public_key["e"]
|
|
|
|
public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
|
|
|
try:
|
|
# decode the token using the public key
|
|
payload = jwt.decode(
|
|
token,
|
|
public_key_rsa, # type: ignore
|
|
algorithms=algorithms,
|
|
options=decode_options,
|
|
audience=audience,
|
|
)
|
|
return payload
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
# the token is expired, do something to refresh it
|
|
raise Exception("Token Expired")
|
|
except Exception as e:
|
|
raise Exception(f"Validation fails: {str(e)}")
|
|
elif public_key is not None and isinstance(public_key, str):
|
|
try:
|
|
cert = x509.load_pem_x509_certificate(
|
|
public_key.encode(), default_backend()
|
|
)
|
|
|
|
# Extract public key
|
|
key = cert.public_key().public_bytes(
|
|
serialization.Encoding.PEM,
|
|
serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
)
|
|
|
|
# decode the token using the public key
|
|
payload = jwt.decode(
|
|
token,
|
|
key,
|
|
algorithms=algorithms,
|
|
audience=audience,
|
|
options=decode_options,
|
|
)
|
|
return payload
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
# the token is expired, do something to refresh it
|
|
raise Exception("Token Expired")
|
|
except Exception as e:
|
|
raise Exception(f"Validation fails: {str(e)}")
|
|
|
|
raise Exception("Invalid JWT Submitted")
|
|
|
|
async def close(self):
|
|
await self.http_handler.close()
|