mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'main' into litellm_aioboto3_sagemaker
This commit is contained in:
commit
57654f4533
79 changed files with 3440 additions and 253 deletions
|
@ -403,34 +403,43 @@ async def user_api_key_auth(
|
|||
verbose_proxy_logger.debug(
|
||||
f"LLM Model List pre access group check: {llm_model_list}"
|
||||
)
|
||||
access_groups = []
|
||||
from collections import defaultdict
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if llm_model_list is not None:
|
||||
for m in llm_model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
access_groups.append((m["model_name"], group))
|
||||
model_name = m["model_name"]
|
||||
access_groups[group].append(model_name)
|
||||
|
||||
allowed_models = valid_token.models
|
||||
access_group_idx = set()
|
||||
models_in_current_access_groups = []
|
||||
if (
|
||||
len(access_groups) > 0
|
||||
): # check if token contains any model access groups
|
||||
for idx, m in enumerate(valid_token.models):
|
||||
for model_name, group in access_groups:
|
||||
if m == group:
|
||||
access_group_idx.add(idx)
|
||||
allowed_models.append(model_name)
|
||||
for idx, m in enumerate(
|
||||
valid_token.models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
if m in access_groups:
|
||||
# if it is an access group we need to remove it from valid_token.models
|
||||
models_in_group = access_groups[m]
|
||||
models_in_current_access_groups.extend(models_in_group)
|
||||
|
||||
# Filter out models that are access_groups
|
||||
filtered_models = [
|
||||
m for m in valid_token.models if m not in access_groups
|
||||
]
|
||||
|
||||
filtered_models += models_in_current_access_groups
|
||||
verbose_proxy_logger.debug(
|
||||
f"model: {model}; allowed_models: {allowed_models}"
|
||||
f"model: {model}; allowed_models: {filtered_models}"
|
||||
)
|
||||
if model is not None and model not in allowed_models:
|
||||
if model is not None and model not in filtered_models:
|
||||
raise ValueError(
|
||||
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
||||
)
|
||||
for val in access_group_idx:
|
||||
allowed_models.pop(val)
|
||||
valid_token.models = allowed_models
|
||||
valid_token.models = filtered_models
|
||||
verbose_proxy_logger.debug(
|
||||
f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}"
|
||||
f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
|
||||
)
|
||||
|
||||
# Check 2. If user_id for this token is in budget
|
||||
|
@ -682,34 +691,31 @@ async def user_api_key_auth(
|
|||
# sso/login, ui/login, /key functions and /user functions
|
||||
# this will never be allowed to call /chat/completions
|
||||
token_team = getattr(valid_token, "team_id", None)
|
||||
if token_team is not None:
|
||||
if token_team == "litellm-dashboard":
|
||||
# this token is only used for managing the ui
|
||||
allowed_routes = [
|
||||
"/sso",
|
||||
"/login",
|
||||
"/key",
|
||||
"/spend",
|
||||
"/user",
|
||||
]
|
||||
# check if the current route startswith any of the allowed routes
|
||||
if (
|
||||
route is not None
|
||||
and isinstance(route, str)
|
||||
and any(
|
||||
route.startswith(allowed_route)
|
||||
for allowed_route in allowed_routes
|
||||
)
|
||||
):
|
||||
# Do something if the current route starts with any of the allowed routes
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
|
||||
)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid Key Passed to LiteLLM Proxy")
|
||||
if token_team is not None and token_team == "litellm-dashboard":
|
||||
# this token is only used for managing the ui
|
||||
allowed_routes = [
|
||||
"/sso",
|
||||
"/login",
|
||||
"/key",
|
||||
"/spend",
|
||||
"/user",
|
||||
"/model/info",
|
||||
]
|
||||
# check if the current route startswith any of the allowed routes
|
||||
if (
|
||||
route is not None
|
||||
and isinstance(route, str)
|
||||
and any(
|
||||
route.startswith(allowed_route) for allowed_route in allowed_routes
|
||||
)
|
||||
):
|
||||
# Do something if the current route starts with any of the allowed routes
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
|
||||
)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
except Exception as e:
|
||||
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}")
|
||||
traceback.print_exc()
|
||||
|
@ -1443,6 +1449,24 @@ class ProxyConfig:
|
|||
database_type == "dynamo_db" or database_type == "dynamodb"
|
||||
):
|
||||
database_args = general_settings.get("database_args", None)
|
||||
### LOAD FROM os.environ/ ###
|
||||
for k, v in database_args.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
database_args[k] = litellm.get_secret(v)
|
||||
if isinstance(k, str) and k == "aws_web_identity_token":
|
||||
value = database_args[k]
|
||||
verbose_proxy_logger.debug(
|
||||
f"Loading AWS Web Identity Token from file: {value}"
|
||||
)
|
||||
if os.path.exists(value):
|
||||
with open(value, "r") as file:
|
||||
token_content = file.read()
|
||||
database_args[k] = token_content
|
||||
else:
|
||||
verbose_proxy_logger.info(
|
||||
f"DynamoDB Loading - {value} is not a valid file path"
|
||||
)
|
||||
verbose_proxy_logger.debug(f"database_args: {database_args}")
|
||||
custom_db_client = DBClient(
|
||||
custom_db_args=database_args, custom_db_type=database_type
|
||||
)
|
||||
|
@ -1580,8 +1604,6 @@ async def generate_key_helper_fn(
|
|||
tpm_limit = tpm_limit
|
||||
rpm_limit = rpm_limit
|
||||
allowed_cache_controls = allowed_cache_controls
|
||||
if type(team_id) is not str:
|
||||
team_id = str(team_id)
|
||||
try:
|
||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||
user_data = {
|
||||
|
@ -2057,14 +2079,6 @@ def model_list(
|
|||
if user_model is not None:
|
||||
all_models += [user_model]
|
||||
verbose_proxy_logger.debug(f"all_models: {all_models}")
|
||||
### CHECK OLLAMA MODELS ###
|
||||
try:
|
||||
response = requests.get("http://0.0.0.0:11434/api/tags")
|
||||
models = response.json()["models"]
|
||||
ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models]
|
||||
all_models.extend(ollama_models)
|
||||
except Exception as e:
|
||||
pass
|
||||
return dict(
|
||||
data=[
|
||||
{
|
||||
|
@ -2355,8 +2369,13 @@ async def chat_completion(
|
|||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.acompletion(**data, specific_deployment=True)
|
||||
else: # router is not set
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.acompletion(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Invalid model name passed in"},
|
||||
)
|
||||
|
||||
# Post Call Processing
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -2387,6 +2406,11 @@ async def chat_completion(
|
|||
)
|
||||
|
||||
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
@ -2417,7 +2441,12 @@ async def chat_completion(
|
|||
traceback.print_exc()
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", str(e)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
@ -2567,8 +2596,13 @@ async def embeddings(
|
|||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aembedding(**data, specific_deployment=True)
|
||||
else:
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.aembedding(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Invalid model name passed in"},
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -2586,7 +2620,12 @@ async def embeddings(
|
|||
)
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
@ -2702,8 +2741,13 @@ async def image_generation(
|
|||
response = await llm_router.aimage_generation(
|
||||
**data
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
else:
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.aimage_generation(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Invalid model name passed in"},
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -2721,7 +2765,165 @@ async def image_generation(
|
|||
)
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/moderations",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["moderations"],
|
||||
)
|
||||
@router.post(
|
||||
"/moderations",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["moderations"],
|
||||
)
|
||||
async def moderations(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
The moderations endpoint is a tool you can use to check whether content complies with an LLM Providers policies.
|
||||
|
||||
Quick Start
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/moderations' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
|
||||
```
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"body": copy.copy(data), # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
data["model"] = (
|
||||
general_settings.get("moderation_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or data["model"] # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||
_headers = dict(request.headers)
|
||||
_headers.pop(
|
||||
"authorization", None
|
||||
) # do not store the original `sk-..` api key in the db
|
||||
data["metadata"]["headers"] = _headers
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["endpoint"] = str(request.url)
|
||||
|
||||
### TEAM-SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.team_id is not None:
|
||||
team_config = await proxy_config.load_team_config(
|
||||
team_id=user_api_key_dict.team_id
|
||||
)
|
||||
if len(team_config) == 0:
|
||||
pass
|
||||
else:
|
||||
team_id = team_config.pop("team_id", None)
|
||||
data["metadata"]["team_id"] = team_id
|
||||
data = {
|
||||
**team_config,
|
||||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="moderation"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.amoderation(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.amoderation(**data)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.amoderation(**data, specific_deployment=True)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.amoderation(
|
||||
**data
|
||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.amoderation(**data)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Invalid model name passed in"},
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.response_taking_too_long(
|
||||
start_time=start_time, end_time=end_time, type="slow_response"
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
)
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
@ -3516,6 +3718,7 @@ async def google_login(request: Request):
|
|||
"""
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
# get url from request
|
||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||
|
@ -3574,6 +3777,69 @@ async def google_login(request: Request):
|
|||
)
|
||||
with microsoft_sso:
|
||||
return await microsoft_sso.get_login_redirect()
|
||||
elif generic_client_id is not None:
|
||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_authorization_endpoint = os.getenv(
|
||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||
)
|
||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||
if generic_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||
type="auth_error",
|
||||
param="GENERIC_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_authorization_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||
type="auth_error",
|
||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_token_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||
type="auth_error",
|
||||
param="GENERIC_TOKEN_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_userinfo_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||
type="auth_error",
|
||||
param="GENERIC_USERINFO_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||
)
|
||||
|
||||
discovery = DiscoveryDocument(
|
||||
authorization_endpoint=generic_authorization_endpoint,
|
||||
token_endpoint=generic_token_endpoint,
|
||||
userinfo_endpoint=generic_userinfo_endpoint,
|
||||
)
|
||||
|
||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||
generic_sso = SSOProvider(
|
||||
client_id=generic_client_id,
|
||||
client_secret=generic_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
|
||||
with generic_sso:
|
||||
return await generic_sso.get_login_redirect()
|
||||
|
||||
elif ui_username is not None:
|
||||
# No Google, Microsoft SSO
|
||||
# Use UI Credentials set in .env
|
||||
|
@ -3673,6 +3939,7 @@ async def auth_callback(request: Request):
|
|||
global general_settings
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
# get url from request
|
||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||
|
@ -3728,6 +3995,77 @@ async def auth_callback(request: Request):
|
|||
allow_insecure_http=True,
|
||||
)
|
||||
result = await microsoft_sso.verify_and_process(request)
|
||||
elif generic_client_id is not None:
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_authorization_endpoint = os.getenv(
|
||||
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
||||
)
|
||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||
if generic_client_secret is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||
type="auth_error",
|
||||
param="GENERIC_CLIENT_SECRET",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_authorization_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
||||
type="auth_error",
|
||||
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_token_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
||||
type="auth_error",
|
||||
param="GENERIC_TOKEN_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if generic_userinfo_endpoint is None:
|
||||
raise ProxyException(
|
||||
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
||||
type="auth_error",
|
||||
param="GENERIC_USERINFO_ENDPOINT",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||
)
|
||||
|
||||
discovery = DiscoveryDocument(
|
||||
authorization_endpoint=generic_authorization_endpoint,
|
||||
token_endpoint=generic_token_endpoint,
|
||||
userinfo_endpoint=generic_userinfo_endpoint,
|
||||
)
|
||||
|
||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
||||
generic_sso = SSOProvider(
|
||||
client_id=generic_client_id,
|
||||
client_secret=generic_client_secret,
|
||||
redirect_uri=redirect_url,
|
||||
allow_insecure_http=True,
|
||||
)
|
||||
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
|
||||
|
||||
request_body = await request.body()
|
||||
|
||||
request_query_params = request.query_params
|
||||
|
||||
# get "code" from query params
|
||||
code = request_query_params.get("code")
|
||||
|
||||
result = await generic_sso.verify_and_process(request)
|
||||
verbose_proxy_logger.debug(f"generic result: {result}")
|
||||
|
||||
# User is Authe'd in - generate key for the UI to access Proxy
|
||||
user_email = getattr(result, "email", None)
|
||||
|
@ -3936,7 +4274,6 @@ async def add_new_model(model_params: ModelParams):
|
|||
)
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
|
||||
@router.get(
|
||||
"/model/info",
|
||||
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
|
||||
|
@ -3969,6 +4306,28 @@ async def model_info_v1(
|
|||
# read litellm model_prices_and_context_window.json to get the following:
|
||||
# input_cost_per_token, output_cost_per_token, max_tokens
|
||||
litellm_model_info = get_litellm_model_info(model=model)
|
||||
|
||||
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
|
||||
if litellm_model_info == {}:
|
||||
# use litellm_param model_name to get model_info
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
litellm_model = litellm_params.get("model", None)
|
||||
try:
|
||||
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
||||
except:
|
||||
litellm_model_info = {}
|
||||
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
|
||||
if litellm_model_info == {}:
|
||||
# use litellm_param model_name to get model_info
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
litellm_model = litellm_params.get("model", None)
|
||||
split_model = litellm_model.split("/")
|
||||
if len(split_model) > 0:
|
||||
litellm_model = split_model[-1]
|
||||
try:
|
||||
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
||||
except:
|
||||
litellm_model_info = {}
|
||||
for k, v in litellm_model_info.items():
|
||||
if k not in model_info:
|
||||
model_info[k] = v
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue