forked from phoenix/litellm-mirror
LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)
* refactor: cleanup unused variables + fix pyright errors * feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686 * fix(o1_reasoning.py): add stricter check for o-1 reasoning model * refactor(mistral/): make it easier to see mistral transformation logic * fix(openai.py): fix openai o-1 model param mapping Fixes https://github.com/BerriAI/litellm/issues/5685 * feat(main.py): infer finetuned gemini model from base model Fixes https://github.com/BerriAI/litellm/issues/5678 * docs(vertex.md): update docs to call finetuned gemini models * feat(proxy_server.py): allow admin to hide proxy model aliases Closes https://github.com/BerriAI/litellm/issues/5692 * docs(load_balancing.md): add docs on hiding alias models from proxy config * fix(base.py): don't raise notimplemented error * fix(user_api_key_auth.py): fix model max budget check * fix(router.py): fix elif * fix(user_api_key_auth.py): don't set team_id to empty str * fix(team_endpoints.py): fix response type * test(test_completion.py): handle predibase error * test(test_proxy_server.py): fix test * fix(o1_transformation.py): fix max_completion_token mapping * test(test_image_generation.py): mark flaky test
This commit is contained in:
parent
db3af20d84
commit
60709a0753
35 changed files with 1020 additions and 539 deletions
|
@ -125,16 +125,7 @@ from litellm.proxy._types import *
|
|||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||
router as analytics_router,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
allowed_routes_check,
|
||||
common_checks,
|
||||
get_actual_routes,
|
||||
get_end_user_object,
|
||||
get_org_object,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
log_to_opentelemetry,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import log_to_opentelemetry
|
||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||
|
@ -260,6 +251,7 @@ from litellm.secret_managers.aws_secret_manager import (
|
|||
load_aws_secret_manager,
|
||||
)
|
||||
from litellm.secret_managers.google_kms import load_google_kms
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicResponse,
|
||||
|
@ -484,7 +476,7 @@ general_settings: dict = {}
|
|||
callback_settings: dict = {}
|
||||
log_file = "api_log.json"
|
||||
worker_config = None
|
||||
master_key = None
|
||||
master_key: Optional[str] = None
|
||||
otel_logging = False
|
||||
prisma_client: Optional[PrismaClient] = None
|
||||
custom_db_client: Optional[DBClient] = None
|
||||
|
@ -874,7 +866,9 @@ def error_tracking():
|
|||
|
||||
|
||||
def _set_spend_logs_payload(
|
||||
payload: dict, prisma_client: PrismaClient, spend_logs_url: Optional[str] = None
|
||||
payload: Union[dict, SpendLogsPayload],
|
||||
prisma_client: PrismaClient,
|
||||
spend_logs_url: Optional[str] = None,
|
||||
):
|
||||
if prisma_client is not None and spend_logs_url is not None:
|
||||
if isinstance(payload["startTime"], datetime):
|
||||
|
@ -1341,6 +1335,9 @@ async def _run_background_health_check():
|
|||
# make 1 deep copy of llm_model_list -> use this for all background health checks
|
||||
_llm_model_list = copy.deepcopy(llm_model_list)
|
||||
|
||||
if _llm_model_list is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
||||
model_list=_llm_model_list, details=health_check_details
|
||||
|
@ -1352,7 +1349,10 @@ async def _run_background_health_check():
|
|||
health_check_results["healthy_count"] = len(healthy_endpoints)
|
||||
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
|
||||
|
||||
await asyncio.sleep(health_check_interval)
|
||||
if health_check_interval is not None and isinstance(
|
||||
health_check_interval, float
|
||||
):
|
||||
await asyncio.sleep(health_check_interval)
|
||||
|
||||
|
||||
class ProxyConfig:
|
||||
|
@ -1467,7 +1467,7 @@ class ProxyConfig:
|
|||
break
|
||||
for k, v in team_config.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
team_config[k] = litellm.get_secret(v)
|
||||
team_config[k] = get_secret(v)
|
||||
return team_config
|
||||
|
||||
def _init_cache(
|
||||
|
@ -1513,6 +1513,9 @@ class ProxyConfig:
|
|||
config = get_file_contents_from_s3(
|
||||
bucket_name=bucket_name, object_key=object_key
|
||||
)
|
||||
|
||||
if config is None:
|
||||
raise Exception("Unable to load config from given source.")
|
||||
else:
|
||||
# default to file
|
||||
config = await self.get_config(config_file_path=config_file_path)
|
||||
|
@ -1528,9 +1531,7 @@ class ProxyConfig:
|
|||
environment_variables = config.get("environment_variables", None)
|
||||
if environment_variables:
|
||||
for key, value in environment_variables.items():
|
||||
os.environ[key] = str(
|
||||
litellm.get_secret(secret_name=key, default_value=value)
|
||||
)
|
||||
os.environ[key] = str(get_secret(secret_name=key, default_value=value))
|
||||
|
||||
# check if litellm_license in general_settings
|
||||
if "LITELLM_LICENSE" in environment_variables:
|
||||
|
@ -1566,8 +1567,8 @@ class ProxyConfig:
|
|||
if (
|
||||
cache_type == "redis" or cache_type == "redis-semantic"
|
||||
) and len(cache_params.keys()) == 0:
|
||||
cache_host = litellm.get_secret("REDIS_HOST", None)
|
||||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||
cache_host = get_secret("REDIS_HOST", None)
|
||||
cache_port = get_secret("REDIS_PORT", None)
|
||||
cache_password = None
|
||||
cache_params.update(
|
||||
{
|
||||
|
@ -1577,8 +1578,8 @@ class ProxyConfig:
|
|||
}
|
||||
)
|
||||
|
||||
if litellm.get_secret("REDIS_PASSWORD", None) is not None:
|
||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||
if get_secret("REDIS_PASSWORD", None) is not None:
|
||||
cache_password = get_secret("REDIS_PASSWORD", None)
|
||||
cache_params.update(
|
||||
{
|
||||
"password": cache_password,
|
||||
|
@ -1617,7 +1618,7 @@ class ProxyConfig:
|
|||
# users can pass os.environ/ variables on the proxy - we should read them from the env
|
||||
for key, value in cache_params.items():
|
||||
if type(value) is str and value.startswith("os.environ/"):
|
||||
cache_params[key] = litellm.get_secret(value)
|
||||
cache_params[key] = get_secret(value)
|
||||
|
||||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||
self._init_cache(cache_params=cache_params)
|
||||
|
@ -1738,7 +1739,7 @@ class ProxyConfig:
|
|||
if value is not None and isinstance(value, dict):
|
||||
for _k, _v in value.items():
|
||||
if isinstance(_v, str) and _v.startswith("os.environ/"):
|
||||
value[_k] = litellm.get_secret(_v)
|
||||
value[_k] = get_secret(_v)
|
||||
litellm.upperbound_key_generate_params = (
|
||||
LiteLLM_UpperboundKeyGenerateParams(**value)
|
||||
)
|
||||
|
@ -1812,15 +1813,15 @@ class ProxyConfig:
|
|||
database_url = general_settings.get("database_url", None)
|
||||
if database_url and database_url.startswith("os.environ/"):
|
||||
verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!")
|
||||
database_url = litellm.get_secret(database_url)
|
||||
database_url = get_secret(database_url)
|
||||
verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url)
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get(
|
||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
"master_key", get_secret("LITELLM_MASTER_KEY", None)
|
||||
)
|
||||
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
master_key = get_secret(master_key) # type: ignore
|
||||
if not isinstance(master_key, str):
|
||||
raise Exception(
|
||||
"Master key must be a string. Current type - {}".format(
|
||||
|
@ -1861,33 +1862,6 @@ class ProxyConfig:
|
|||
await initialize_pass_through_endpoints(
|
||||
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
||||
)
|
||||
## dynamodb
|
||||
database_type = general_settings.get("database_type", None)
|
||||
if database_type is not None and (
|
||||
database_type == "dynamo_db" or database_type == "dynamodb"
|
||||
):
|
||||
database_args = general_settings.get("database_args", None)
|
||||
### LOAD FROM os.environ/ ###
|
||||
for k, v in database_args.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
database_args[k] = litellm.get_secret(v)
|
||||
if isinstance(k, str) and k == "aws_web_identity_token":
|
||||
value = database_args[k]
|
||||
verbose_proxy_logger.debug(
|
||||
f"Loading AWS Web Identity Token from file: {value}"
|
||||
)
|
||||
if os.path.exists(value):
|
||||
with open(value, "r") as file:
|
||||
token_content = file.read()
|
||||
database_args[k] = token_content
|
||||
else:
|
||||
verbose_proxy_logger.info(
|
||||
f"DynamoDB Loading - {value} is not a valid file path"
|
||||
)
|
||||
verbose_proxy_logger.debug("database_args: %s", database_args)
|
||||
custom_db_client = DBClient(
|
||||
custom_db_args=database_args, custom_db_type=database_type
|
||||
)
|
||||
## ADMIN UI ACCESS ##
|
||||
ui_access_mode = general_settings.get(
|
||||
"ui_access_mode", "all"
|
||||
|
@ -1951,7 +1925,7 @@ class ProxyConfig:
|
|||
### LOAD FROM os.environ/ ###
|
||||
for k, v in model["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
model["litellm_params"][k] = litellm.get_secret(v)
|
||||
model["litellm_params"][k] = get_secret(v)
|
||||
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
|
||||
litellm_model_name = model["litellm_params"]["model"]
|
||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||
|
@ -2005,7 +1979,10 @@ class ProxyConfig:
|
|||
) # type:ignore
|
||||
|
||||
# Guardrail settings
|
||||
guardrails_v2 = config.get("guardrails", None)
|
||||
guardrails_v2: Optional[dict] = None
|
||||
|
||||
if config is not None:
|
||||
guardrails_v2 = config.get("guardrails", None)
|
||||
if guardrails_v2:
|
||||
init_guardrails_v2(
|
||||
all_guardrails=guardrails_v2, config_file_path=config_file_path
|
||||
|
@ -2074,7 +2051,7 @@ class ProxyConfig:
|
|||
### LOAD FROM os.environ/ ###
|
||||
for k, v in model["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
model["litellm_params"][k] = litellm.get_secret(v)
|
||||
model["litellm_params"][k] = get_secret(v)
|
||||
|
||||
## check if they have model-id's ##
|
||||
model_id = model.get("model_info", {}).get("id", None)
|
||||
|
@ -2234,7 +2211,8 @@ class ProxyConfig:
|
|||
for k, v in environment_variables.items():
|
||||
try:
|
||||
decrypted_value = decrypt_value_helper(value=v)
|
||||
os.environ[k] = decrypted_value
|
||||
if decrypted_value is not None:
|
||||
os.environ[k] = decrypted_value
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"Error setting env variable: %s - %s", k, str(e)
|
||||
|
@ -2536,7 +2514,7 @@ async def async_assistants_data_generator(
|
|||
)
|
||||
|
||||
# chunk = chunk.model_dump_json(exclude_none=True)
|
||||
async for c in chunk:
|
||||
async for c in chunk: # type: ignore
|
||||
c = c.model_dump_json(exclude_none=True)
|
||||
try:
|
||||
yield f"data: {c}\n\n"
|
||||
|
@ -2745,17 +2723,22 @@ async def startup_event():
|
|||
|
||||
### LOAD MASTER KEY ###
|
||||
# check if master key set in environment - load from there
|
||||
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
|
||||
# check if DATABASE_URL in environment - load from there
|
||||
if prisma_client is None:
|
||||
prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None))
|
||||
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
|
||||
prisma_setup(database_url=_db_url)
|
||||
|
||||
### LOAD CONFIG ###
|
||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
||||
# check if it's a valid file path
|
||||
if os.path.isfile(worker_config):
|
||||
if proxy_config.is_yaml(config_file_path=worker_config):
|
||||
if worker_config is not None:
|
||||
if (
|
||||
isinstance(worker_config, str)
|
||||
and os.path.isfile(worker_config)
|
||||
and proxy_config.is_yaml(config_file_path=worker_config)
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
|
@ -2763,21 +2746,23 @@ async def startup_event():
|
|||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
else:
|
||||
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
|
||||
worker_config, str
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||
await initialize(**worker_config)
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(worker_config)
|
||||
if isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
|
||||
## CHECK PREMIUM USER
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -2825,7 +2810,7 @@ async def startup_event():
|
|||
if general_settings.get("litellm_jwtauth", None) is not None:
|
||||
for k, v in general_settings["litellm_jwtauth"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
general_settings["litellm_jwtauth"][k] = litellm.get_secret(v)
|
||||
general_settings["litellm_jwtauth"][k] = get_secret(v)
|
||||
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
|
||||
else:
|
||||
litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
@ -2948,8 +2933,7 @@ async def startup_event():
|
|||
|
||||
### ADD NEW MODELS ###
|
||||
store_model_in_db = (
|
||||
litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db)
|
||||
or store_model_in_db
|
||||
get_secret("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db
|
||||
) # type: ignore
|
||||
if store_model_in_db == True:
|
||||
scheduler.add_job(
|
||||
|
@ -3498,7 +3482,7 @@ async def completion(
|
|||
)
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
|
||||
)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
|
@ -4000,7 +3984,7 @@ async def audio_speech(
|
|||
request_data=data,
|
||||
)
|
||||
return StreamingResponse(
|
||||
generate(response), media_type="audio/mpeg", headers=custom_headers
|
||||
generate(response), media_type="audio/mpeg", headers=custom_headers # type: ignore
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -4288,6 +4272,7 @@ async def create_assistant(
|
|||
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
data = {} # ensure data always dict
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
|
@ -7642,6 +7627,7 @@ async def model_group_info(
|
|||
)
|
||||
|
||||
model_groups: List[ModelGroupInfo] = []
|
||||
|
||||
for model in all_models_str:
|
||||
|
||||
_model_group_info = llm_router.get_model_group_info(model_group=model)
|
||||
|
@ -8051,7 +8037,8 @@ async def google_login(request: Request):
|
|||
with microsoft_sso:
|
||||
return await microsoft_sso.get_login_redirect()
|
||||
elif generic_client_id is not None:
|
||||
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
|
@ -8616,6 +8603,8 @@ async def auth_callback(request: Request):
|
|||
redirect_url += "sso/callback"
|
||||
else:
|
||||
redirect_url += "/sso/callback"
|
||||
|
||||
result = None
|
||||
if google_client_id is not None:
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
|
||||
|
@ -8662,7 +8651,8 @@ async def auth_callback(request: Request):
|
|||
result = await microsoft_sso.verify_and_process(request)
|
||||
elif generic_client_id is not None:
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
|
@ -8766,8 +8756,8 @@ async def auth_callback(request: Request):
|
|||
verbose_proxy_logger.debug("generic result: %s", result)
|
||||
|
||||
# User is Authe'd in - generate key for the UI to access Proxy
|
||||
user_email = getattr(result, "email", None)
|
||||
user_id = getattr(result, "id", None)
|
||||
user_email: Optional[str] = getattr(result, "email", None)
|
||||
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
||||
|
||||
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
|
||||
email_domain = user_email.split("@")[1]
|
||||
|
@ -8783,12 +8773,12 @@ async def auth_callback(request: Request):
|
|||
)
|
||||
|
||||
# generic client id
|
||||
if generic_client_id is not None:
|
||||
if generic_client_id is not None and result is not None:
|
||||
user_id = getattr(result, "id", None)
|
||||
user_email = getattr(result, "email", None)
|
||||
user_role = getattr(result, generic_user_role_attribute_name, None)
|
||||
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
|
||||
|
||||
if user_id is None:
|
||||
if user_id is None and result is not None:
|
||||
_first_name = getattr(result, "first_name", "") or ""
|
||||
_last_name = getattr(result, "last_name", "") or ""
|
||||
user_id = _first_name + _last_name
|
||||
|
@ -8811,54 +8801,45 @@ async def auth_callback(request: Request):
|
|||
"spend": 0,
|
||||
"team_id": "litellm-dashboard",
|
||||
}
|
||||
user_defined_values: SSOUserDefinedValues = {
|
||||
"models": user_id_models,
|
||||
"user_id": user_id,
|
||||
"user_email": user_email,
|
||||
"max_budget": max_internal_user_budget,
|
||||
"user_role": None,
|
||||
"budget_duration": internal_user_budget_duration,
|
||||
}
|
||||
user_defined_values: Optional[SSOUserDefinedValues] = None
|
||||
if user_id is not None:
|
||||
user_defined_values = SSOUserDefinedValues(
|
||||
models=user_id_models,
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
max_budget=max_internal_user_budget,
|
||||
user_role=None,
|
||||
budget_duration=internal_user_budget_duration,
|
||||
)
|
||||
|
||||
_user_id_from_sso = user_id
|
||||
user_role = None
|
||||
try:
|
||||
user_role = None
|
||||
if prisma_client is not None:
|
||||
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
|
||||
)
|
||||
if user_info is not None:
|
||||
user_defined_values = {
|
||||
"models": getattr(user_info, "models", user_id_models),
|
||||
"user_id": getattr(user_info, "user_id", user_id),
|
||||
"user_email": getattr(user_info, "user_id", user_email),
|
||||
"user_role": getattr(user_info, "user_role", None),
|
||||
"max_budget": getattr(
|
||||
user_info, "max_budget", max_internal_user_budget
|
||||
),
|
||||
"budget_duration": getattr(
|
||||
user_info, "budget_duration", internal_user_budget_duration
|
||||
),
|
||||
}
|
||||
user_role = getattr(user_info, "user_role", None)
|
||||
if user_info is None:
|
||||
## check if user-email in db ##
|
||||
user_info = await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_email": user_email}
|
||||
)
|
||||
|
||||
## check if user-email in db ##
|
||||
user_info = await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_email": user_email}
|
||||
)
|
||||
if user_info is not None:
|
||||
user_defined_values = {
|
||||
"models": getattr(user_info, "models", user_id_models),
|
||||
"user_id": user_id,
|
||||
"user_email": getattr(user_info, "user_id", user_email),
|
||||
"user_role": getattr(user_info, "user_role", None),
|
||||
"max_budget": getattr(
|
||||
if user_info is not None and user_id is not None:
|
||||
user_defined_values = SSOUserDefinedValues(
|
||||
models=getattr(user_info, "models", user_id_models),
|
||||
user_id=user_id,
|
||||
user_email=getattr(user_info, "user_email", user_email),
|
||||
user_role=getattr(user_info, "user_role", None),
|
||||
max_budget=getattr(
|
||||
user_info, "max_budget", max_internal_user_budget
|
||||
),
|
||||
"budget_duration": getattr(
|
||||
budget_duration=getattr(
|
||||
user_info, "budget_duration", internal_user_budget_duration
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
user_role = getattr(user_info, "user_role", None)
|
||||
|
||||
# update id
|
||||
|
@ -8886,6 +8867,11 @@ async def auth_callback(request: Request):
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
if user_defined_values is None:
|
||||
raise Exception(
|
||||
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
|
||||
)
|
||||
|
||||
is_internal_user = False
|
||||
if (
|
||||
user_defined_values["user_role"] is not None
|
||||
|
@ -8960,7 +8946,8 @@ async def auth_callback(request: Request):
|
|||
master_key,
|
||||
algorithm="HS256",
|
||||
)
|
||||
litellm_dashboard_ui += "?userID=" + user_id
|
||||
if user_id is not None and isinstance(user_id, str):
|
||||
litellm_dashboard_ui += "?userID=" + user_id
|
||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
||||
return redirect_response
|
||||
|
@ -9023,6 +9010,7 @@ async def new_invitation(
|
|||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
} # type: ignore
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
if "Foreign key constraint failed on the field" in str(e):
|
||||
raise HTTPException(
|
||||
|
@ -9031,7 +9019,7 @@ async def new_invitation(
|
|||
"error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
|
||||
},
|
||||
)
|
||||
return response
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
@router.get(
|
||||
|
@ -9951,44 +9939,46 @@ async def get_routes():
|
|||
"""
|
||||
routes = []
|
||||
for route in app.routes:
|
||||
route_info = {
|
||||
"path": getattr(route, "path", None),
|
||||
"methods": getattr(route, "methods", None),
|
||||
"name": getattr(route, "name", None),
|
||||
"endpoint": (
|
||||
getattr(route, "endpoint", None).__name__
|
||||
if getattr(route, "endpoint", None)
|
||||
else None
|
||||
),
|
||||
}
|
||||
routes.append(route_info)
|
||||
endpoint_route = getattr(route, "endpoint", None)
|
||||
if endpoint_route is not None:
|
||||
route_info = {
|
||||
"path": getattr(route, "path", None),
|
||||
"methods": getattr(route, "methods", None),
|
||||
"name": getattr(route, "name", None),
|
||||
"endpoint": (
|
||||
endpoint_route.__name__
|
||||
if getattr(route, "endpoint", None)
|
||||
else None
|
||||
),
|
||||
}
|
||||
routes.append(route_info)
|
||||
|
||||
return {"routes": routes}
|
||||
|
||||
|
||||
#### TEST ENDPOINTS ####
|
||||
@router.get(
|
||||
"/token/generate",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def token_generate():
|
||||
"""
|
||||
Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
|
||||
"""
|
||||
# Initialize AuthJWTSSO with your OpenID Provider configuration
|
||||
from fastapi_sso import AuthJWTSSO
|
||||
# @router.get(
|
||||
# "/token/generate",
|
||||
# dependencies=[Depends(user_api_key_auth)],
|
||||
# include_in_schema=False,
|
||||
# )
|
||||
# async def token_generate():
|
||||
# """
|
||||
# Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
|
||||
# """
|
||||
# # Initialize AuthJWTSSO with your OpenID Provider configuration
|
||||
# from fastapi_sso import AuthJWTSSO
|
||||
|
||||
auth_jwt_sso = AuthJWTSSO(
|
||||
issuer=os.getenv("OPENID_BASE_URL"),
|
||||
client_id=os.getenv("OPENID_CLIENT_ID"),
|
||||
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
|
||||
scopes=["litellm_proxy_admin"],
|
||||
)
|
||||
# auth_jwt_sso = AuthJWTSSO(
|
||||
# issuer=os.getenv("OPENID_BASE_URL"),
|
||||
# client_id=os.getenv("OPENID_CLIENT_ID"),
|
||||
# client_secret=os.getenv("OPENID_CLIENT_SECRET"),
|
||||
# scopes=["litellm_proxy_admin"],
|
||||
# )
|
||||
|
||||
token = auth_jwt_sso.create_access_token()
|
||||
# token = auth_jwt_sso.create_access_token()
|
||||
|
||||
return {"token": token}
|
||||
# return {"token": token}
|
||||
|
||||
|
||||
@router.on_event("shutdown")
|
||||
|
@ -10013,7 +10003,8 @@ async def shutdown_event():
|
|||
# flush langfuse logs on shutdow
|
||||
from litellm.utils import langFuseLogger
|
||||
|
||||
langFuseLogger.Langfuse.flush()
|
||||
if langFuseLogger is not None:
|
||||
langFuseLogger.Langfuse.flush()
|
||||
except:
|
||||
# [DO NOT BLOCK shutdown events for this]
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue