forked from phoenix/litellm-mirror
feat(utils.py): adding encode and decode functions
This commit is contained in:
parent
c038731c48
commit
4eeadd284a
3 changed files with 80 additions and 31 deletions
|
@ -325,7 +325,9 @@ from .utils import (
|
|||
check_valid_key,
|
||||
get_llm_provider,
|
||||
completion_with_config,
|
||||
register_model
|
||||
register_model,
|
||||
encode,
|
||||
decode
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
|
|
|
@ -8,7 +8,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import time
|
||||
from litellm import token_counter
|
||||
from litellm import token_counter, encode, decode
|
||||
|
||||
|
||||
def test_tokenizers():
|
||||
|
@ -38,4 +38,35 @@ def test_tokenizers():
|
|||
except Exception as e:
|
||||
pytest.fail(f'An exception occured: {e}')
|
||||
|
||||
test_tokenizers()
|
||||
# test_tokenizers()
|
||||
|
||||
def test_encoding_and_decoding():
|
||||
try:
|
||||
sample_text = "Hellö World, this is my input string!"
|
||||
# openai encoding + decoding
|
||||
openai_tokens = encode(model="gpt-3.5-turbo", text=sample_text)
|
||||
openai_text = decode(model="gpt-3.5-turbo", tokens=openai_tokens)
|
||||
|
||||
assert openai_text == sample_text
|
||||
|
||||
# claude encoding + decoding
|
||||
claude_tokens = encode(model="claude-instant-1", text=sample_text)
|
||||
claude_text = decode(model="claude-instant-1", tokens=claude_tokens.ids)
|
||||
|
||||
assert claude_text == sample_text
|
||||
|
||||
# cohere encoding + decoding
|
||||
cohere_tokens = encode(model="command-nightly", text=sample_text)
|
||||
cohere_text = decode(model="command-nightly", tokens=cohere_tokens.ids)
|
||||
|
||||
assert cohere_text == sample_text
|
||||
|
||||
# llama2 encoding + decoding
|
||||
llama2_tokens = encode(model="meta-llama/Llama-2-7b-chat", text=sample_text)
|
||||
llama2_text = decode(model="meta-llama/Llama-2-7b-chat", tokens=llama2_tokens.ids)
|
||||
|
||||
assert llama2_text == sample_text
|
||||
except Exception as e:
|
||||
pytest.fail(f'An exception occured: {e}')
|
||||
|
||||
test_encoding_and_decoding()
|
|
@ -869,6 +869,42 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
|
|||
return a100_80gb_price_per_second_public*total_time
|
||||
|
||||
|
||||
def _select_tokenizer(model: str):
|
||||
# cohere
|
||||
if model in litellm.cohere_models:
|
||||
tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly")
|
||||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
||||
# anthropic
|
||||
elif model in litellm.anthropic_models:
|
||||
# Read the JSON file
|
||||
filename = pkg_resources.resource_filename(__name__, 'llms/tokenizers/anthropic_tokenizer.json')
|
||||
with open(filename, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
# Decode the JSON data from utf-8
|
||||
json_data_decoded = json.dumps(json_data, ensure_ascii=False)
|
||||
# Convert to str
|
||||
json_str = str(json_data_decoded)
|
||||
# load tokenizer
|
||||
tokenizer = Tokenizer.from_str(json_str)
|
||||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
||||
# llama2
|
||||
elif "llama-2" in model.lower():
|
||||
tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
||||
# default - tiktoken
|
||||
else:
|
||||
return {"type": "openai_tokenizer", "tokenizer": encoding}
|
||||
|
||||
def encode(model: str, text: str):
|
||||
tokenizer_json = _select_tokenizer(model=model)
|
||||
enc = tokenizer_json["tokenizer"].encode(text)
|
||||
return enc
|
||||
|
||||
def decode(model: str, tokens: List[int]):
|
||||
tokenizer_json = _select_tokenizer(model=model)
|
||||
dec = tokenizer_json["tokenizer"].decode(tokens)
|
||||
return dec
|
||||
|
||||
def token_counter(model="", text=None, messages: Optional[List] = None):
|
||||
"""
|
||||
Count the number of tokens in a given text using a specified model.
|
||||
|
@ -881,42 +917,22 @@ def token_counter(model="", text=None, messages: Optional[List] = None):
|
|||
Returns:
|
||||
int: The number of tokens in the text.
|
||||
"""
|
||||
# use tiktoken or anthropic's tokenizer depending on the model
|
||||
# use tiktoken, anthropic, cohere or llama2's tokenizer depending on the model
|
||||
if text == None:
|
||||
if messages is not None:
|
||||
text = " ".join([message["content"] for message in messages])
|
||||
text = "".join([message["content"] for message in messages])
|
||||
else:
|
||||
raise ValueError("text and messages cannot both be None")
|
||||
num_tokens = 0
|
||||
|
||||
if model is not None:
|
||||
# cohere
|
||||
if model in litellm.cohere_models:
|
||||
tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly")
|
||||
enc = tokenizer.encode(text)
|
||||
tokenizer_json = _select_tokenizer(model=model)
|
||||
if tokenizer_json["type"] == "huggingface_tokenizer":
|
||||
enc = tokenizer_json["tokenizer"].encode(text)
|
||||
num_tokens = len(enc.ids)
|
||||
# anthropic
|
||||
elif model in litellm.anthropic_models:
|
||||
# Read the JSON file
|
||||
filename = pkg_resources.resource_filename(__name__, 'llms/tokenizers/anthropic_tokenizer.json')
|
||||
with open(filename, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
# Decode the JSON data from utf-8
|
||||
json_data_decoded = json.dumps(json_data, ensure_ascii=False)
|
||||
# Convert to str
|
||||
json_str = str(json_data_decoded)
|
||||
# load tokenizer
|
||||
tokenizer = Tokenizer.from_str(json_str)
|
||||
enc = tokenizer.encode(text)
|
||||
num_tokens = len(enc.ids)
|
||||
# llama2
|
||||
elif "llama-2" in model.lower():
|
||||
tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
enc = tokenizer.encode(text)
|
||||
num_tokens = len(enc.ids)
|
||||
# default - tiktoken
|
||||
else:
|
||||
num_tokens = len(encoding.encode(text))
|
||||
elif tokenizer_json["type"] == "openai_tokenizer":
|
||||
enc = tokenizer_json["tokenizer"].encode(text)
|
||||
num_tokens = len(enc)
|
||||
else:
|
||||
num_tokens = len(encoding.encode(text))
|
||||
return num_tokens
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue