mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(Fixes) OpenAI Streaming Token Counting + Fixes usage track when litellm.turn_off_message_logging=True
(#8156)
* working streaming usage tracking * fix test_async_chat_openai_stream_options * fix await asyncio.sleep(1) * test_async_chat_azure * fix s3 logging * fix get_stream_options * fix get_stream_options * fix streaming handler * test_stream_token_counting_with_redaction * fix codeql concern
This commit is contained in:
parent
38b4980018
commit
ef6ab91ac2
8 changed files with 268 additions and 94 deletions
|
@ -1029,21 +1029,13 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
] = None
|
] = None
|
||||||
if "complete_streaming_response" in self.model_call_details:
|
if "complete_streaming_response" in self.model_call_details:
|
||||||
return # break out of this.
|
return # break out of this.
|
||||||
if self.stream and (
|
complete_streaming_response = self._get_assembled_streaming_response(
|
||||||
isinstance(result, litellm.ModelResponse)
|
result=result,
|
||||||
or isinstance(result, TextCompletionResponse)
|
start_time=start_time,
|
||||||
or isinstance(result, ModelResponseStream)
|
end_time=end_time,
|
||||||
):
|
is_async=False,
|
||||||
complete_streaming_response: Optional[
|
streaming_chunks=self.sync_streaming_chunks,
|
||||||
Union[ModelResponse, TextCompletionResponse]
|
)
|
||||||
] = _assemble_complete_response_from_streaming_chunks(
|
|
||||||
result=result,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
request_kwargs=self.model_call_details,
|
|
||||||
streaming_chunks=self.sync_streaming_chunks,
|
|
||||||
is_async=False,
|
|
||||||
)
|
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"Logging Details LiteLLM-Success Call streaming complete"
|
"Logging Details LiteLLM-Success Call streaming complete"
|
||||||
|
@ -1542,22 +1534,13 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
return # break out of this.
|
return # break out of this.
|
||||||
complete_streaming_response: Optional[
|
complete_streaming_response: Optional[
|
||||||
Union[ModelResponse, TextCompletionResponse]
|
Union[ModelResponse, TextCompletionResponse]
|
||||||
] = None
|
] = self._get_assembled_streaming_response(
|
||||||
if self.stream is True and (
|
result=result,
|
||||||
isinstance(result, litellm.ModelResponse)
|
start_time=start_time,
|
||||||
or isinstance(result, litellm.ModelResponseStream)
|
end_time=end_time,
|
||||||
or isinstance(result, TextCompletionResponse)
|
is_async=True,
|
||||||
):
|
streaming_chunks=self.streaming_chunks,
|
||||||
complete_streaming_response: Optional[
|
)
|
||||||
Union[ModelResponse, TextCompletionResponse]
|
|
||||||
] = _assemble_complete_response_from_streaming_chunks(
|
|
||||||
result=result,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
request_kwargs=self.model_call_details,
|
|
||||||
streaming_chunks=self.streaming_chunks,
|
|
||||||
is_async=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||||
|
@ -2259,6 +2242,32 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
_new_callbacks.append(_c)
|
_new_callbacks.append(_c)
|
||||||
return _new_callbacks
|
return _new_callbacks
|
||||||
|
|
||||||
|
def _get_assembled_streaming_response(
|
||||||
|
self,
|
||||||
|
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream, Any],
|
||||||
|
start_time: datetime.datetime,
|
||||||
|
end_time: datetime.datetime,
|
||||||
|
is_async: bool,
|
||||||
|
streaming_chunks: List[Any],
|
||||||
|
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||||
|
if isinstance(result, ModelResponse):
|
||||||
|
return result
|
||||||
|
elif isinstance(result, TextCompletionResponse):
|
||||||
|
return result
|
||||||
|
elif isinstance(result, ModelResponseStream):
|
||||||
|
complete_streaming_response: Optional[
|
||||||
|
Union[ModelResponse, TextCompletionResponse]
|
||||||
|
] = _assemble_complete_response_from_streaming_chunks(
|
||||||
|
result=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
request_kwargs=self.model_call_details,
|
||||||
|
streaming_chunks=streaming_chunks,
|
||||||
|
is_async=is_async,
|
||||||
|
)
|
||||||
|
return complete_streaming_response
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
|
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -5,7 +5,6 @@ import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, cast
|
from typing import Any, Callable, Dict, List, Optional, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -14,6 +13,7 @@ from pydantic import BaseModel
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.redact_messages import LiteLLMLoggingObject
|
from litellm.litellm_core_utils.redact_messages import LiteLLMLoggingObject
|
||||||
|
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||||
from litellm.types.utils import Delta
|
from litellm.types.utils import Delta
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
|
@ -29,11 +29,6 @@ from .exception_mapping_utils import exception_type
|
||||||
from .llm_response_utils.get_api_base import get_api_base
|
from .llm_response_utils.get_api_base import get_api_base
|
||||||
from .rules import Rules
|
from .rules import Rules
|
||||||
|
|
||||||
MAX_THREADS = 100
|
|
||||||
|
|
||||||
# Create a ThreadPoolExecutor
|
|
||||||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterable(obj: Any) -> bool:
|
def is_async_iterable(obj: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -1568,21 +1563,6 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
if processed_chunk is None:
|
if processed_chunk is None:
|
||||||
continue
|
continue
|
||||||
## LOGGING
|
|
||||||
## LOGGING
|
|
||||||
executor.submit(
|
|
||||||
self.logging_obj.success_handler,
|
|
||||||
result=processed_chunk,
|
|
||||||
start_time=None,
|
|
||||||
end_time=None,
|
|
||||||
cache_hit=cache_hit,
|
|
||||||
)
|
|
||||||
|
|
||||||
asyncio.create_task(
|
|
||||||
self.logging_obj.async_success_handler(
|
|
||||||
processed_chunk, cache_hit=cache_hit
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.logging_obj._llm_caching_handler is not None:
|
if self.logging_obj._llm_caching_handler is not None:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
@ -1634,16 +1614,6 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
if processed_chunk is None:
|
if processed_chunk is None:
|
||||||
continue
|
continue
|
||||||
## LOGGING
|
|
||||||
threading.Thread(
|
|
||||||
target=self.logging_obj.success_handler,
|
|
||||||
args=(processed_chunk, None, None, cache_hit),
|
|
||||||
).start() # log processed_chunk
|
|
||||||
asyncio.create_task(
|
|
||||||
self.logging_obj.async_success_handler(
|
|
||||||
processed_chunk, cache_hit=cache_hit
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
choice = processed_chunk.choices[0]
|
choice = processed_chunk.choices[0]
|
||||||
if isinstance(choice, StreamingChoices):
|
if isinstance(choice, StreamingChoices):
|
||||||
|
@ -1671,33 +1641,31 @@ class CustomStreamWrapper:
|
||||||
"usage",
|
"usage",
|
||||||
getattr(complete_streaming_response, "usage"),
|
getattr(complete_streaming_response, "usage"),
|
||||||
)
|
)
|
||||||
## LOGGING
|
|
||||||
threading.Thread(
|
|
||||||
target=self.logging_obj.success_handler,
|
|
||||||
args=(response, None, None, cache_hit),
|
|
||||||
).start() # log response
|
|
||||||
asyncio.create_task(
|
|
||||||
self.logging_obj.async_success_handler(
|
|
||||||
response, cache_hit=cache_hit
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if self.sent_stream_usage is False and self.send_stream_usage is True:
|
if self.sent_stream_usage is False and self.send_stream_usage is True:
|
||||||
self.sent_stream_usage = True
|
self.sent_stream_usage = True
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
self.logging_obj.async_success_handler(
|
||||||
|
complete_streaming_response,
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
start_time=None,
|
||||||
|
end_time=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
executor.submit(
|
||||||
|
self.logging_obj.success_handler,
|
||||||
|
complete_streaming_response,
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
start_time=None,
|
||||||
|
end_time=None,
|
||||||
|
)
|
||||||
|
|
||||||
raise StopAsyncIteration # Re-raise StopIteration
|
raise StopAsyncIteration # Re-raise StopIteration
|
||||||
else:
|
else:
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
processed_chunk = self.finish_reason_handler()
|
processed_chunk = self.finish_reason_handler()
|
||||||
## LOGGING
|
|
||||||
threading.Thread(
|
|
||||||
target=self.logging_obj.success_handler,
|
|
||||||
args=(processed_chunk, None, None, cache_hit),
|
|
||||||
).start() # log response
|
|
||||||
asyncio.create_task(
|
|
||||||
self.logging_obj.async_success_handler(
|
|
||||||
processed_chunk, cache_hit=cache_hit
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
except httpx.TimeoutException as e: # if httpx read timeout error occues
|
except httpx.TimeoutException as e: # if httpx read timeout error occues
|
||||||
traceback_exception = traceback.format_exc()
|
traceback_exception = traceback.format_exc()
|
||||||
|
|
5
litellm/litellm_core_utils/thread_pool_executor.py
Normal file
5
litellm/litellm_core_utils/thread_pool_executor.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
MAX_THREADS = 100
|
||||||
|
# Create a ThreadPoolExecutor
|
||||||
|
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
|
@ -14,6 +14,7 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
@ -833,8 +834,9 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
stream_options: Optional[dict] = None,
|
stream_options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
if stream_options is not None:
|
data.update(
|
||||||
data["stream_options"] = stream_options
|
self.get_stream_options(stream_options=stream_options, api_base=api_base)
|
||||||
|
)
|
||||||
|
|
||||||
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||||
is_async=False,
|
is_async=False,
|
||||||
|
@ -893,8 +895,9 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
if stream_options is not None:
|
data.update(
|
||||||
data["stream_options"] = stream_options
|
self.get_stream_options(stream_options=stream_options, api_base=api_base)
|
||||||
|
)
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
try:
|
try:
|
||||||
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||||
|
@ -977,6 +980,20 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
status_code=500, message=f"{str(e)}", headers=error_headers
|
status_code=500, message=f"{str(e)}", headers=error_headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_stream_options(
|
||||||
|
self, stream_options: Optional[dict], api_base: Optional[str]
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Pass `stream_options` to the data dict for OpenAI requests
|
||||||
|
"""
|
||||||
|
if stream_options is not None:
|
||||||
|
return {"stream_options": stream_options}
|
||||||
|
else:
|
||||||
|
# by default litellm will include usage for openai endpoints
|
||||||
|
if api_base is None or urlparse(api_base).hostname == "api.openai.com":
|
||||||
|
return {"stream_options": {"include_usage": True}}
|
||||||
|
return {}
|
||||||
|
|
||||||
# Embedding
|
# Embedding
|
||||||
@track_llm_api_timing()
|
@track_llm_api_timing()
|
||||||
async def make_openai_embedding_request(
|
async def make_openai_embedding_request(
|
||||||
|
|
|
@ -166,7 +166,6 @@ with resources.open_text(
|
||||||
# Convert to str (if necessary)
|
# Convert to str (if necessary)
|
||||||
claude_json_str = json.dumps(json_data)
|
claude_json_str = json.dumps(json_data)
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
@ -185,6 +184,7 @@ from typing import (
|
||||||
|
|
||||||
from openai import OpenAIError as OriginalError
|
from openai import OpenAIError as OriginalError
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||||
BaseAudioTranscriptionConfig,
|
BaseAudioTranscriptionConfig,
|
||||||
)
|
)
|
||||||
|
@ -235,10 +235,6 @@ from .types.router import LiteLLM_Params
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ####################
|
####### ENVIRONMENT VARIABLES ####################
|
||||||
# Adjust to your specific application needs / system capabilities.
|
# Adjust to your specific application needs / system capabilities.
|
||||||
MAX_THREADS = 100
|
|
||||||
|
|
||||||
# Create a ThreadPoolExecutor
|
|
||||||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
|
||||||
sentry_sdk_instance = None
|
sentry_sdk_instance = None
|
||||||
capture_exception = None
|
capture_exception = None
|
||||||
add_breadcrumb = None
|
add_breadcrumb = None
|
||||||
|
|
|
@ -418,6 +418,8 @@ async def test_async_chat_openai_stream():
|
||||||
)
|
)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
## test failure callback
|
## test failure callback
|
||||||
try:
|
try:
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
|
@ -428,6 +430,7 @@ async def test_async_chat_openai_stream():
|
||||||
)
|
)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
|
await asyncio.sleep(1)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
@ -499,6 +502,8 @@ async def test_async_chat_azure_stream():
|
||||||
)
|
)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
# test failure callback
|
# test failure callback
|
||||||
try:
|
try:
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
|
@ -509,6 +514,7 @@ async def test_async_chat_azure_stream():
|
||||||
)
|
)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
|
await asyncio.sleep(1)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
@ -540,6 +546,8 @@ async def test_async_chat_openai_stream_options():
|
||||||
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
print("mock client args list=", mock_client.await_args_list)
|
print("mock client args list=", mock_client.await_args_list)
|
||||||
mock_client.assert_awaited_once()
|
mock_client.assert_awaited_once()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -607,6 +615,8 @@ async def test_async_chat_bedrock_stream():
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(f"chunk: {chunk}")
|
print(f"chunk: {chunk}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
## test failure callback
|
## test failure callback
|
||||||
try:
|
try:
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
|
@ -617,6 +627,8 @@ async def test_async_chat_bedrock_stream():
|
||||||
)
|
)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
@ -770,6 +782,8 @@ async def test_async_text_completion_bedrock():
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(f"chunk: {chunk}")
|
print(f"chunk: {chunk}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
## test failure callback
|
## test failure callback
|
||||||
try:
|
try:
|
||||||
response = await litellm.atext_completion(
|
response = await litellm.atext_completion(
|
||||||
|
@ -780,6 +794,8 @@ async def test_async_text_completion_bedrock():
|
||||||
)
|
)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
@ -809,6 +825,8 @@ async def test_async_text_completion_openai_stream():
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(f"chunk: {chunk}")
|
print(f"chunk: {chunk}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
## test failure callback
|
## test failure callback
|
||||||
try:
|
try:
|
||||||
response = await litellm.atext_completion(
|
response = await litellm.atext_completion(
|
||||||
|
@ -819,6 +837,8 @@ async def test_async_text_completion_openai_stream():
|
||||||
)
|
)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
|
@ -381,7 +381,7 @@ class CompletionCustomHandler(
|
||||||
|
|
||||||
# Simple Azure OpenAI call
|
# Simple Azure OpenAI call
|
||||||
## COMPLETION
|
## COMPLETION
|
||||||
@pytest.mark.flaky(retries=5, delay=1)
|
# @pytest.mark.flaky(retries=5, delay=1)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_chat_azure():
|
async def test_async_chat_azure():
|
||||||
try:
|
try:
|
||||||
|
@ -427,11 +427,11 @@ async def test_async_chat_azure():
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(f"async azure router chunk: {chunk}")
|
print(f"async azure router chunk: {chunk}")
|
||||||
continue
|
continue
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(2)
|
||||||
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
|
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
|
||||||
assert len(customHandler_streaming_azure_router.errors) == 0
|
assert len(customHandler_streaming_azure_router.errors) == 0
|
||||||
assert (
|
assert (
|
||||||
len(customHandler_streaming_azure_router.states) >= 4
|
len(customHandler_streaming_azure_router.states) >= 3
|
||||||
) # pre, post, stream (multiple times), success
|
) # pre, post, stream (multiple times), success
|
||||||
# failure
|
# failure
|
||||||
model_list = [
|
model_list = [
|
||||||
|
|
159
tests/logging_callback_tests/test_token_counting.py
Normal file
159
tests/logging_callback_tests/test_token_counting.py
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
|
# this file is to test litellm/proxy
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
from litellm.types.utils import StandardLoggingPayload, Usage
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomLogger(CustomLogger):
|
||||||
|
def __init__(self):
|
||||||
|
self.recorded_usage: Optional[Usage] = None
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
standard_logging_payload = kwargs.get("standard_logging_object")
|
||||||
|
print(
|
||||||
|
"standard_logging_payload",
|
||||||
|
json.dumps(standard_logging_payload, indent=4, default=str),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.recorded_usage = Usage(
|
||||||
|
prompt_tokens=standard_logging_payload.get("prompt_tokens"),
|
||||||
|
completion_tokens=standard_logging_payload.get("completion_tokens"),
|
||||||
|
total_tokens=standard_logging_payload.get("total_tokens"),
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_token_counting_gpt_4o():
|
||||||
|
"""
|
||||||
|
When stream_options={"include_usage": True} logging callback tracks Usage == Usage from llm API
|
||||||
|
"""
|
||||||
|
custom_logger = TestCustomLogger()
|
||||||
|
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?" * 100}],
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_usage = None
|
||||||
|
async for chunk in response:
|
||||||
|
if "usage" in chunk:
|
||||||
|
actual_usage = chunk["usage"]
|
||||||
|
print("chunk.usage", json.dumps(chunk["usage"], indent=4, default=str))
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
print("\n\n\n\n\n")
|
||||||
|
print(
|
||||||
|
"recorded_usage",
|
||||||
|
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
|
||||||
|
)
|
||||||
|
print("\n\n\n\n\n")
|
||||||
|
|
||||||
|
assert actual_usage.prompt_tokens == custom_logger.recorded_usage.prompt_tokens
|
||||||
|
assert (
|
||||||
|
actual_usage.completion_tokens == custom_logger.recorded_usage.completion_tokens
|
||||||
|
)
|
||||||
|
assert actual_usage.total_tokens == custom_logger.recorded_usage.total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_token_counting_without_include_usage():
|
||||||
|
"""
|
||||||
|
When stream_options={"include_usage": True} is not passed, the usage tracked == usage from llm api chunk
|
||||||
|
|
||||||
|
by default, litellm passes `include_usage=True` for OpenAI API
|
||||||
|
"""
|
||||||
|
custom_logger = TestCustomLogger()
|
||||||
|
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?" * 100}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_usage = None
|
||||||
|
async for chunk in response:
|
||||||
|
if "usage" in chunk:
|
||||||
|
actual_usage = chunk["usage"]
|
||||||
|
print("chunk.usage", json.dumps(chunk["usage"], indent=4, default=str))
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
print("\n\n\n\n\n")
|
||||||
|
print(
|
||||||
|
"recorded_usage",
|
||||||
|
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
|
||||||
|
)
|
||||||
|
print("\n\n\n\n\n")
|
||||||
|
|
||||||
|
assert actual_usage.prompt_tokens == custom_logger.recorded_usage.prompt_tokens
|
||||||
|
assert (
|
||||||
|
actual_usage.completion_tokens == custom_logger.recorded_usage.completion_tokens
|
||||||
|
)
|
||||||
|
assert actual_usage.total_tokens == custom_logger.recorded_usage.total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_token_counting_with_redaction():
|
||||||
|
"""
|
||||||
|
When litellm.turn_off_message_logging=True is used, the usage tracked == usage from llm api chunk
|
||||||
|
"""
|
||||||
|
litellm.turn_off_message_logging = True
|
||||||
|
custom_logger = TestCustomLogger()
|
||||||
|
litellm.logging_callback_manager.add_litellm_callback(custom_logger)
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?" * 100}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_usage = None
|
||||||
|
async for chunk in response:
|
||||||
|
if "usage" in chunk:
|
||||||
|
actual_usage = chunk["usage"]
|
||||||
|
print("chunk.usage", json.dumps(chunk["usage"], indent=4, default=str))
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
print("\n\n\n\n\n")
|
||||||
|
print(
|
||||||
|
"recorded_usage",
|
||||||
|
json.dumps(custom_logger.recorded_usage, indent=4, default=str),
|
||||||
|
)
|
||||||
|
print("\n\n\n\n\n")
|
||||||
|
|
||||||
|
assert actual_usage.prompt_tokens == custom_logger.recorded_usage.prompt_tokens
|
||||||
|
assert (
|
||||||
|
actual_usage.completion_tokens == custom_logger.recorded_usage.completion_tokens
|
||||||
|
)
|
||||||
|
assert actual_usage.total_tokens == custom_logger.recorded_usage.total_tokens
|
Loading…
Add table
Add a link
Reference in a new issue