diff --git a/docs/my-website/docs/proxy/ui.md b/docs/my-website/docs/proxy/ui.md index cc62ce096..c98c85f2d 100644 --- a/docs/my-website/docs/proxy/ui.md +++ b/docs/my-website/docs/proxy/ui.md @@ -37,12 +37,12 @@ http://0.0.0.0:8000/ui # /ui ``` -## Get Admin UI Link on Swagger +### 3. Get Admin UI Link on Swagger Your Proxy Swagger is available on the root of the Proxy: e.g.: `http://localhost:4000/` -## Change default username + password +### 4. Change default username + password Set the following in your .env on the Proxy @@ -111,6 +111,29 @@ MICROSOFT_TENANT="5a39737 + + + +A generic OAuth client that can be used to quickly create support for any OAuth provider with close to no code + +**Required .env variables on your Proxy** +```shell + +GENERIC_CLIENT_ID = "******" +GENERIC_CLIENT_SECRET = "G*******" +GENERIC_AUTHORIZATION_ENDPOINT = "http://localhost:9090/auth" +GENERIC_TOKEN_ENDPOINT = "http://localhost:9090/token" +GENERIC_USERINFO_ENDPOINT = "http://localhost:9090/me" +``` + +- Set Redirect URI, if your provider requires it + - Set a redirect url = `/sso/callback` + ```shell + http://localhost:4000/sso/callback + ``` + + + ### Step 3. Test flow diff --git a/docs/my-website/docs/proxy/user_keys.md b/docs/my-website/docs/proxy/user_keys.md index 47cfef9c3..fcccffaa0 100644 --- a/docs/my-website/docs/proxy/user_keys.md +++ b/docs/my-website/docs/proxy/user_keys.md @@ -197,7 +197,7 @@ from openai import OpenAI # set api_key to send to proxy server client = OpenAI(api_key="", base_url="http://0.0.0.0:8000") -response = openai.embeddings.create( +response = client.embeddings.create( input=["hello from litellm"], model="text-embedding-ada-002" ) @@ -281,6 +281,84 @@ print(query_result[:5]) ``` +## `/moderations` + + +### Request Format +Input, Output and Exceptions are mapped to the OpenAI format for all supported models + + + + +```python +import openai +from openai import OpenAI + +# set base_url to your proxy server +# set api_key to send to proxy server +client = OpenAI(api_key="", base_url="http://0.0.0.0:8000") + +response = client.moderations.create( + input="hello from litellm", + model="text-moderation-stable" +) + +print(response) + +``` + + + +```shell +curl --location 'http://0.0.0.0:8000/moderations' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer sk-1234' \ + --data '{"input": "Sample text goes here", "model": "text-moderation-stable"}' +``` + + + + +### Response Format + +```json +{ + "id": "modr-8sFEN22QCziALOfWTa77TodNLgHwA", + "model": "text-moderation-007", + "results": [ + { + "categories": { + "harassment": false, + "harassment/threatening": false, + "hate": false, + "hate/threatening": false, + "self-harm": false, + "self-harm/instructions": false, + "self-harm/intent": false, + "sexual": false, + "sexual/minors": false, + "violence": false, + "violence/graphic": false + }, + "category_scores": { + "harassment": 0.000019947197870351374, + "harassment/threatening": 5.5971017900446896e-6, + "hate": 0.000028560316422954202, + "hate/threatening": 2.2631787999216613e-8, + "self-harm": 2.9121162015144364e-7, + "self-harm/instructions": 9.314219084899378e-8, + "self-harm/intent": 8.093739012338119e-8, + "sexual": 0.00004414955765241757, + "sexual/minors": 0.0000156943697220413, + "violence": 0.00022354527027346194, + "violence/graphic": 8.804164281173144e-6 + }, + "flagged": false + } + ] +} +``` + ## Advanced diff --git a/litellm/main.py b/litellm/main.py index a7990ecfb..352ce1882 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2962,16 +2962,39 @@ def text_completion( ##### Moderation ####################### -def moderation(input: str, api_key: Optional[str] = None): + + +def moderation( + input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs +): # only supports open ai for now api_key = ( api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") ) - openai.api_key = api_key - openai.api_type = "open_ai" # type: ignore - openai.api_version = None - openai.base_url = "https://api.openai.com/v1/" - response = openai.moderations.create(input=input) + + openai_client = kwargs.get("client", None) + if openai_client is None: + openai_client = openai.OpenAI( + api_key=api_key, + ) + + response = openai_client.moderations.create(input=input, model=model) + return response + + +##### Moderation ####################### +@client +async def amoderation(input: str, model: str, api_key: Optional[str] = None, **kwargs): + # only supports open ai for now + api_key = ( + api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + ) + openai_client = kwargs.get("client", None) + if openai_client is None: + openai_client = openai.AsyncOpenAI( + api_key=api_key, + ) + response = await openai_client.moderations.create(input=input, model=model) return response diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 8d35bcae8..74a780c71 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -9,14 +9,19 @@ model_list: mode: chat max_tokens: 4096 base_model: azure/gpt-4-1106-preview + access_groups: ["public"] - model_name: openai-gpt-3.5 litellm_params: model: gpt-3.5-turbo api_key: os.environ/OPENAI_API_KEY + model_info: + access_groups: ["public"] - model_name: anthropic-claude-v2.1 litellm_params: model: bedrock/anthropic.claude-v2:1 timeout: 300 # sets a 5 minute timeout + model_info: + access_groups: ["private"] - model_name: anthropic-claude-v2 litellm_params: model: bedrock/anthropic.claude-v2 @@ -32,19 +37,13 @@ model_list: api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault model_info: base_model: azure/gpt-4 + - model_name: text-moderation-stable + litellm_params: + model: text-moderation-stable + api_key: os.environ/OPENAI_API_KEY litellm_settings: fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}] success_callback: ['langfuse'] - max_budget: 50 # global budget for proxy - max_user_budget: 0.0001 - budget_duration: 30d # global budget duration, will reset after 30d - default_key_generate_params: - max_budget: 1.5000 - models: ["azure-gpt-3.5"] - duration: None - upperbound_key_generate_params: - max_budget: 100 - duration: "30d" # setting callback class # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ff10b5a17..07e08ac61 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -403,34 +403,43 @@ async def user_api_key_auth( verbose_proxy_logger.debug( f"LLM Model List pre access group check: {llm_model_list}" ) - access_groups = [] + from collections import defaultdict + + access_groups = defaultdict(list) if llm_model_list is not None: for m in llm_model_list: for group in m.get("model_info", {}).get("access_groups", []): - access_groups.append((m["model_name"], group)) + model_name = m["model_name"] + access_groups[group].append(model_name) - allowed_models = valid_token.models - access_group_idx = set() + models_in_current_access_groups = [] if ( len(access_groups) > 0 ): # check if token contains any model access groups - for idx, m in enumerate(valid_token.models): - for model_name, group in access_groups: - if m == group: - access_group_idx.add(idx) - allowed_models.append(model_name) + for idx, m in enumerate( + valid_token.models + ): # loop token models, if any of them are an access group add the access group + if m in access_groups: + # if it is an access group we need to remove it from valid_token.models + models_in_group = access_groups[m] + models_in_current_access_groups.extend(models_in_group) + + # Filter out models that are access_groups + filtered_models = [ + m for m in valid_token.models if m not in access_groups + ] + + filtered_models += models_in_current_access_groups verbose_proxy_logger.debug( - f"model: {model}; allowed_models: {allowed_models}" + f"model: {model}; allowed_models: {filtered_models}" ) - if model is not None and model not in allowed_models: + if model is not None and model not in filtered_models: raise ValueError( f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" ) - for val in access_group_idx: - allowed_models.pop(val) - valid_token.models = allowed_models + valid_token.models = filtered_models verbose_proxy_logger.debug( - f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}" + f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" ) # Check 2. If user_id for this token is in budget @@ -2087,14 +2096,6 @@ def model_list( if user_model is not None: all_models += [user_model] verbose_proxy_logger.debug(f"all_models: {all_models}") - ### CHECK OLLAMA MODELS ### - try: - response = requests.get("http://0.0.0.0:11434/api/tags") - models = response.json()["models"] - ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models] - all_models.extend(ollama_models) - except Exception as e: - pass return dict( data=[ { @@ -2798,7 +2799,161 @@ async def image_generation( ) -#### KEY MANAGEMENT ##### + +@router.post( + "/v1/moderations", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["moderations"], +) +@router.post( + "/moderations", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["moderations"], +) +async def moderations( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + The moderations endpoint is a tool you can use to check whether content complies with an LLM Providers policies. + + Quick Start + ``` + curl --location 'http://0.0.0.0:4000/moderations' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer sk-1234' \ + --data '{"input": "Sample text goes here", "model": "text-moderation-stable"}' + ``` + """ + global proxy_logging_obj + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + body = await request.body() + data = orjson.loads(body) + + # Include original request and headers in the data + data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + data["model"] = ( + general_settings.get("moderation_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request + ) + if user_model: + data["model"] = user_model + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, data=data, call_type="moderation" + ) + + start_time = time.time() + + ## ROUTE TO CORRECT ENDPOINT ## + # skip router if user passed their key + if "api_key" in data: + response = await litellm.amoderation(**data) + elif ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list + response = await llm_router.amoderation(**data) + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.amoderation(**data, specific_deployment=True) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + response = await llm_router.amoderation( + **data + ) # ensure this goes the llm_router, router will do the correct alias mapping + elif user_model is not None: # `litellm --model ` + response = await litellm.amoderation(**data) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Invalid model name passed in"}, + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + end_time = time.time() + asyncio.create_task( + proxy_logging_obj.response_taking_too_long( + start_time=start_time, end_time=end_time, type="slow_response" + ) + ) + + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +#### KEY MANAGEMENT #### @router.post( @@ -3684,7 +3839,6 @@ async def user_update(data: UpdateUserRequest): code=status.HTTP_400_BAD_REQUEST, ) - #### TEAM MANAGEMENT #### @@ -3766,75 +3920,182 @@ async def team_info( ): """ get info on team + related keys - - ``` - curl --location 'http://localhost:4000/team/info' \ - --header 'Authorization: Bearer sk-1234' \ - --header 'Content-Type: application/json' \ - --data '{ - "teams": ["",..] - }' - ``` """ - global prisma_client - try: - if prisma_client is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={ - "error": f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - }, - ) - if team_id is None: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail={"message": "Malformed request. No team id passed in."}, - ) + pass + +@app.get("/sso/callback", tags=["experimental"]) +async def auth_callback(request: Request): + """Verify login""" + global general_settings + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) - team_info = await prisma_client.get_data( - team_id=team_id, table_name="team", query_type="find_unique" - ) - ## GET ALL KEYS ## - keys = await prisma_client.get_data( - team_id=team_id, - table_name="key", - query_type="find_all", - expires=datetime.now(), - ) + # get url from request + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) - if team_info is None: - ## make sure we still return a total spend ## - spend = 0 - for k in keys: - spend += getattr(k, "spend", 0) - team_info = {"spend": spend} + if redirect_url.endswith("/"): + redirect_url += "sso/callback" + else: + redirect_url += "/sso/callback" - ## REMOVE HASHED TOKEN INFO before returning ## - for key in keys: - try: - key = key.model_dump() # noqa - except: - # if using pydantic v1 - key = key.dict() - key.pop("token", None) - return {"team_id": team_id, "team_info": team_info, "keys": keys} + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO - except Exception as e: - if isinstance(e, HTTPException): + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, + google_sso = GoogleSSO( + client_id=google_client_id, + redirect_uri=redirect_url, + client_secret=google_client_secret, ) + result = await google_sso.verify_and_process(request) + + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if microsoft_tenant is None: + raise ProxyException( + message="MICROSOFT_TENANT not set. Set it in .env file", + type="auth_error", + param="MICROSOFT_TENANT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + result = await microsoft_sso.verify_and_process(request) + elif generic_client_id is not None: + # make generic sso provider + from fastapi_sso.sso.generic import create_provider, DiscoveryDocument + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_authorization_endpoint = os.getenv( + "GENERIC_AUTHORIZATION_ENDPOINT", None + ) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type="auth_error", + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type="auth_error", + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type="auth_error", + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + + SSOProvider = create_provider(name="oidc", discovery_document=discovery) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process") + + request_body = await request.body() + + request_query_params = request.query_params + + # get "code" from query params + code = request_query_params.get("code") + + result = await generic_sso.verify_and_process(request) + verbose_proxy_logger.debug(f"generic result: {result}") + + # User is Authe'd in - generate key for the UI to access Proxy + user_email = getattr(result, "email", None) + user_id = getattr(result, "id", None) + if user_id is None: + user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "") + + response = await generate_key_helper_fn( + **{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore + ) + key = response["token"] # type: ignore + user_id = response["user_id"] # type: ignore + + litellm_dashboard_ui = "/ui/" + + user_role = "app_owner" + if ( + os.getenv("PROXY_ADMIN_ID", None) is not None + and os.environ["PROXY_ADMIN_ID"] == user_id + ): + # checks if user is admin + user_role = "app_admin" + + import jwt + + jwt_token = jwt.encode( + { + "user_id": user_id, + "key": key, + "user_email": user_email, + "user_role": user_role, + }, + "secret", + algorithm="HS256", + ) + litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token + + # if a user has logged in they should be allowed to create keys - this ensures that it's set to True + general_settings["allow_user_auth"] = True + return RedirectResponse(url=litellm_dashboard_ui) #### MODEL MANAGEMENT #### @@ -4260,6 +4521,73 @@ async def google_login(request: Request): ) with microsoft_sso: return await microsoft_sso.get_login_redirect() + elif generic_client_id is not None: + from fastapi_sso.sso.generic import create_provider, DiscoveryDocument + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_authorization_endpoint = os.getenv( + "GENERIC_AUTHORIZATION_ENDPOINT", None + ) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type="auth_error", + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type="auth_error", + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type="auth_error", + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + + SSOProvider = create_provider(name="oidc", discovery_document=discovery) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + + with generic_sso: + return await generic_sso.get_login_redirect() + + elif ui_username is not None: + # No Google, Microsoft SSO + # Use UI Credentials set in .env + from fastapi.responses import HTMLResponse elif ui_username is not None: # No Google, Microsoft SSO # Use UI Credentials set in .env diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d82a7231f..d3c95f350 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -93,7 +93,9 @@ class ProxyLogging: self, user_api_key_dict: UserAPIKeyAuth, data: dict, - call_type: Literal["completion", "embeddings", "image_generation"], + call_type: Literal[ + "completion", "embeddings", "image_generation", "moderation" + ], ): """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. diff --git a/litellm/router.py b/litellm/router.py index 21e967576..b64b111a1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -599,6 +599,98 @@ class Router: self.fail_calls[model_name] += 1 raise e + async def amoderation(self, model: str, input: str, **kwargs): + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._amoderation + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + async def _amoderation(self, model: str, input: str, **kwargs): + model_name = None + try: + verbose_router_logger.debug( + f"Inside _moderation()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs and v is not None + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + self.total_calls[model_name] += 1 + + timeout = ( + data.get( + "timeout", None + ) # timeout set on litellm_params for this deployment + or self.timeout # timeout set on router + or kwargs.get( + "timeout", None + ) # this uses default_litellm_params when nothing is set + ) + + response = await litellm.amoderation( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + "timeout": timeout, + **kwargs, + } + ) + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + def text_completion( self, model: str, diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index 3f8cb513b..57b56e87f 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -86,7 +86,7 @@ class LowestLatencyLoggingHandler(CustomLogger): if isinstance(response_obj, ModelResponse): completion_tokens = response_obj.usage.completion_tokens total_tokens = response_obj.usage.total_tokens - final_value = float(completion_tokens / response_ms.total_seconds()) + final_value = float(response_ms.total_seconds() / completion_tokens) # ------------ # Update usage @@ -168,7 +168,7 @@ class LowestLatencyLoggingHandler(CustomLogger): if isinstance(response_obj, ModelResponse): completion_tokens = response_obj.usage.completion_tokens total_tokens = response_obj.usage.total_tokens - final_value = float(completion_tokens / response_ms.total_seconds()) + final_value = float(response_ms.total_seconds() / completion_tokens) # ------------ # Update usage diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index e93b00ef6..17ced7382 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2093,10 +2093,6 @@ def test_completion_cloudflare(): def test_moderation(): - import openai - - openai.api_type = "azure" - openai.api_version = "GM" response = litellm.moderation(input="i'm ishaan cto of litellm") print(response) output = response.results[0] diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index b9ca29cee..ab329e14a 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -991,3 +991,23 @@ def test_router_timeout(): print(e) print(vars(e)) pass + + +@pytest.mark.asyncio +async def test_router_amoderation(): + model_list = [ + { + "model_name": "openai-moderations", + "litellm_params": { + "model": "text-moderation-stable", + "api_key": os.getenv("OPENAI_API_KEY", None), + }, + } + ] + + router = Router(model_list=model_list) + result = await router.amoderation( + model="openai-moderations", input="this is valid good text" + ) + + print("moderation result", result) diff --git a/litellm/utils.py b/litellm/utils.py index e238b84d7..b15be366d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -738,6 +738,8 @@ class CallTypes(Enum): text_completion = "text_completion" image_generation = "image_generation" aimage_generation = "aimage_generation" + moderation = "moderation" + amoderation = "amoderation" # Logging function -> log the exact model details + what's being sent | Non-BlockingP @@ -2100,6 +2102,11 @@ def client(original_function): or call_type == CallTypes.aimage_generation.value ): messages = args[0] if len(args) > 0 else kwargs["prompt"] + elif ( + call_type == CallTypes.moderation.value + or call_type == CallTypes.amoderation.value + ): + messages = args[1] if len(args) > 1 else kwargs["input"] elif ( call_type == CallTypes.atext_completion.value or call_type == CallTypes.text_completion.value diff --git a/pyproject.toml b/pyproject.toml index 921547692..5a7ecbc50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.23.15" +version = "1.23.16" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -69,7 +69,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.23.15" +version = "1.23.16" version_files = [ "pyproject.toml:^version" ]