diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d6bf49dca..7f04b70e2 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1001,3 +1001,15 @@ class LiteLLM_ErrorLogs(LiteLLMBase): class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase): response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None + + +class TokenCountRequest(LiteLLMBase): + model: str + prompt: Optional[str] = None + messages: Optional[List[dict]] = None + + +class TokenCountResponse(LiteLLMBase): + total_tokens: int + model: str + custom_llm_provider: str diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 60ed4c269..dc7ffea73 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4766,25 +4766,24 @@ async def moderations( @router.post( - "/dev/token_counter ", - tags=["LLM Utils"], - dependencies=[Depends(user_api_key_auth)], - responses={ - 200: { - "cost": { - "description": "The calculated cost", - "example": 0.0, - "type": "float", - } - } - }, + "/dev/token_counter", tags=["LLM Utils"], dependencies=[Depends(user_api_key_auth)] ) -async def token_counter(request: Request): +async def token_counter(request: TokenCountRequest): """ """ from litellm import token_counter - data = await request.json() - total_tokens = token_counter(**data) + prompt = request.prompt + messages = request.messages + + if prompt is None and messages is None: + raise HTTPException( + status_code=400, detail="prompt or messages must be provided" + ) + total_tokens = token_counter( + model=request.model, + text=prompt, + messages=messages, + ) return {"total_tokens": total_tokens}