diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 9a5368049..b900b623b 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -88,7 +88,7 @@ class LiteLLMRoutes(enum.Enum): "/models", "/v1/models", # token counter - "utils/token_counter", + "/utils/token_counter", ] info_routes: List = [ diff --git a/litellm/proxy/auth/litellm_license.py b/litellm/proxy/auth/litellm_license.py new file mode 100644 index 000000000..59c193e1d --- /dev/null +++ b/litellm/proxy/auth/litellm_license.py @@ -0,0 +1,42 @@ +# What is this? +## If litellm license in env, checks if it's valid +import os +from litellm.llms.custom_httpx.http_handler import HTTPHandler + + +class LicenseCheck: + """ + - Check if license in env + - Returns if license is valid + """ + + base_url = "https://license.litellm.ai" + + def __init__(self) -> None: + self.license_str = os.getenv("LITELLM_LICENSE", None) + self.http_handler = HTTPHandler() + + def _verify(self, license_str: str) -> bool: + url = "{}/verify_license/{}".format(self.base_url, license_str) + + try: # don't impact user, if call fails + response = self.http_handler.get(url=url) + + response.raise_for_status() + + response_json = response.json() + + premium = response_json["valid"] + + assert isinstance(premium, bool) + + return premium + except Exception as e: + return False + + def is_premium(self) -> bool: + if self.license_str is None: + return False + elif self._verify(license_str=self.license_str): + return True + return False diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 763a53daf..08dd8c906 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -110,6 +110,7 @@ from litellm.router import LiteLLM_Params, Deployment, updateDeployment from litellm.router import ModelInfo as RouterModelInfo from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm.proxy.auth.handle_jwt import JWTHandler +from litellm.proxy.auth.litellm_license import LicenseCheck from litellm.proxy.hooks.prompt_injection_detection import ( _OPTIONAL_PromptInjectionDetection, ) @@ -150,6 +151,7 @@ from fastapi.responses import ( ORJSONResponse, JSONResponse, ) +from fastapi.openapi.utils import get_openapi from fastapi.responses import RedirectResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles @@ -169,17 +171,30 @@ except Exception as e: except Exception as e: pass +_license_check = LicenseCheck() +premium_user: bool = _license_check.is_premium() + ui_link = f"/ui/" ui_message = ( f"👉 [```LiteLLM Admin Panel on /ui```]({ui_link}). Create, Edit Keys with SSO" ) +### CUSTOM BRANDING [ENTERPRISE FEATURE] ### _docs_url = None if os.getenv("NO_DOCS", "False") == "True" else "/" +_title = os.getenv("DOCS_TITLE", "LiteLLM API") if premium_user else "LiteLLM API" +_description = ( + os.getenv( + "DOCS_DESCRIPTION", + f"Proxy Server to call 100+ LLMs in the OpenAI format\n\n{ui_message}", + ) + if premium_user + else f"Proxy Server to call 100+ LLMs in the OpenAI format\n\n{ui_message}" +) app = FastAPI( docs_url=_docs_url, - title="LiteLLM API", - description=f"Proxy Server to call 100+ LLMs in the OpenAI format\n\n{ui_message}", + title=_title, + description=_description, version=version, root_path=os.environ.get( "SERVER_ROOT_PATH", "" @@ -187,6 +202,31 @@ app = FastAPI( ) +### CUSTOM API DOCS [ENTERPRISE FEATURE] ### +# Custom OpenAPI schema generator to include only selected routes +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + openapi_schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes, + ) + # Filter routes to include only specific ones + openai_routes = LiteLLMRoutes.openai_routes.value + paths_to_include: dict = {} + for route in openai_routes: + paths_to_include[route] = openapi_schema["paths"][route] + openapi_schema["paths"] = paths_to_include + app.openapi_schema = openapi_schema + return app.openapi_schema + + +if os.getenv("DOCS_FILTERED", "False") == "True" and premium_user: + app.openapi = custom_openapi # type: ignore + + class ProxyException(Exception): # NOTE: DO NOT MODIFY THIS # This is used to map exactly to OPENAI Exceptions @@ -4874,11 +4914,12 @@ async def token_counter(request: TokenCountRequest): model_to_use = ( litellm_model_name or request.model ) # use litellm model name, if it's not avalable then fallback to request.model - total_tokens, tokenizer_used = token_counter( + _tokenizer_used = litellm.utils._select_tokenizer(model=model_to_use) + tokenizer_used = _tokenizer_used["type"] + total_tokens = token_counter( model=model_to_use, text=prompt, messages=messages, - return_tokenizer_used=True, ) return TokenCountResponse( total_tokens=total_tokens, diff --git a/litellm/utils.py b/litellm/utils.py index ebf5d980c..5d5c2b69c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4123,8 +4123,7 @@ def token_counter( text: Optional[Union[str, List[str]]] = None, messages: Optional[List] = None, count_response_tokens: Optional[bool] = False, - return_tokenizer_used: Optional[bool] = False, -): +) -> int: """ Count the number of tokens in a given text using a specified model. @@ -4216,10 +4215,6 @@ def token_counter( ) else: num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore - _tokenizer_type = tokenizer_json["type"] - if return_tokenizer_used: - # used by litellm proxy server -> POST /utils/token_counter - return num_tokens, _tokenizer_type return num_tokens