diff --git a/litellm/utils.py b/litellm/utils.py index 55dd4068f..82c0e94d1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -60,6 +60,10 @@ from litellm.litellm_core_utils.redact_messages import ( ) from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.openai import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionToolParam, +) from litellm.types.utils import ( CallTypes, ChatCompletionDeltaToolCall, @@ -79,7 +83,6 @@ from litellm.types.utils import ( TranscriptionResponse, Usage, ) -from litellm.types.llms.openai import ChatCompletionToolParam, ChatCompletionNamedToolChoiceParam oidc_cache = DualCache() @@ -1572,7 +1575,7 @@ def openai_token_counter( model="gpt-3.5-turbo-0613", text: Optional[str] = None, is_tool_call: Optional[bool] = False, - tools: list[ChatCompletionToolParam] | None = None, + tools: List[ChatCompletionToolParam] | None = None, tool_choice: ChatCompletionNamedToolChoiceParam | None = None, count_response_tokens: Optional[ bool @@ -1617,7 +1620,7 @@ def openai_token_counter( for message in messages: num_tokens += tokens_per_message if message.get("role", None) == "system": - includes_system_message = True + includes_system_message = True for key, value in message.items(): if isinstance(value, str): num_tokens += len(encoding.encode(value, disallowed_special=())) @@ -1868,6 +1871,7 @@ def _format_type(props, indent): # This is a guess, as an empty string doesn't yield the expected token count return "any" + def token_counter( model="", custom_tokenizer: Optional[dict] = None, @@ -1955,7 +1959,7 @@ def token_counter( is_tool_call=is_tool_call, count_response_tokens=count_response_tokens, tools=tools, - tool_choice=tool_choice + tool_choice=tool_choice, ) else: print_verbose( @@ -1968,7 +1972,7 @@ def token_counter( is_tool_call=is_tool_call, count_response_tokens=count_response_tokens, tools=tools, - tool_choice=tool_choice + tool_choice=tool_choice, ) else: num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore