(Fixes) OpenAI Streaming Token Counting + Fixes usage track when litellm.turn_off_message_logging=True (#8156)

* working streaming usage tracking

* fix test_async_chat_openai_stream_options

* fix await asyncio.sleep(1)

* test_async_chat_azure

* fix s3 logging

* fix get_stream_options

* fix get_stream_options

* fix streaming handler

* test_stream_token_counting_with_redaction

* fix codeql concern
This commit is contained in:
Ishaan Jaff 2025-01-31 15:06:37 -08:00 committed by GitHub
parent 38b4980018
commit ef6ab91ac2
8 changed files with 268 additions and 94 deletions

View file

@ -1029,21 +1029,13 @@ class Logging(LiteLLMLoggingBaseClass):
] = None ] = None
if "complete_streaming_response" in self.model_call_details: if "complete_streaming_response" in self.model_call_details:
return # break out of this. return # break out of this.
if self.stream and ( complete_streaming_response = self._get_assembled_streaming_response(
isinstance(result, litellm.ModelResponse) result=result,
or isinstance(result, TextCompletionResponse) start_time=start_time,
or isinstance(result, ModelResponseStream) end_time=end_time,
): is_async=False,
complete_streaming_response: Optional[ streaming_chunks=self.sync_streaming_chunks,
Union[ModelResponse, TextCompletionResponse] )
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=self.sync_streaming_chunks,
is_async=False,
)
if complete_streaming_response is not None: if complete_streaming_response is not None:
verbose_logger.debug( verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete" "Logging Details LiteLLM-Success Call streaming complete"
@ -1542,22 +1534,13 @@ class Logging(LiteLLMLoggingBaseClass):
return # break out of this. return # break out of this.
complete_streaming_response: Optional[ complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse] Union[ModelResponse, TextCompletionResponse]
] = None ] = self._get_assembled_streaming_response(
if self.stream is True and ( result=result,
isinstance(result, litellm.ModelResponse) start_time=start_time,
or isinstance(result, litellm.ModelResponseStream) end_time=end_time,
or isinstance(result, TextCompletionResponse) is_async=True,
): streaming_chunks=self.streaming_chunks,
complete_streaming_response: Optional[ )
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=self.streaming_chunks,
is_async=True,
)
if complete_streaming_response is not None: if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response") print_verbose("Async success callbacks: Got a complete streaming response")
@ -2259,6 +2242,32 @@ class Logging(LiteLLMLoggingBaseClass):
_new_callbacks.append(_c) _new_callbacks.append(_c)
return _new_callbacks return _new_callbacks
def _get_assembled_streaming_response(
self,
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream, Any],
start_time: datetime.datetime,
end_time: datetime.datetime,
is_async: bool,
streaming_chunks: List[Any],
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
if isinstance(result, ModelResponse):
return result
elif isinstance(result, TextCompletionResponse):
return result
elif isinstance(result, ModelResponseStream):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=streaming_chunks,
is_async=is_async,
)
return complete_streaming_response
return None
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915 def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
""" """

View file

@ -5,7 +5,6 @@ import threading
import time import time
import traceback import traceback
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, cast from typing import Any, Callable, Dict, List, Optional, cast
import httpx import httpx
@ -14,6 +13,7 @@ from pydantic import BaseModel
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.redact_messages import LiteLLMLoggingObject from litellm.litellm_core_utils.redact_messages import LiteLLMLoggingObject
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.types.utils import Delta from litellm.types.utils import Delta
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import ( from litellm.types.utils import (
@ -29,11 +29,6 @@ from .exception_mapping_utils import exception_type
from .llm_response_utils.get_api_base import get_api_base from .llm_response_utils.get_api_base import get_api_base
from .rules import Rules from .rules import Rules
MAX_THREADS = 100
# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
def is_async_iterable(obj: Any) -> bool: def is_async_iterable(obj: Any) -> bool:
""" """
@ -1568,21 +1563,6 @@ class CustomStreamWrapper:
) )
if processed_chunk is None: if processed_chunk is None:
continue continue
## LOGGING
## LOGGING
executor.submit(
self.logging_obj.success_handler,
result=processed_chunk,
start_time=None,
end_time=None,
cache_hit=cache_hit,
)
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
)
)
if self.logging_obj._llm_caching_handler is not None: if self.logging_obj._llm_caching_handler is not None:
asyncio.create_task( asyncio.create_task(
@ -1634,16 +1614,6 @@ class CustomStreamWrapper:
) )
if processed_chunk is None: if processed_chunk is None:
continue continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
)
)
choice = processed_chunk.choices[0] choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices): if isinstance(choice, StreamingChoices):
@ -1671,33 +1641,31 @@ class CustomStreamWrapper:
"usage", "usage",
getattr(complete_streaming_response, "usage"), getattr(complete_streaming_response, "usage"),
) )
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
response, cache_hit=cache_hit
)
)
if self.sent_stream_usage is False and self.send_stream_usage is True: if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True self.sent_stream_usage = True
return response return response
asyncio.create_task(
self.logging_obj.async_success_handler(
complete_streaming_response,
cache_hit=cache_hit,
start_time=None,
end_time=None,
)
)
executor.submit(
self.logging_obj.success_handler,
complete_streaming_response,
cache_hit=cache_hit,
start_time=None,
end_time=None,
)
raise StopAsyncIteration # Re-raise StopIteration raise StopAsyncIteration # Re-raise StopIteration
else: else:
self.sent_last_chunk = True self.sent_last_chunk = True
processed_chunk = self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
)
)
return processed_chunk return processed_chunk
except httpx.TimeoutException as e: # if httpx read timeout error occues except httpx.TimeoutException as e: # if httpx read timeout error occues
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()

View file

@ -0,0 +1,5 @@
from concurrent.futures import ThreadPoolExecutor
MAX_THREADS = 100
# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)

View file

@ -14,6 +14,7 @@ from typing import (
Union, Union,
cast, cast,
) )
from urllib.parse import urlparse
import httpx import httpx
import openai import openai
@ -833,8 +834,9 @@ class OpenAIChatCompletion(BaseLLM):
stream_options: Optional[dict] = None, stream_options: Optional[dict] = None,
): ):
data["stream"] = True data["stream"] = True
if stream_options is not None: data.update(
data["stream_options"] = stream_options self.get_stream_options(stream_options=stream_options, api_base=api_base)
)
openai_client: OpenAI = self._get_openai_client( # type: ignore openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False, is_async=False,
@ -893,8 +895,9 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
data["stream"] = True data["stream"] = True
if stream_options is not None: data.update(
data["stream_options"] = stream_options self.get_stream_options(stream_options=stream_options, api_base=api_base)
)
for _ in range(2): for _ in range(2):
try: try:
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
@ -977,6 +980,20 @@ class OpenAIChatCompletion(BaseLLM):
status_code=500, message=f"{str(e)}", headers=error_headers status_code=500, message=f"{str(e)}", headers=error_headers
) )
def get_stream_options(
self, stream_options: Optional[dict], api_base: Optional[str]
) -> dict:
"""
Pass `stream_options` to the data dict for OpenAI requests
"""
if stream_options is not None:
return {"stream_options": stream_options}
else:
# by default litellm will include usage for openai endpoints
if api_base is None or urlparse(api_base).hostname == "api.openai.com":
return {"stream_options": {"include_usage": True}}
return {}
# Embedding # Embedding
@track_llm_api_timing() @track_llm_api_timing()
async def make_openai_embedding_request( async def make_openai_embedding_request(

View file

@ -166,7 +166,6 @@ with resources.open_text(
# Convert to str (if necessary) # Convert to str (if necessary)
claude_json_str = json.dumps(json_data) claude_json_str = json.dumps(json_data)
import importlib.metadata import importlib.metadata
from concurrent.futures import ThreadPoolExecutor
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -185,6 +184,7 @@ from typing import (
from openai import OpenAIError as OriginalError from openai import OpenAIError as OriginalError
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.audio_transcription.transformation import ( from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig, BaseAudioTranscriptionConfig,
) )
@ -235,10 +235,6 @@ from .types.router import LiteLLM_Params
####### ENVIRONMENT VARIABLES #################### ####### ENVIRONMENT VARIABLES ####################
# Adjust to your specific application needs / system capabilities. # Adjust to your specific application needs / system capabilities.
MAX_THREADS = 100
# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
sentry_sdk_instance = None sentry_sdk_instance = None
capture_exception = None capture_exception = None
add_breadcrumb = None add_breadcrumb = None

View file

@ -418,6 +418,8 @@ async def test_async_chat_openai_stream():
) )
async for chunk in response: async for chunk in response:
continue continue
await asyncio.sleep(1)
## test failure callback ## test failure callback
try: try:
response = await litellm.acompletion( response = await litellm.acompletion(
@ -428,6 +430,7 @@ async def test_async_chat_openai_stream():
) )
async for chunk in response: async for chunk in response:
continue continue
await asyncio.sleep(1)
except Exception: except Exception:
pass pass
time.sleep(1) time.sleep(1)
@ -499,6 +502,8 @@ async def test_async_chat_azure_stream():
) )
async for chunk in response: async for chunk in response:
continue continue
await asyncio.sleep(1)
# test failure callback # test failure callback
try: try:
response = await litellm.acompletion( response = await litellm.acompletion(
@ -509,6 +514,7 @@ async def test_async_chat_azure_stream():
) )
async for chunk in response: async for chunk in response:
continue continue
await asyncio.sleep(1)
except Exception: except Exception:
pass pass
await asyncio.sleep(1) await asyncio.sleep(1)
@ -540,6 +546,8 @@ async def test_async_chat_openai_stream_options():
async for chunk in response: async for chunk in response:
continue continue
await asyncio.sleep(1)
print("mock client args list=", mock_client.await_args_list) print("mock client args list=", mock_client.await_args_list)
mock_client.assert_awaited_once() mock_client.assert_awaited_once()
except Exception as e: except Exception as e:
@ -607,6 +615,8 @@ async def test_async_chat_bedrock_stream():
async for chunk in response: async for chunk in response:
print(f"chunk: {chunk}") print(f"chunk: {chunk}")
continue continue
await asyncio.sleep(1)
## test failure callback ## test failure callback
try: try:
response = await litellm.acompletion( response = await litellm.acompletion(
@ -617,6 +627,8 @@ async def test_async_chat_bedrock_stream():
) )
async for chunk in response: async for chunk in response:
continue continue
await asyncio.sleep(1)
except Exception: except Exception:
pass pass
await asyncio.sleep(1) await asyncio.sleep(1)
@ -770,6 +782,8 @@ async def test_async_text_completion_bedrock():
async for chunk in response: async for chunk in response:
print(f"chunk: {chunk}") print(f"chunk: {chunk}")
continue continue
await asyncio.sleep(1)
## test failure callback ## test failure callback
try: try:
response = await litellm.atext_completion( response = await litellm.atext_completion(
@ -780,6 +794,8 @@ async def test_async_text_completion_bedrock():
) )
async for chunk in response: async for chunk in response:
continue continue
await asyncio.sleep(1)
except Exception: except Exception:
pass pass
time.sleep(1) time.sleep(1)
@ -809,6 +825,8 @@ async def test_async_text_completion_openai_stream():
async for chunk in response: async for chunk in response:
print(f"chunk: {chunk}") print(f"chunk: {chunk}")
continue continue
await asyncio.sleep(1)
## test failure callback ## test failure callback
try: try:
response = await litellm.atext_completion( response = await litellm.atext_completion(
@ -819,6 +837,8 @@ async def test_async_text_completion_openai_stream():
) )
async for chunk in response: async for chunk in response:
continue continue
await asyncio.sleep(1)
except Exception: except Exception:
pass pass
time.sleep(1) time.sleep(1)

View file

@ -381,7 +381,7 @@ class CompletionCustomHandler(
# Simple Azure OpenAI call # Simple Azure OpenAI call
## COMPLETION ## COMPLETION
@pytest.mark.flaky(retries=5, delay=1) # @pytest.mark.flaky(retries=5, delay=1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_chat_azure(): async def test_async_chat_azure():
try: try:
@ -427,11 +427,11 @@ async def test_async_chat_azure():
async for chunk in response: async for chunk in response:
print(f"async azure router chunk: {chunk}") print(f"async azure router chunk: {chunk}")
continue continue
await asyncio.sleep(1) await asyncio.sleep(2)
print(f"customHandler.states: {customHandler_streaming_azure_router.states}") print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
assert len(customHandler_streaming_azure_router.errors) == 0 assert len(customHandler_streaming_azure_router.errors) == 0
assert ( assert (
len(customHandler_streaming_azure_router.states) >= 4 len(customHandler_streaming_azure_router.states) >= 3
) # pre, post, stream (multiple times), success ) # pre, post, stream (multiple times), success
# failure # failure
model_list = [ model_list = [

View file

@ -0,0 +1,159 @@
import os
import sys
import traceback
import uuid
import pytest
from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
load_dotenv()
import io
import os
import time
import json
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
import asyncio
from typing import Optional
from litellm.types.utils import StandardLoggingPayload, Usage
from litellm.integrations.custom_logger import CustomLogger
class TestCustomLogger(CustomLogger):
def __init__(self):
self.recorded_usage: Optional[Usage] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
standard_logging_payload = kwargs.get("standard_logging_object")
print(
"standard_logging_payload",
json.dumps(standard_logging_payload, indent=4, default=str),
)
self.recorded_usage = Usage(
prompt_tokens=standard_logging_payload.get("prompt_tokens"),
completion_tokens=standard_logging_payload.get("completion_tokens"),
total_tokens=standard_logging_payload.get("total_tokens"),
)
pass
@pytest.mark.asyncio
async def test_stream_token_counting_gpt_4o():
"""
When stream_options={"include_usage": True} logging callback tracks Usage == Usage from llm API
"""
custom_logger = TestCustomLogger()
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, how are you?" * 100}],
stream=True,
stream_options={"include_usage": True},
)
actual_usage = None
async for chunk in response:
if "usage" in chunk:
actual_usage = chunk["usage"]
print("chunk.usage", json.dumps(chunk["usage"], indent=4, default=str))
pass
await asyncio.sleep(2)
print("\n\n\n\n\n")
print(
"recorded_usage",
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
)
print("\n\n\n\n\n")
assert actual_usage.prompt_tokens == custom_logger.recorded_usage.prompt_tokens
assert (
actual_usage.completion_tokens == custom_logger.recorded_usage.completion_tokens
)
assert actual_usage.total_tokens == custom_logger.recorded_usage.total_tokens
@pytest.mark.asyncio
async def test_stream_token_counting_without_include_usage():
"""
When stream_options={"include_usage": True} is not passed, the usage tracked == usage from llm api chunk
by default, litellm passes `include_usage=True` for OpenAI API
"""
custom_logger = TestCustomLogger()
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, how are you?" * 100}],
stream=True,
)
actual_usage = None
async for chunk in response:
if "usage" in chunk:
actual_usage = chunk["usage"]
print("chunk.usage", json.dumps(chunk["usage"], indent=4, default=str))
pass
await asyncio.sleep(2)
print("\n\n\n\n\n")
print(
"recorded_usage",
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
)
print("\n\n\n\n\n")
assert actual_usage.prompt_tokens == custom_logger.recorded_usage.prompt_tokens
assert (
actual_usage.completion_tokens == custom_logger.recorded_usage.completion_tokens
)
assert actual_usage.total_tokens == custom_logger.recorded_usage.total_tokens
@pytest.mark.asyncio
async def test_stream_token_counting_with_redaction():
"""
When litellm.turn_off_message_logging=True is used, the usage tracked == usage from llm api chunk
"""
litellm.turn_off_message_logging = True
custom_logger = TestCustomLogger()
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, how are you?" * 100}],
stream=True,
)
actual_usage = None
async for chunk in response:
if "usage" in chunk:
actual_usage = chunk["usage"]
print("chunk.usage", json.dumps(chunk["usage"], indent=4, default=str))
pass
await asyncio.sleep(2)
print("\n\n\n\n\n")
print(
"recorded_usage",
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
)
print("\n\n\n\n\n")
assert actual_usage.prompt_tokens == custom_logger.recorded_usage.prompt_tokens
assert (
actual_usage.completion_tokens == custom_logger.recorded_usage.completion_tokens
)
assert actual_usage.total_tokens == custom_logger.recorded_usage.total_tokens