forked from phoenix/litellm-mirror
feat(proxy_server.py): enable custom branding + routes on openapi docs
Allows user to add their branding + show only openai routes on docs
This commit is contained in:
parent
f1a482f358
commit
c0d62e94ae
4 changed files with 89 additions and 11 deletions
|
@ -88,7 +88,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/models",
|
"/models",
|
||||||
"/v1/models",
|
"/v1/models",
|
||||||
# token counter
|
# token counter
|
||||||
"utils/token_counter",
|
"/utils/token_counter",
|
||||||
]
|
]
|
||||||
|
|
||||||
info_routes: List = [
|
info_routes: List = [
|
||||||
|
|
42
litellm/proxy/auth/litellm_license.py
Normal file
42
litellm/proxy/auth/litellm_license.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
# What is this?
|
||||||
|
## If litellm license in env, checks if it's valid
|
||||||
|
import os
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
|
||||||
|
class LicenseCheck:
|
||||||
|
"""
|
||||||
|
- Check if license in env
|
||||||
|
- Returns if license is valid
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_url = "https://license.litellm.ai"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||||
|
self.http_handler = HTTPHandler()
|
||||||
|
|
||||||
|
def _verify(self, license_str: str) -> bool:
|
||||||
|
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
||||||
|
|
||||||
|
try: # don't impact user, if call fails
|
||||||
|
response = self.http_handler.get(url=url)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
premium = response_json["valid"]
|
||||||
|
|
||||||
|
assert isinstance(premium, bool)
|
||||||
|
|
||||||
|
return premium
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_premium(self) -> bool:
|
||||||
|
if self.license_str is None:
|
||||||
|
return False
|
||||||
|
elif self._verify(license_str=self.license_str):
|
||||||
|
return True
|
||||||
|
return False
|
|
@ -110,6 +110,7 @@ from litellm.router import LiteLLM_Params, Deployment, updateDeployment
|
||||||
from litellm.router import ModelInfo as RouterModelInfo
|
from litellm.router import ModelInfo as RouterModelInfo
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
_OPTIONAL_PromptInjectionDetection,
|
_OPTIONAL_PromptInjectionDetection,
|
||||||
)
|
)
|
||||||
|
@ -150,6 +151,7 @@ from fastapi.responses import (
|
||||||
ORJSONResponse,
|
ORJSONResponse,
|
||||||
JSONResponse,
|
JSONResponse,
|
||||||
)
|
)
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
@ -169,17 +171,30 @@ except Exception as e:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
_license_check = LicenseCheck()
|
||||||
|
premium_user: bool = _license_check.is_premium()
|
||||||
|
|
||||||
ui_link = f"/ui/"
|
ui_link = f"/ui/"
|
||||||
ui_message = (
|
ui_message = (
|
||||||
f"👉 [```LiteLLM Admin Panel on /ui```]({ui_link}). Create, Edit Keys with SSO"
|
f"👉 [```LiteLLM Admin Panel on /ui```]({ui_link}). Create, Edit Keys with SSO"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### CUSTOM BRANDING [ENTERPRISE FEATURE] ###
|
||||||
_docs_url = None if os.getenv("NO_DOCS", "False") == "True" else "/"
|
_docs_url = None if os.getenv("NO_DOCS", "False") == "True" else "/"
|
||||||
|
_title = os.getenv("DOCS_TITLE", "LiteLLM API") if premium_user else "LiteLLM API"
|
||||||
|
_description = (
|
||||||
|
os.getenv(
|
||||||
|
"DOCS_DESCRIPTION",
|
||||||
|
f"Proxy Server to call 100+ LLMs in the OpenAI format\n\n{ui_message}",
|
||||||
|
)
|
||||||
|
if premium_user
|
||||||
|
else f"Proxy Server to call 100+ LLMs in the OpenAI format\n\n{ui_message}"
|
||||||
|
)
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
docs_url=_docs_url,
|
docs_url=_docs_url,
|
||||||
title="LiteLLM API",
|
title=_title,
|
||||||
description=f"Proxy Server to call 100+ LLMs in the OpenAI format\n\n{ui_message}",
|
description=_description,
|
||||||
version=version,
|
version=version,
|
||||||
root_path=os.environ.get(
|
root_path=os.environ.get(
|
||||||
"SERVER_ROOT_PATH", ""
|
"SERVER_ROOT_PATH", ""
|
||||||
|
@ -187,6 +202,31 @@ app = FastAPI(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
### CUSTOM API DOCS [ENTERPRISE FEATURE] ###
|
||||||
|
# Custom OpenAPI schema generator to include only selected routes
|
||||||
|
def custom_openapi():
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
version=app.version,
|
||||||
|
description=app.description,
|
||||||
|
routes=app.routes,
|
||||||
|
)
|
||||||
|
# Filter routes to include only specific ones
|
||||||
|
openai_routes = LiteLLMRoutes.openai_routes.value
|
||||||
|
paths_to_include: dict = {}
|
||||||
|
for route in openai_routes:
|
||||||
|
paths_to_include[route] = openapi_schema["paths"][route]
|
||||||
|
openapi_schema["paths"] = paths_to_include
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
if os.getenv("DOCS_FILTERED", "False") == "True" and premium_user:
|
||||||
|
app.openapi = custom_openapi # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ProxyException(Exception):
|
class ProxyException(Exception):
|
||||||
# NOTE: DO NOT MODIFY THIS
|
# NOTE: DO NOT MODIFY THIS
|
||||||
# This is used to map exactly to OPENAI Exceptions
|
# This is used to map exactly to OPENAI Exceptions
|
||||||
|
@ -4874,11 +4914,12 @@ async def token_counter(request: TokenCountRequest):
|
||||||
model_to_use = (
|
model_to_use = (
|
||||||
litellm_model_name or request.model
|
litellm_model_name or request.model
|
||||||
) # use litellm model name, if it's not avalable then fallback to request.model
|
) # use litellm model name, if it's not avalable then fallback to request.model
|
||||||
total_tokens, tokenizer_used = token_counter(
|
_tokenizer_used = litellm.utils._select_tokenizer(model=model_to_use)
|
||||||
|
tokenizer_used = _tokenizer_used["type"]
|
||||||
|
total_tokens = token_counter(
|
||||||
model=model_to_use,
|
model=model_to_use,
|
||||||
text=prompt,
|
text=prompt,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
return_tokenizer_used=True,
|
|
||||||
)
|
)
|
||||||
return TokenCountResponse(
|
return TokenCountResponse(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
|
|
|
@ -4123,8 +4123,7 @@ def token_counter(
|
||||||
text: Optional[Union[str, List[str]]] = None,
|
text: Optional[Union[str, List[str]]] = None,
|
||||||
messages: Optional[List] = None,
|
messages: Optional[List] = None,
|
||||||
count_response_tokens: Optional[bool] = False,
|
count_response_tokens: Optional[bool] = False,
|
||||||
return_tokenizer_used: Optional[bool] = False,
|
) -> int:
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Count the number of tokens in a given text using a specified model.
|
Count the number of tokens in a given text using a specified model.
|
||||||
|
|
||||||
|
@ -4216,10 +4215,6 @@ def token_counter(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
||||||
_tokenizer_type = tokenizer_json["type"]
|
|
||||||
if return_tokenizer_used:
|
|
||||||
# used by litellm proxy server -> POST /utils/token_counter
|
|
||||||
return num_tokens, _tokenizer_type
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue