From 43f139fafd8e69d81c5fd5d8f95d511e0953c36f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 24 Jan 2024 20:09:08 -0800 Subject: [PATCH] fix(ollama_chat.py): fix default token counting for ollama chat --- litellm/llms/ollama_chat.py | 12 ++++++++---- litellm/utils.py | 9 +++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index 31e3f0d16a..e381c93f7d 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -220,8 +220,10 @@ def get_ollama_response( model_response["choices"][0]["message"] = response_json["message"] model_response["created"] = int(time.time()) model_response["model"] = "ollama/" + model - prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore - completion_tokens = response_json["eval_count"] + prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore + completion_tokens = response_json.get( + "eval_count", litellm.token_counter(text=response_json["message"]) + ) model_response["usage"] = litellm.Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, @@ -320,8 +322,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj): model_response["choices"][0]["message"] = response_json["message"] model_response["created"] = int(time.time()) model_response["model"] = "ollama/" + data["model"] - prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore - completion_tokens = response_json["eval_count"] + prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=data["messages"])) # type: ignore + completion_tokens = response_json.get( + "eval_count", litellm.token_counter(text=response_json["message"]) + ) model_response["usage"] = litellm.Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, diff --git a/litellm/utils.py b/litellm/utils.py index 03d38ff35a..4718083c2b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2872,8 +2872,13 @@ def token_counter( print_verbose( f"Token Counter - using generic token counter, for model={model}" ) - enc = tokenizer_json["tokenizer"].encode(text) - num_tokens = len(enc) + num_tokens = openai_token_counter( + text=text, # type: ignore + model="gpt-3.5-turbo", + messages=messages, + is_tool_call=is_tool_call, + count_response_tokens=count_response_tokens, + ) else: num_tokens = len(encoding.encode(text)) # type: ignore return num_tokens