fix(utils.py): if openai model, don't check hf tokenizers

This commit is contained in:
Krrish Dholakia 2024-08-12 16:28:22 -07:00
parent e9c88952b9
commit a8644d8a7d
2 changed files with 21 additions and 1 deletions

View file

@ -11,7 +11,9 @@ import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
from litellm import ( from litellm import (
create_pretrained_tokenizer, create_pretrained_tokenizer,
decode, decode,
@ -343,3 +345,14 @@ def test_empty_tools():
) )
print(result) print(result)
def test_gpt_4o_token_counter():
with patch.object(
litellm.utils, "openai_token_counter", new=MagicMock()
) as mock_client:
token_counter(
model="gpt-4o-2024-05-13", messages=[{"role": "user", "content": "Hey!"}]
)
mock_client.assert_called()

View file

@ -1610,10 +1610,17 @@ def _select_tokenizer(model: str):
# default - tiktoken # default - tiktoken
else: else:
tokenizer = None tokenizer = None
if (
model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_text_completion_models
or model in litellm.open_ai_embedding_models
):
return {"type": "openai_tokenizer", "tokenizer": encoding}
try: try:
tokenizer = Tokenizer.from_pretrained(model) tokenizer = Tokenizer.from_pretrained(model)
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
except: except Exception:
return {"type": "openai_tokenizer", "tokenizer": encoding} return {"type": "openai_tokenizer", "tokenizer": encoding}