mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
commit
b603d9784d
2 changed files with 58 additions and 29 deletions
|
@ -134,7 +134,8 @@ The following can be used to customize attribute names when interacting with the
|
||||||
GENERIC_USER_ID_ATTRIBUTE = "given_name"
|
GENERIC_USER_ID_ATTRIBUTE = "given_name"
|
||||||
GENERIC_USER_EMAIL_ATTRIBUTE = "family_name"
|
GENERIC_USER_EMAIL_ATTRIBUTE = "family_name"
|
||||||
GENERIC_USER_ROLE_ATTRIBUTE = "given_role"
|
GENERIC_USER_ROLE_ATTRIBUTE = "given_role"
|
||||||
|
GENERIC_CLIENT_STATE = "some-state" # if the provider needs a state parameter
|
||||||
|
GENERIC_INCLUDE_CLIENT_ID = "false" # some providers enforce that the client_id is not in the body
|
||||||
GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not enough to retrieve basic user info like first_name and last_name located in profile scope
|
GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not enough to retrieve basic user info like first_name and last_name located in profile scope
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -5097,7 +5097,15 @@ async def google_login(request: Request):
|
||||||
scope=generic_scope,
|
scope=generic_scope,
|
||||||
)
|
)
|
||||||
with generic_sso:
|
with generic_sso:
|
||||||
return await generic_sso.get_login_redirect()
|
# TODO: state should be a random string and added to the user session with cookie
|
||||||
|
# or a cryptographicly signed state that we can verify stateless
|
||||||
|
# For simplification we are using a static state, this is not perfect but some
|
||||||
|
# SSO providers do not allow stateless verification
|
||||||
|
redirect_params = {}
|
||||||
|
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||||||
|
if state:
|
||||||
|
redirect_params['state'] = state
|
||||||
|
return await generic_sso.get_login_redirect(**redirect_params)
|
||||||
elif ui_username is not None:
|
elif ui_username is not None:
|
||||||
# No Google, Microsoft SSO
|
# No Google, Microsoft SSO
|
||||||
# Use UI Credentials set in .env
|
# Use UI Credentials set in .env
|
||||||
|
@ -5265,7 +5273,7 @@ async def auth_callback(request: Request):
|
||||||
result = await microsoft_sso.verify_and_process(request)
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
# make generic sso provider
|
# make generic sso provider
|
||||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||||
|
@ -5274,6 +5282,9 @@ async def auth_callback(request: Request):
|
||||||
)
|
)
|
||||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
generic_include_client_id = (
|
||||||
|
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
||||||
|
)
|
||||||
if generic_client_secret is None:
|
if generic_client_secret is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
@ -5308,12 +5319,36 @@ async def auth_callback(request: Request):
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email")
|
||||||
|
generic_user_email_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
||||||
|
)
|
||||||
|
generic_user_role_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
||||||
|
)
|
||||||
|
|
||||||
discovery = DiscoveryDocument(
|
discovery = DiscoveryDocument(
|
||||||
authorization_endpoint=generic_authorization_endpoint,
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
token_endpoint=generic_token_endpoint,
|
token_endpoint=generic_token_endpoint,
|
||||||
userinfo_endpoint=generic_userinfo_endpoint,
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
)
|
)
|
||||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
|
||||||
|
def response_convertor(response, client):
|
||||||
|
return OpenID(
|
||||||
|
id=response.get(generic_user_email_attribute_name),
|
||||||
|
display_name=response.get(generic_user_email_attribute_name),
|
||||||
|
email=response.get(generic_user_email_attribute_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
SSOProvider = create_provider(
|
||||||
|
name="oidc",
|
||||||
|
discovery_document=discovery,
|
||||||
|
response_convertor=response_convertor,
|
||||||
|
)
|
||||||
generic_sso = SSOProvider(
|
generic_sso = SSOProvider(
|
||||||
client_id=generic_client_id,
|
client_id=generic_client_id,
|
||||||
client_secret=generic_client_secret,
|
client_secret=generic_client_secret,
|
||||||
|
@ -5322,43 +5357,36 @@ async def auth_callback(request: Request):
|
||||||
scope=generic_scope,
|
scope=generic_scope,
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
|
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
|
||||||
request_body = await request.body()
|
result = await generic_sso.verify_and_process(
|
||||||
request_query_params = request.query_params
|
request, params={"include_client_id": generic_include_client_id}
|
||||||
# get "code" from query params
|
)
|
||||||
code = request_query_params.get("code")
|
|
||||||
result = await generic_sso.verify_and_process(request)
|
|
||||||
verbose_proxy_logger.debug(f"generic result: {result}")
|
verbose_proxy_logger.debug(f"generic result: {result}")
|
||||||
|
|
||||||
# User is Authe'd in - generate key for the UI to access Proxy
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||||||
user_email = getattr(result, "email", None)
|
user_email = getattr(result, "email", None)
|
||||||
user_id = getattr(result, "id", None)
|
user_id = getattr(result, "id", None)
|
||||||
|
|
||||||
# generic client id
|
# generic client id
|
||||||
if generic_client_id is not None:
|
if generic_client_id is not None:
|
||||||
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email")
|
user_id = result.id
|
||||||
generic_user_email_attribute_name = os.getenv(
|
user_email = result.email
|
||||||
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
|
||||||
)
|
|
||||||
generic_user_role_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
user_id = getattr(result, generic_user_id_attribute_name, None)
|
|
||||||
user_email = getattr(result, generic_user_email_attribute_name, None)
|
|
||||||
user_role = getattr(result, generic_user_role_attribute_name, None)
|
user_role = getattr(result, generic_user_role_attribute_name, None)
|
||||||
|
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
||||||
# get user_info from litellm DB
|
|
||||||
user_info = None
|
user_info = None
|
||||||
|
user_id_models = []
|
||||||
|
|
||||||
|
# User might not be already created on first generation of key
|
||||||
|
# But if it is, we want its models preferences
|
||||||
|
try:
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
||||||
user_id_models: List = []
|
|
||||||
if user_info is not None:
|
if user_info is not None:
|
||||||
user_id_models = getattr(user_info, "models", [])
|
user_id_models = getattr(user_info, "models", [])
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
response = await generate_key_helper_fn(
|
response = await generate_key_helper_fn(
|
||||||
**{
|
**{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue