feat - try using hf tokenizer

This commit is contained in:
Ishaan Jaff 2024-05-16 10:59:29 -07:00
parent c646b809a6
commit 22ba5fa186
3 changed files with 34 additions and 4 deletions

View file

@ -79,6 +79,8 @@ class LiteLLMRoutes(enum.Enum):
"/v1/models", "/v1/models",
] ]
llm_utils_routes: List = ["utils/token_counter"]
info_routes: List = [ info_routes: List = [
"/key/info", "/key/info",
"/team/info", "/team/info",
@ -1012,3 +1014,5 @@ class TokenCountRequest(LiteLLMBase):
class TokenCountResponse(LiteLLMBase): class TokenCountResponse(LiteLLMBase):
total_tokens: int total_tokens: int
model: str model: str
base_model: str
tokenizer_type: str

View file

@ -4775,21 +4775,38 @@ async def token_counter(request: TokenCountRequest):
""" """ """ """
from litellm import token_counter from litellm import token_counter
global llm_router
prompt = request.prompt prompt = request.prompt
messages = request.messages messages = request.messages
if llm_router is not None:
# get 1 deployment corresponding to the model
for _model in llm_router.model_list:
if _model["model_name"] == request.model:
deployment = _model
break
litellm_model_name = deployment.get("litellm_params", {}).get("model")
# remove the custom_llm_provider_prefix in the litellm_model_name
if "/" in litellm_model_name:
litellm_model_name = litellm_model_name.split("/", 1)[1]
if prompt is None and messages is None: if prompt is None and messages is None:
raise HTTPException( raise HTTPException(
status_code=400, detail="prompt or messages must be provided" status_code=400, detail="prompt or messages must be provided"
) )
total_tokens = token_counter( total_tokens, tokenizer_used = token_counter(
model=request.model, model=litellm_model_name,
text=prompt, text=prompt,
messages=messages, messages=messages,
return_tokenizer_used=True,
) )
return TokenCountResponse( return TokenCountResponse(
total_tokens=total_tokens, total_tokens=total_tokens,
model=request.model, model=request.model,
base_model=litellm_model_name,
tokenizer_type=tokenizer_used,
) )

View file

@ -3860,7 +3860,12 @@ def _select_tokenizer(model: str):
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
# default - tiktoken # default - tiktoken
else: else:
return {"type": "openai_tokenizer", "tokenizer": encoding} tokenizer = None
try:
tokenizer = Tokenizer.from_pretrained(model)
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
except:
return {"type": "openai_tokenizer", "tokenizer": encoding}
def encode(model="", text="", custom_tokenizer: Optional[dict] = None): def encode(model="", text="", custom_tokenizer: Optional[dict] = None):
@ -4097,6 +4102,7 @@ def token_counter(
text: Optional[Union[str, List[str]]] = None, text: Optional[Union[str, List[str]]] = None,
messages: Optional[List] = None, messages: Optional[List] = None,
count_response_tokens: Optional[bool] = False, count_response_tokens: Optional[bool] = False,
return_tokenizer_used: Optional[bool] = False,
): ):
""" """
Count the number of tokens in a given text using a specified model. Count the number of tokens in a given text using a specified model.
@ -4189,7 +4195,10 @@ def token_counter(
) )
else: else:
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
_tokenizer_type = tokenizer_json["type"]
if return_tokenizer_used:
# used by litellm proxy server -> POST /utils/token_counter
return num_tokens, _tokenizer_type
return num_tokens return num_tokens