forked from phoenix/litellm-mirror
test - token count response
This commit is contained in:
parent
22ba5fa186
commit
4a5e6aa43c
3 changed files with 157 additions and 14 deletions
|
@ -1013,6 +1013,6 @@ class TokenCountRequest(LiteLLMBase):
|
||||||
|
|
||||||
class TokenCountResponse(LiteLLMBase):
|
class TokenCountResponse(LiteLLMBase):
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
model: str
|
request_model: str
|
||||||
base_model: str
|
model_used: str
|
||||||
tokenizer_type: str
|
tokenizer_type: str
|
||||||
|
|
|
@ -4779,33 +4779,38 @@ async def token_counter(request: TokenCountRequest):
|
||||||
|
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
|
if prompt is None and messages is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="prompt or messages must be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
deployment = None
|
||||||
|
litellm_model_name = None
|
||||||
if llm_router is not None:
|
if llm_router is not None:
|
||||||
# get 1 deployment corresponding to the model
|
# get 1 deployment corresponding to the model
|
||||||
for _model in llm_router.model_list:
|
for _model in llm_router.model_list:
|
||||||
if _model["model_name"] == request.model:
|
if _model["model_name"] == request.model:
|
||||||
deployment = _model
|
deployment = _model
|
||||||
break
|
break
|
||||||
|
if deployment is not None:
|
||||||
litellm_model_name = deployment.get("litellm_params", {}).get("model")
|
litellm_model_name = deployment.get("litellm_params", {}).get("model")
|
||||||
# remove the custom_llm_provider_prefix in the litellm_model_name
|
# remove the custom_llm_provider_prefix in the litellm_model_name
|
||||||
if "/" in litellm_model_name:
|
if "/" in litellm_model_name:
|
||||||
litellm_model_name = litellm_model_name.split("/", 1)[1]
|
litellm_model_name = litellm_model_name.split("/", 1)[1]
|
||||||
|
|
||||||
if prompt is None and messages is None:
|
model_to_use = (
|
||||||
raise HTTPException(
|
litellm_model_name or request.model
|
||||||
status_code=400, detail="prompt or messages must be provided"
|
) # use litellm model name, if it's not avalable then fallback to request.model
|
||||||
)
|
|
||||||
total_tokens, tokenizer_used = token_counter(
|
total_tokens, tokenizer_used = token_counter(
|
||||||
model=litellm_model_name,
|
model=model_to_use,
|
||||||
text=prompt,
|
text=prompt,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
return_tokenizer_used=True,
|
return_tokenizer_used=True,
|
||||||
)
|
)
|
||||||
return TokenCountResponse(
|
return TokenCountResponse(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
model=request.model,
|
request_model=request.model,
|
||||||
base_model=litellm_model_name,
|
model_used=model_to_use,
|
||||||
tokenizer_type=tokenizer_used,
|
tokenizer_type=tokenizer_used,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
138
litellm/tests/test_proxy_token_counter.py
Normal file
138
litellm/tests/test_proxy_token_counter.py
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
# Test the following scenarios:
|
||||||
|
# 1. Generate a Key, and use it to make a call
|
||||||
|
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import Request
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os, io, time
|
||||||
|
|
||||||
|
# this file is to test litellm/proxy
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest, logging, asyncio
|
||||||
|
import litellm, asyncio
|
||||||
|
from litellm.proxy.proxy_server import token_counter
|
||||||
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
|
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
||||||
|
|
||||||
|
from litellm.proxy._types import TokenCountRequest, TokenCountResponse
|
||||||
|
|
||||||
|
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vLLM_token_counting():
|
||||||
|
"""
|
||||||
|
Test Token counter for vLLM models
|
||||||
|
- User passes model="special-alias"
|
||||||
|
- token_counter should infer that special_alias -> maps to wolfram/miquliz-120b-v2.0
|
||||||
|
-> token counter should use hugging face tokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "special-alias",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/wolfram/miquliz-120b-v2.0",
|
||||||
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||||
|
|
||||||
|
response = await token_counter(
|
||||||
|
request=TokenCountRequest(
|
||||||
|
model="special-alias",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response: ", response)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.tokenizer_type == "huggingface_tokenizer"
|
||||||
|
) # SHOULD use the hugging face tokenizer
|
||||||
|
assert response.model_used == "wolfram/miquliz-120b-v2.0"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_counting_model_not_in_model_list():
|
||||||
|
"""
|
||||||
|
Test Token counter - when a model is not in model_list
|
||||||
|
-> should use the default OpenAI tokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||||
|
|
||||||
|
response = await token_counter(
|
||||||
|
request=TokenCountRequest(
|
||||||
|
model="special-alias",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response: ", response)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.tokenizer_type == "openai_tokenizer"
|
||||||
|
) # SHOULD use the OpenAI tokenizer
|
||||||
|
assert response.model_used == "special-alias"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gpt_token_counting():
|
||||||
|
"""
|
||||||
|
Test Token counter
|
||||||
|
-> should work for gpt-4
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||||
|
|
||||||
|
response = await token_counter(
|
||||||
|
request=TokenCountRequest(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response: ", response)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.tokenizer_type == "openai_tokenizer"
|
||||||
|
) # SHOULD use the OpenAI tokenizer
|
||||||
|
assert response.request_model == "gpt-4"
|
Loading…
Add table
Add a link
Reference in a new issue