mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #3682 from BerriAI/litellm_token_counter_endpoint
[Feat] `token_counter` endpoint
This commit is contained in:
commit
0a816b2c45
4 changed files with 214 additions and 2 deletions
|
@ -89,6 +89,8 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/v1/models",
|
||||
]
|
||||
|
||||
llm_utils_routes: List = ["utils/token_counter"]
|
||||
|
||||
info_routes: List = [
|
||||
"/key/info",
|
||||
"/team/info",
|
||||
|
@ -1011,3 +1013,16 @@ class LiteLLM_ErrorLogs(LiteLLMBase):
|
|||
|
||||
class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase):
|
||||
response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None
|
||||
|
||||
|
||||
class TokenCountRequest(LiteLLMBase):
|
||||
model: str
|
||||
prompt: Optional[str] = None
|
||||
messages: Optional[List[dict]] = None
|
||||
|
||||
|
||||
class TokenCountResponse(LiteLLMBase):
|
||||
total_tokens: int
|
||||
request_model: str
|
||||
model_used: str
|
||||
tokenizer_type: str
|
||||
|
|
|
@ -4777,6 +4777,56 @@ async def moderations(
|
|||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/utils/token_counter",
|
||||
tags=["llm utils"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=TokenCountResponse,
|
||||
)
|
||||
async def token_counter(request: TokenCountRequest):
|
||||
""" """
|
||||
from litellm import token_counter
|
||||
|
||||
global llm_router
|
||||
|
||||
prompt = request.prompt
|
||||
messages = request.messages
|
||||
if prompt is None and messages is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="prompt or messages must be provided"
|
||||
)
|
||||
|
||||
deployment = None
|
||||
litellm_model_name = None
|
||||
if llm_router is not None:
|
||||
# get 1 deployment corresponding to the model
|
||||
for _model in llm_router.model_list:
|
||||
if _model["model_name"] == request.model:
|
||||
deployment = _model
|
||||
break
|
||||
if deployment is not None:
|
||||
litellm_model_name = deployment.get("litellm_params", {}).get("model")
|
||||
# remove the custom_llm_provider_prefix in the litellm_model_name
|
||||
if "/" in litellm_model_name:
|
||||
litellm_model_name = litellm_model_name.split("/", 1)[1]
|
||||
|
||||
model_to_use = (
|
||||
litellm_model_name or request.model
|
||||
) # use litellm model name, if it's not avalable then fallback to request.model
|
||||
total_tokens, tokenizer_used = token_counter(
|
||||
model=model_to_use,
|
||||
text=prompt,
|
||||
messages=messages,
|
||||
return_tokenizer_used=True,
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=total_tokens,
|
||||
request_model=request.model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type=tokenizer_used,
|
||||
)
|
||||
|
||||
|
||||
#### KEY MANAGEMENT ####
|
||||
|
||||
|
||||
|
|
138
litellm/tests/test_proxy_token_counter.py
Normal file
138
litellm/tests/test_proxy_token_counter.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
# Test the following scenarios:
|
||||
# 1. Generate a Key, and use it to make a call
|
||||
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from datetime import datetime
|
||||
|
||||
load_dotenv()
|
||||
import os, io, time
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest, logging, asyncio
|
||||
import litellm, asyncio
|
||||
from litellm.proxy.proxy_server import token_counter
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
||||
|
||||
from litellm.proxy._types import TokenCountRequest, TokenCountResponse
|
||||
|
||||
|
||||
from litellm import Router
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vLLM_token_counting():
|
||||
"""
|
||||
Test Token counter for vLLM models
|
||||
- User passes model="special-alias"
|
||||
- token_counter should infer that special_alias -> maps to wolfram/miquliz-120b-v2.0
|
||||
-> token counter should use hugging face tokenizer
|
||||
"""
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "special-alias",
|
||||
"litellm_params": {
|
||||
"model": "openai/wolfram/miquliz-120b-v2.0",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||
|
||||
response = await token_counter(
|
||||
request=TokenCountRequest(
|
||||
model="special-alias",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
)
|
||||
|
||||
print("response: ", response)
|
||||
|
||||
assert (
|
||||
response.tokenizer_type == "huggingface_tokenizer"
|
||||
) # SHOULD use the hugging face tokenizer
|
||||
assert response.model_used == "wolfram/miquliz-120b-v2.0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counting_model_not_in_model_list():
|
||||
"""
|
||||
Test Token counter - when a model is not in model_list
|
||||
-> should use the default OpenAI tokenizer
|
||||
"""
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||
|
||||
response = await token_counter(
|
||||
request=TokenCountRequest(
|
||||
model="special-alias",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
)
|
||||
|
||||
print("response: ", response)
|
||||
|
||||
assert (
|
||||
response.tokenizer_type == "openai_tokenizer"
|
||||
) # SHOULD use the OpenAI tokenizer
|
||||
assert response.model_used == "special-alias"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_token_counting():
|
||||
"""
|
||||
Test Token counter
|
||||
-> should work for gpt-4
|
||||
"""
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||
|
||||
response = await token_counter(
|
||||
request=TokenCountRequest(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
)
|
||||
|
||||
print("response: ", response)
|
||||
|
||||
assert (
|
||||
response.tokenizer_type == "openai_tokenizer"
|
||||
) # SHOULD use the OpenAI tokenizer
|
||||
assert response.request_model == "gpt-4"
|
|
@ -3880,6 +3880,11 @@ def _select_tokenizer(model: str):
|
|||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
||||
# default - tiktoken
|
||||
else:
|
||||
tokenizer = None
|
||||
try:
|
||||
tokenizer = Tokenizer.from_pretrained(model)
|
||||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
||||
except:
|
||||
return {"type": "openai_tokenizer", "tokenizer": encoding}
|
||||
|
||||
|
||||
|
@ -4117,6 +4122,7 @@ def token_counter(
|
|||
text: Optional[Union[str, List[str]]] = None,
|
||||
messages: Optional[List] = None,
|
||||
count_response_tokens: Optional[bool] = False,
|
||||
return_tokenizer_used: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Count the number of tokens in a given text using a specified model.
|
||||
|
@ -4209,7 +4215,10 @@ def token_counter(
|
|||
)
|
||||
else:
|
||||
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
||||
|
||||
_tokenizer_type = tokenizer_json["type"]
|
||||
if return_tokenizer_used:
|
||||
# used by litellm proxy server -> POST /utils/token_counter
|
||||
return num_tokens, _tokenizer_type
|
||||
return num_tokens
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue