Merge branch 'main' into litellm_team_settings

This commit is contained in:
Krish Dholakia 2024-02-14 21:45:59 -08:00 committed by GitHub
commit 856aa9c30b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 681 additions and 113 deletions

View file

@ -403,34 +403,43 @@ async def user_api_key_auth(
verbose_proxy_logger.debug(
f"LLM Model List pre access group check: {llm_model_list}"
)
access_groups = []
from collections import defaultdict
access_groups = defaultdict(list)
if llm_model_list is not None:
for m in llm_model_list:
for group in m.get("model_info", {}).get("access_groups", []):
access_groups.append((m["model_name"], group))
model_name = m["model_name"]
access_groups[group].append(model_name)
allowed_models = valid_token.models
access_group_idx = set()
models_in_current_access_groups = []
if (
len(access_groups) > 0
): # check if token contains any model access groups
for idx, m in enumerate(valid_token.models):
for model_name, group in access_groups:
if m == group:
access_group_idx.add(idx)
allowed_models.append(model_name)
for idx, m in enumerate(
valid_token.models
): # loop token models, if any of them are an access group add the access group
if m in access_groups:
# if it is an access group we need to remove it from valid_token.models
models_in_group = access_groups[m]
models_in_current_access_groups.extend(models_in_group)
# Filter out models that are access_groups
filtered_models = [
m for m in valid_token.models if m not in access_groups
]
filtered_models += models_in_current_access_groups
verbose_proxy_logger.debug(
f"model: {model}; allowed_models: {allowed_models}"
f"model: {model}; allowed_models: {filtered_models}"
)
if model is not None and model not in allowed_models:
if model is not None and model not in filtered_models:
raise ValueError(
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
)
for val in access_group_idx:
allowed_models.pop(val)
valid_token.models = allowed_models
valid_token.models = filtered_models
verbose_proxy_logger.debug(
f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}"
f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
)
# Check 2. If user_id for this token is in budget
@ -2087,14 +2096,6 @@ def model_list(
if user_model is not None:
all_models += [user_model]
verbose_proxy_logger.debug(f"all_models: {all_models}")
### CHECK OLLAMA MODELS ###
try:
response = requests.get("http://0.0.0.0:11434/api/tags")
models = response.json()["models"]
ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models]
all_models.extend(ollama_models)
except Exception as e:
pass
return dict(
data=[
{
@ -2798,7 +2799,161 @@ async def image_generation(
)
#### KEY MANAGEMENT #####
@router.post(
"/v1/moderations",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["moderations"],
)
@router.post(
"/moderations",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["moderations"],
)
async def moderations(
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
The moderations endpoint is a tool you can use to check whether content complies with an LLM Providers policies.
Quick Start
```
curl --location 'http://0.0.0.0:4000/moderations' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
```
"""
global proxy_logging_obj
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
data["model"] = (
general_settings.get("moderation_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="moderation"
)
start_time = time.time()
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.amoderation(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.amoderation(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.amoderation(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.amoderation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.amoderation(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
#### KEY MANAGEMENT ####
@router.post(
@ -3684,7 +3839,6 @@ async def user_update(data: UpdateUserRequest):
code=status.HTTP_400_BAD_REQUEST,
)
#### TEAM MANAGEMENT ####
@ -3766,75 +3920,182 @@ async def team_info(
):
"""
get info on team + related keys
```
curl --location 'http://localhost:4000/team/info' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"teams": ["<team-id>",..]
}'
```
"""
global prisma_client
try:
if prisma_client is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"error": f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
},
)
if team_id is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={"message": "Malformed request. No team id passed in."},
)
pass
@app.get("/sso/callback", tags=["experimental"])
async def auth_callback(request: Request):
"""Verify login"""
global general_settings
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
team_info = await prisma_client.get_data(
team_id=team_id, table_name="team", query_type="find_unique"
)
## GET ALL KEYS ##
keys = await prisma_client.get_data(
team_id=team_id,
table_name="key",
query_type="find_all",
expires=datetime.now(),
)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
if team_info is None:
## make sure we still return a total spend ##
spend = 0
for k in keys:
spend += getattr(k, "spend", 0)
team_info = {"spend": spend}
if redirect_url.endswith("/"):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
## REMOVE HASHED TOKEN INFO before returning ##
for key in keys:
try:
key = key.model_dump() # noqa
except:
# if using pydantic v1
key = key.dict()
key.pop("token", None)
return {"team_id": team_id, "team_info": team_info, "keys": keys}
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
except Exception as e:
if isinstance(e, HTTPException):
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
google_sso = GoogleSSO(
client_id=google_client_id,
redirect_uri=redirect_url,
client_secret=google_client_secret,
)
result = await google_sso.verify_and_process(request)
elif microsoft_client_id is not None:
from fastapi_sso.sso.microsoft import MicrosoftSSO
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if microsoft_tenant is None:
raise ProxyException(
message="MICROSOFT_TENANT not set. Set it in .env file",
type="auth_error",
param="MICROSOFT_TENANT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = MicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None:
# make generic sso provider
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
request_body = await request.body()
request_query_params = request.query_params
# get "code" from query params
code = request_query_params.get("code")
result = await generic_sso.verify_and_process(request)
verbose_proxy_logger.debug(f"generic result: {result}")
# User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None)
user_id = getattr(result, "id", None)
if user_id is None:
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
response = await generate_key_helper_fn(
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
litellm_dashboard_ui = "/ui/"
user_role = "app_owner"
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
):
# checks if user is admin
user_role = "app_admin"
import jwt
jwt_token = jwt.encode(
{
"user_id": user_id,
"key": key,
"user_email": user_email,
"user_role": user_role,
},
"secret",
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui)
#### MODEL MANAGEMENT ####
@ -4260,6 +4521,73 @@ async def google_login(request: Request):
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with generic_sso:
return await generic_sso.get_login_redirect()
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env