mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
commit
b603d9784d
2 changed files with 58 additions and 29 deletions
|
@ -5097,7 +5097,15 @@ async def google_login(request: Request):
|
|||
scope=generic_scope,
|
||||
)
|
||||
with generic_sso:
|
||||
return await generic_sso.get_login_redirect()
|
||||
# TODO: state should be a random string and added to the user session with cookie
|
||||
# or a cryptographicly signed state that we can verify stateless
|
||||
# For simplification we are using a static state, this is not perfect but some
|
||||
# SSO providers do not allow stateless verification
|
||||
redirect_params = {}
|
||||
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||||
if state:
|
||||
redirect_params['state'] = state
|
||||
return await generic_sso.get_login_redirect(**redirect_params)
|
||||
elif ui_username is not None:
|
||||
# No Google, Microsoft SSO
|
||||
# Use UI Credentials set in .env
|
||||
|
@ -5265,7 +5273,7 @@ async def auth_callback(request: Request):
|
|||
result = await microsoft_sso.verify_and_process(request)
|
||||
elif generic_client_id is not None:
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
|
@ -5274,6 +5282,9 @@ async def auth_callback(request: Request):
|
|||
)
|
||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||
generic_include_client_id = (
|
||||
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
||||
)
|
||||
if generic_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||
|
@ -5308,12 +5319,36 @@ async def auth_callback(request: Request):
|
|||
verbose_proxy_logger.debug(
|
||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||
)
|
||||
|
||||
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email")
|
||||
generic_user_email_attribute_name = os.getenv(
|
||||
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
||||
)
|
||||
generic_user_role_attribute_name = os.getenv(
|
||||
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
||||
)
|
||||
|
||||
discovery = DiscoveryDocument(
|
||||
authorization_endpoint=generic_authorization_endpoint,
|
||||
token_endpoint=generic_token_endpoint,
|
||||
userinfo_endpoint=generic_userinfo_endpoint,
|
||||
)
|
||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||
|
||||
def response_convertor(response, client):
|
||||
return OpenID(
|
||||
id=response.get(generic_user_email_attribute_name),
|
||||
display_name=response.get(generic_user_email_attribute_name),
|
||||
email=response.get(generic_user_email_attribute_name),
|
||||
)
|
||||
|
||||
SSOProvider = create_provider(
|
||||
name="oidc",
|
||||
discovery_document=discovery,
|
||||
response_convertor=response_convertor,
|
||||
)
|
||||
generic_sso = SSOProvider(
|
||||
client_id=generic_client_id,
|
||||
client_secret=generic_client_secret,
|
||||
|
@ -5322,43 +5357,36 @@ async def auth_callback(request: Request):
|
|||
scope=generic_scope,
|
||||
)
|
||||
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
|
||||
request_body = await request.body()
|
||||
request_query_params = request.query_params
|
||||
# get "code" from query params
|
||||
code = request_query_params.get("code")
|
||||
result = await generic_sso.verify_and_process(request)
|
||||
result = await generic_sso.verify_and_process(
|
||||
request, params={"include_client_id": generic_include_client_id}
|
||||
)
|
||||
verbose_proxy_logger.debug(f"generic result: {result}")
|
||||
|
||||
# User is Authe'd in - generate key for the UI to access Proxy
|
||||
user_email = getattr(result, "email", None)
|
||||
user_id = getattr(result, "id", None)
|
||||
|
||||
# generic client id
|
||||
if generic_client_id is not None:
|
||||
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email")
|
||||
generic_user_email_attribute_name = os.getenv(
|
||||
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
||||
)
|
||||
generic_user_role_attribute_name = os.getenv(
|
||||
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
||||
)
|
||||
|
||||
user_id = getattr(result, generic_user_id_attribute_name, None)
|
||||
user_email = getattr(result, generic_user_email_attribute_name, None)
|
||||
user_id = result.id
|
||||
user_email = result.email
|
||||
user_role = getattr(result, generic_user_role_attribute_name, None)
|
||||
|
||||
if user_id is None:
|
||||
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
||||
# get user_info from litellm DB
|
||||
|
||||
user_info = None
|
||||
if prisma_client is not None:
|
||||
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
||||
user_id_models: List = []
|
||||
if user_info is not None:
|
||||
user_id_models = getattr(user_info, "models", [])
|
||||
user_id_models = []
|
||||
|
||||
# User might not be already created on first generation of key
|
||||
# But if it is, we want its models preferences
|
||||
try:
|
||||
if prisma_client is not None:
|
||||
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
||||
if user_info is not None:
|
||||
user_id_models = getattr(user_info, "models", [])
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
response = await generate_key_helper_fn(
|
||||
**{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue