Merge branch 'main' into litellm_aioboto3_sagemaker

This commit is contained in:
Krish Dholakia 2024-02-14 21:46:58 -08:00 committed by GitHub
commit 57654f4533
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
79 changed files with 3440 additions and 253 deletions

View file

@ -403,34 +403,43 @@ async def user_api_key_auth(
verbose_proxy_logger.debug(
f"LLM Model List pre access group check: {llm_model_list}"
)
access_groups = []
from collections import defaultdict
access_groups = defaultdict(list)
if llm_model_list is not None:
for m in llm_model_list:
for group in m.get("model_info", {}).get("access_groups", []):
access_groups.append((m["model_name"], group))
model_name = m["model_name"]
access_groups[group].append(model_name)
allowed_models = valid_token.models
access_group_idx = set()
models_in_current_access_groups = []
if (
len(access_groups) > 0
): # check if token contains any model access groups
for idx, m in enumerate(valid_token.models):
for model_name, group in access_groups:
if m == group:
access_group_idx.add(idx)
allowed_models.append(model_name)
for idx, m in enumerate(
valid_token.models
): # loop token models, if any of them are an access group add the access group
if m in access_groups:
# if it is an access group we need to remove it from valid_token.models
models_in_group = access_groups[m]
models_in_current_access_groups.extend(models_in_group)
# Filter out models that are access_groups
filtered_models = [
m for m in valid_token.models if m not in access_groups
]
filtered_models += models_in_current_access_groups
verbose_proxy_logger.debug(
f"model: {model}; allowed_models: {allowed_models}"
f"model: {model}; allowed_models: {filtered_models}"
)
if model is not None and model not in allowed_models:
if model is not None and model not in filtered_models:
raise ValueError(
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
)
for val in access_group_idx:
allowed_models.pop(val)
valid_token.models = allowed_models
valid_token.models = filtered_models
verbose_proxy_logger.debug(
f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}"
f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
)
# Check 2. If user_id for this token is in budget
@ -682,34 +691,31 @@ async def user_api_key_auth(
# sso/login, ui/login, /key functions and /user functions
# this will never be allowed to call /chat/completions
token_team = getattr(valid_token, "team_id", None)
if token_team is not None:
if token_team == "litellm-dashboard":
# this token is only used for managing the ui
allowed_routes = [
"/sso",
"/login",
"/key",
"/spend",
"/user",
]
# check if the current route startswith any of the allowed routes
if (
route is not None
and isinstance(route, str)
and any(
route.startswith(allowed_route)
for allowed_route in allowed_routes
)
):
# Do something if the current route starts with any of the allowed routes
pass
else:
raise Exception(
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid Key Passed to LiteLLM Proxy")
if token_team is not None and token_team == "litellm-dashboard":
# this token is only used for managing the ui
allowed_routes = [
"/sso",
"/login",
"/key",
"/spend",
"/user",
"/model/info",
]
# check if the current route startswith any of the allowed routes
if (
route is not None
and isinstance(route, str)
and any(
route.startswith(allowed_route) for allowed_route in allowed_routes
)
):
# Do something if the current route starts with any of the allowed routes
pass
else:
raise Exception(
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
except Exception as e:
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}")
traceback.print_exc()
@ -1443,6 +1449,24 @@ class ProxyConfig:
database_type == "dynamo_db" or database_type == "dynamodb"
):
database_args = general_settings.get("database_args", None)
### LOAD FROM os.environ/ ###
for k, v in database_args.items():
if isinstance(v, str) and v.startswith("os.environ/"):
database_args[k] = litellm.get_secret(v)
if isinstance(k, str) and k == "aws_web_identity_token":
value = database_args[k]
verbose_proxy_logger.debug(
f"Loading AWS Web Identity Token from file: {value}"
)
if os.path.exists(value):
with open(value, "r") as file:
token_content = file.read()
database_args[k] = token_content
else:
verbose_proxy_logger.info(
f"DynamoDB Loading - {value} is not a valid file path"
)
verbose_proxy_logger.debug(f"database_args: {database_args}")
custom_db_client = DBClient(
custom_db_args=database_args, custom_db_type=database_type
)
@ -1580,8 +1604,6 @@ async def generate_key_helper_fn(
tpm_limit = tpm_limit
rpm_limit = rpm_limit
allowed_cache_controls = allowed_cache_controls
if type(team_id) is not str:
team_id = str(team_id)
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
user_data = {
@ -2057,14 +2079,6 @@ def model_list(
if user_model is not None:
all_models += [user_model]
verbose_proxy_logger.debug(f"all_models: {all_models}")
### CHECK OLLAMA MODELS ###
try:
response = requests.get("http://0.0.0.0:11434/api/tags")
models = response.json()["models"]
ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models]
all_models.extend(ollama_models)
except Exception as e:
pass
return dict(
data=[
{
@ -2355,8 +2369,13 @@ async def chat_completion(
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.acompletion(**data, specific_deployment=True)
else: # router is not set
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.acompletion(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
# Post Call Processing
data["litellm_status"] = "success" # used for alerting
@ -2387,6 +2406,11 @@ async def chat_completion(
)
fastapi_response.headers["x-litellm-model-id"] = model_id
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response
)
return response
except Exception as e:
traceback.print_exc()
@ -2417,7 +2441,12 @@ async def chat_completion(
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
raise ProxyException(
message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
@ -2567,8 +2596,13 @@ async def embeddings(
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aembedding(**data, specific_deployment=True)
else:
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aembedding(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -2586,7 +2620,12 @@ async def embeddings(
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
raise ProxyException(
message=getattr(e, "message", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
@ -2702,8 +2741,13 @@ async def image_generation(
response = await llm_router.aimage_generation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
else:
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aimage_generation(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -2721,7 +2765,165 @@ async def image_generation(
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
raise ProxyException(
message=getattr(e, "message", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/moderations",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["moderations"],
)
@router.post(
"/moderations",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["moderations"],
)
async def moderations(
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
The moderations endpoint is a tool you can use to check whether content complies with an LLM Providers policies.
Quick Start
```
curl --location 'http://0.0.0.0:4000/moderations' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
```
"""
global proxy_logging_obj
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
data["model"] = (
general_settings.get("moderation_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="moderation"
)
start_time = time.time()
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.amoderation(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.amoderation(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.amoderation(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.amoderation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.amoderation(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
@ -3516,6 +3718,7 @@ async def google_login(request: Request):
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
@ -3574,6 +3777,69 @@ async def google_login(request: Request):
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with generic_sso:
return await generic_sso.get_login_redirect()
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
@ -3673,6 +3939,7 @@ async def auth_callback(request: Request):
global general_settings
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
@ -3728,6 +3995,77 @@ async def auth_callback(request: Request):
allow_insecure_http=True,
)
result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None:
# make generic sso provider
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type="auth_error",
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type="auth_error",
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
request_body = await request.body()
request_query_params = request.query_params
# get "code" from query params
code = request_query_params.get("code")
result = await generic_sso.verify_and_process(request)
verbose_proxy_logger.debug(f"generic result: {result}")
# User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None)
@ -3936,7 +4274,6 @@ async def add_new_model(model_params: ModelParams):
)
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
@router.get(
"/model/info",
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
@ -3969,6 +4306,28 @@ async def model_info_v1(
# read litellm model_prices_and_context_window.json to get the following:
# input_cost_per_token, output_cost_per_token, max_tokens
litellm_model_info = get_litellm_model_info(model=model)
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except:
litellm_model_info = {}
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
split_model = litellm_model.split("/")
if len(split_model) > 0:
litellm_model = split_model[-1]
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except:
litellm_model_info = {}
for k, v in litellm_model_info.items():
if k not in model_info:
model_info[k] = v