diff --git a/docs/my-website/docs/completion/token_usage.md b/docs/my-website/docs/completion/token_usage.md index 626973c57..807ccfd91 100644 --- a/docs/my-website/docs/completion/token_usage.md +++ b/docs/my-website/docs/completion/token_usage.md @@ -1,7 +1,7 @@ # Completion Token Usage & Cost By default LiteLLM returns token usage in all completion requests ([See here](https://litellm.readthedocs.io/en/latest/output/)) -However, we also expose 5 helper functions + **[NEW]** an API to calculate token usage across providers: +However, we also expose some helper functions + **[NEW]** an API to calculate token usage across providers: - `encode`: This encodes the text passed in, using the model-specific tokenizer. [**Jump to code**](#1-encode) @@ -9,17 +9,19 @@ However, we also expose 5 helper functions + **[NEW]** an API to calculate token - `token_counter`: This returns the number of tokens for a given input - it uses the tokenizer based on the model, and defaults to tiktoken if no model-specific tokenizer is available. [**Jump to code**](#3-token_counter) -- `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#4-cost_per_token) +- `create_pretrained_tokenizer` and `create_tokenizer`: LiteLLM provides default tokenizer support for OpenAI, Cohere, Anthropic, Llama2, and Llama3 models. If you are using a different model, you can create a custom tokenizer and pass it as `custom_tokenizer` to the `encode`, `decode`, and `token_counter` methods. [**Jump to code**](#4-create_pretrained_tokenizer-and-create_tokenizer) -- `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#5-completion_cost) +- `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#5-cost_per_token) -- `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#6-get_max_tokens) +- `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#6-completion_cost) -- `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#7-model_cost) +- `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#7-get_max_tokens) -- `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#8-register_model) +- `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#8-model_cost) -- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#9-apilitellmai) +- `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#9-register_model) + +- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#10-apilitellmai) 📣 This is a community maintained list. Contributions are welcome! ❤️ @@ -60,7 +62,24 @@ messages = [{"user": "role", "content": "Hey, how's it going"}] print(token_counter(model="gpt-3.5-turbo", messages=messages)) ``` -### 4. `cost_per_token` +### 4. `create_pretrained_tokenizer` and `create_tokenizer` + +```python +from litellm import create_pretrained_tokenizer, create_tokenizer + +# get tokenizer from huggingface repo +custom_tokenizer_1 = create_pretrained_tokenizer("Xenova/llama-3-tokenizer") + +# use tokenizer from json file +with open("tokenizer.json") as f: + json_data = json.load(f) + +json_str = json.dumps(json_data) + +custom_tokenizer_2 = create_tokenizer(json_str) +``` + +### 5. `cost_per_token` ```python from litellm import cost_per_token @@ -72,7 +91,7 @@ prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_toke print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar) ``` -### 5. `completion_cost` +### 6. `completion_cost` * Input: Accepts a `litellm.completion()` response **OR** prompt + completion strings * Output: Returns a `float` of cost for the `completion` call @@ -99,7 +118,7 @@ cost = completion_cost(model="bedrock/anthropic.claude-v2", prompt="Hey!", compl formatted_string = f"${float(cost):.10f}" print(formatted_string) ``` -### 6. `get_max_tokens` +### 7. `get_max_tokens` Input: Accepts a model name - e.g., gpt-3.5-turbo (to get a complete list, call litellm.model_list). Output: Returns the maximum number of tokens allowed for the given model @@ -112,7 +131,7 @@ model = "gpt-3.5-turbo" print(get_max_tokens(model)) # Output: 4097 ``` -### 7. `model_cost` +### 8. `model_cost` * Output: Returns a dict object containing the max_tokens, input_cost_per_token, output_cost_per_token for all models on [community-maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) @@ -122,7 +141,7 @@ from litellm import model_cost print(model_cost) # {'gpt-3.5-turbo': {'max_tokens': 4000, 'input_cost_per_token': 1.5e-06, 'output_cost_per_token': 2e-06}, ...} ``` -### 8. `register_model` +### 9. `register_model` * Input: Provide EITHER a model cost dictionary or a url to a hosted json blob * Output: Returns updated model_cost dictionary + updates litellm.model_cost with model details. @@ -157,5 +176,3 @@ export LITELLM_LOCAL_MODEL_COST_MAP="True" ``` Note: this means you will need to upgrade to get updated pricing, and newer models. - - diff --git a/litellm/__init__.py b/litellm/__init__.py index 924dde3a8..dc640f0e9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -613,6 +613,8 @@ from .utils import ( get_optional_params, modify_integration, token_counter, + create_pretrained_tokenizer, + create_tokenizer, cost_per_token, completion_cost, supports_function_calling, diff --git a/litellm/main.py b/litellm/main.py index 51ec95401..9765669fe 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -34,6 +34,8 @@ from litellm.utils import ( async_mock_completion_streaming_obj, convert_to_model_response_object, token_counter, + create_pretrained_tokenizer, + create_tokenizer, Usage, get_optional_params_embeddings, get_optional_params_image_gen, diff --git a/litellm/tests/test_token_counter.py b/litellm/tests/test_token_counter.py index af0db487e..4d759d4cf 100644 --- a/litellm/tests/test_token_counter.py +++ b/litellm/tests/test_token_counter.py @@ -9,7 +9,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import time -from litellm import token_counter, encode, decode +from litellm import token_counter, create_pretrained_tokenizer, encode, decode def test_token_counter_normal_plus_function_calling(): @@ -69,15 +69,23 @@ def test_tokenizers(): model="meta-llama/Llama-2-7b-chat", text=sample_text ) + # llama3 tokenizer (also testing custom tokenizer) + llama3_tokens_1 = token_counter(model="meta-llama/llama-3-70b-instruct", text=sample_text) + + llama3_tokenizer = create_pretrained_tokenizer("Xenova/llama-3-tokenizer") + llama3_tokens_2 = token_counter(custom_tokenizer=llama3_tokenizer, text=sample_text) + print( - f"openai tokens: {openai_tokens}; claude tokens: {claude_tokens}; cohere tokens: {cohere_tokens}; llama2 tokens: {llama2_tokens}" + f"openai tokens: {openai_tokens}; claude tokens: {claude_tokens}; cohere tokens: {cohere_tokens}; llama2 tokens: {llama2_tokens}; llama3 tokens: {llama3_tokens_1}" ) # assert that all token values are different assert ( - openai_tokens != cohere_tokens != llama2_tokens + openai_tokens != cohere_tokens != llama2_tokens != llama3_tokens_1 ), "Token values are not different." + assert llama3_tokens_1 == llama3_tokens_2, "Custom tokenizer is not being used! It has been configured to use the same tokenizer as the built in llama3 tokenizer and the results should be the same." + print("test tokenizer: It worked!") except Exception as e: pytest.fail(f"An exception occured: {e}") diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index 44fb1607c..57b93df9c 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -20,6 +20,8 @@ from litellm.utils import ( validate_environment, function_to_dict, token_counter, + create_pretrained_tokenizer, + create_tokenizer, ) # Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils' diff --git a/litellm/utils.py b/litellm/utils.py index d8530f7ad..ec296e9dc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3775,29 +3775,34 @@ def _select_tokenizer(model: str): elif "llama-2" in model.lower() or "replicate" in model.lower(): tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} + # llama3 + elif "llama-3" in model.lower(): + tokenizer = Tokenizer.from_pretrained("Xenova/llama-3-tokenizer") + return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} # default - tiktoken else: return {"type": "openai_tokenizer", "tokenizer": encoding} -def encode(model: str, text: str): +def encode(model="", text="", custom_tokenizer: Optional[dict] = None): """ Encodes the given text using the specified model. Args: model (str): The name of the model to use for tokenization. + custom_tokenizer (Optional[dict]): A custom tokenizer created with the `create_pretrained_tokenizer` or `create_tokenizer` method. Must be a dictionary with a string value for `type` and Tokenizer for `tokenizer`. Default is None. text (str): The text to be encoded. Returns: enc: The encoded text. """ - tokenizer_json = _select_tokenizer(model=model) + tokenizer_json = custom_tokenizer or _select_tokenizer(model=model) enc = tokenizer_json["tokenizer"].encode(text) return enc -def decode(model: str, tokens: List[int]): - tokenizer_json = _select_tokenizer(model=model) +def decode(model="", tokens: List[int] = [], custom_tokenizer: Optional[dict] = None): + tokenizer_json = custom_tokenizer or _select_tokenizer(model=model) dec = tokenizer_json["tokenizer"].decode(tokens) return dec @@ -3969,10 +3974,47 @@ def calculage_img_tokens( tile_tokens = (base_tokens * 2) * tiles_needed_high_res total_tokens = base_tokens + tile_tokens return total_tokens + + +def create_pretrained_tokenizer( + identifier: str, + revision="main", + auth_token: Optional[str] = None +): + """ + Creates a tokenizer from an existing file on a HuggingFace repository to be used with `token_counter`. + + Args: + identifier (str): The identifier of a Model on the Hugging Face Hub, that contains a tokenizer.json file + revision (str, defaults to main): A branch or commit id + auth_token (str, optional, defaults to None): An optional auth token used to access private repositories on the Hugging Face Hub + + Returns: + dict: A dictionary with the tokenizer and its type. + """ + + tokenizer = Tokenizer.from_pretrained(identifier, revision=revision, auth_token=auth_token) + return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} + + +def create_tokenizer(json: str): + """ + Creates a tokenizer from a valid JSON string for use with `token_counter`. + + Args: + json (str): A valid JSON string representing a previously serialized tokenizer + + Returns: + dict: A dictionary with the tokenizer and its type. + """ + + tokenizer = Tokenizer.from_str(json) + return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} def token_counter( model="", + custom_tokenizer: Optional[dict] = None, text: Optional[Union[str, List[str]]] = None, messages: Optional[List] = None, count_response_tokens: Optional[bool] = False, @@ -3982,13 +4024,14 @@ def token_counter( Args: model (str): The name of the model to use for tokenization. Default is an empty string. + custom_tokenizer (Optional[dict]): A custom tokenizer created with the `create_pretrained_tokenizer` or `create_tokenizer` method. Must be a dictionary with a string value for `type` and Tokenizer for `tokenizer`. Default is None. text (str): The raw text string to be passed to the model. Default is None. messages (Optional[List[Dict[str, str]]]): Alternative to passing in text. A list of dictionaries representing messages with "role" and "content" keys. Default is None. Returns: int: The number of tokens in the text. """ - # use tiktoken, anthropic, cohere or llama2's tokenizer depending on the model + # use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model is_tool_call = False num_tokens = 0 if text == None: @@ -4030,8 +4073,8 @@ def token_counter( elif isinstance(text, str): count_response_tokens = True # user just trying to count tokens for a text. don't add the chat_ml +3 tokens to this - if model is not None: - tokenizer_json = _select_tokenizer(model=model) + if model is not None or custom_tokenizer is not None: + tokenizer_json = custom_tokenizer or _select_tokenizer(model=model) if tokenizer_json["type"] == "huggingface_tokenizer": print_verbose( f"Token Counter - using hugging face token counter, for model={model}"