[Fix] o1-mini causes pydantic warnings on reasoning_tokens (#5754)

* add requester_metadata in standard logging payload

* log requester_metadata in metadata

* use StandardLoggingPayload for logging

* docs StandardLoggingPayload

* fix import

* include standard logging object in failure

* add test for requester metadata

* handle completion_tokens_details

* add test for completion_tokens_details
This commit is contained in:
Ishaan Jaff 2024-09-17 20:23:14 -07:00 committed by GitHub
parent d0425e7767
commit 7f4dfe434a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 91 additions and 22 deletions

View file

@ -61,7 +61,7 @@ litellm_settings:
Removes any field with `user_api_key_*` from metadata.
## What gets logged?
## What gets logged? StandardLoggingPayload
Found under `kwargs["standard_logging_object"]`. This is a standard payload, logged for every response.

View file

@ -16,6 +16,7 @@ from litellm.litellm_core_utils.logging_utils import (
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
from litellm.types.utils import StandardLoggingMetadata, StandardLoggingPayload
class RequestKwargs(TypedDict):
@ -30,6 +31,7 @@ class GCSBucketPayload(TypedDict):
start_time: str
end_time: str
response_cost: Optional[float]
metadata: Optional[StandardLoggingMetadata]
spend_log_metadata: str
exception: Optional[str]
log_event_type: Optional[str]
@ -183,13 +185,22 @@ class GCSBucketLogger(GCSBucketBase):
end_user_id=kwargs.get("end_user_id", None),
)
# Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if payload is None:
raise ValueError("standard_logging_object not found in kwargs")
gcs_payload: GCSBucketPayload = GCSBucketPayload(
request_kwargs=request_kwargs,
response_obj=response_dict,
start_time=start_time,
end_time=end_time,
metadata=payload["metadata"],
spend_log_metadata=_spend_log_payload.get("metadata", ""),
response_cost=kwargs.get("response_cost", None),
response_cost=payload["response_cost"],
exception=exception_str,
log_event_type=None,
)

View file

@ -1628,6 +1628,17 @@ class Logging:
self.model_call_details.setdefault("original_response", None)
self.model_call_details["response_cost"] = 0
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
)
)
if hasattr(exception, "headers") and isinstance(exception.headers, dict):
self.model_call_details.setdefault("litellm_params", {})
metadata = (
@ -2419,6 +2430,7 @@ def get_standard_logging_object_payload(
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
requester_metadata=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys

View file

@ -96,7 +96,7 @@ def convert_key_logging_metadata_to_callback(
for var, value in data.callback_vars.items():
if team_callback_settings_obj.callback_vars is None:
team_callback_settings_obj.callback_vars = {}
team_callback_settings_obj.callback_vars[var] = (
team_callback_settings_obj.callback_vars[var] = str(
litellm.utils.get_secret(value, default_value=value) or value
)
@ -204,6 +204,13 @@ async def add_litellm_data_to_request(
if _metadata_variable_name not in data:
data[_metadata_variable_name] = {}
# We want to log the "metadata" from the client side request. Avoid circular reference by not directly assigning metadata to itself
if "metadata" in data and data["metadata"] is not None:
data[_metadata_variable_name]["requester_metadata"] = copy.deepcopy(
data["metadata"]
)
data[_metadata_variable_name]["user_api_key"] = user_api_key_dict.api_key
data[_metadata_variable_name]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None

View file

@ -31,9 +31,8 @@ general_settings:
"os.environ/SLACK_WEBHOOK_URL_2",
],
}
key_management_system: "azure_key_vault"
litellm_settings:
success_callback: ["prometheus"]
callbacks: ["gcs_bucket"]

View file

@ -89,6 +89,7 @@ async def test_basic_gcs_logger():
"user_api_key_team_alias": None,
"user_api_key_metadata": {},
"requester_ip_address": "127.0.0.1",
"requester_metadata": {"foo": "bar"},
"spend_logs_metadata": {"hello": "world"},
"headers": {
"content-type": "application/json",
@ -159,6 +160,8 @@ async def test_basic_gcs_logger():
== "116544810872468347480"
)
assert gcs_payload["metadata"]["requester_metadata"] == {"foo": "bar"}
# Delete Object from GCS
print("deleting object from GCS")
await gcs_logger.delete_gcs_object(object_name=object_name)

View file

@ -5,7 +5,7 @@ from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject
from openai.types.audio.transcription_create_params import FileTypes
from openai.types.audio.transcription_create_params import FileTypes # type: ignore
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage
from pydantic import ConfigDict, Field, PrivateAttr
from typing_extensions import Callable, Dict, Required, TypedDict, override
@ -253,7 +253,7 @@ class HiddenParams(OpenAIObject):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except:
@ -359,7 +359,7 @@ class Message(OpenAIObject):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except:
@ -490,6 +490,19 @@ class Usage(CompletionUsage):
completion_tokens_details = CompletionTokensDetails(
reasoning_tokens=reasoning_tokens
)
# Ensure completion_tokens_details is properly handled
if "completion_tokens_details" in params:
if isinstance(params["completion_tokens_details"], dict):
completion_tokens_details = CompletionTokensDetails(
**params["completion_tokens_details"]
)
elif isinstance(
params["completion_tokens_details"], CompletionTokensDetails
):
completion_tokens_details = params["completion_tokens_details"]
del params["completion_tokens_details"]
super().__init__(
prompt_tokens=prompt_tokens or 0,
completion_tokens=completion_tokens or 0,
@ -641,6 +654,7 @@ class ModelResponse(OpenAIObject):
if choices is not None and isinstance(choices, list):
new_choices = []
for choice in choices:
_new_choice = None
if isinstance(choice, StreamingChoices):
_new_choice = choice
elif isinstance(choice, dict):
@ -715,7 +729,7 @@ class ModelResponse(OpenAIObject):
# Allow dictionary-style access to attributes
return getattr(self, key)
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except:
@ -804,7 +818,7 @@ class EmbeddingResponse(OpenAIObject):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except:
@ -855,7 +869,7 @@ class TextChoices(OpenAIObject):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except:
@ -911,6 +925,7 @@ class TextCompletionResponse(OpenAIObject):
if choices is not None and isinstance(choices, list):
new_choices = []
for choice in choices:
_new_choice = None
if isinstance(choice, TextChoices):
_new_choice = choice
elif isinstance(choice, dict):
@ -937,12 +952,12 @@ class TextCompletionResponse(OpenAIObject):
usage = Usage()
super(TextCompletionResponse, self).__init__(
id=id,
object=object,
created=created,
model=model,
choices=choices,
usage=usage,
id=id, # type: ignore
object=object, # type: ignore
created=created, # type: ignore
model=model, # type: ignore
choices=choices, # type: ignore
usage=usage, # type: ignore
**params,
)
@ -986,7 +1001,7 @@ class ImageObject(OpenAIObject):
revised_prompt: Optional[str] = None
def __init__(self, b64_json=None, url=None, revised_prompt=None):
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt) # type: ignore
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -1004,7 +1019,7 @@ class ImageObject(OpenAIObject):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except:
@ -1057,7 +1072,7 @@ class ImageResponse(OpenAIImageResponse):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except:
@ -1072,7 +1087,7 @@ class TranscriptionResponse(OpenAIObject):
_response_headers: Optional[dict] = None
def __init__(self, text=None):
super().__init__(text=text)
super().__init__(text=text) # type: ignore
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -1090,7 +1105,7 @@ class TranscriptionResponse(OpenAIObject):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except:
@ -1247,6 +1262,7 @@ class StandardLoggingMetadata(TypedDict):
dict
] # special param to log k,v pairs to spendlogs for a call
requester_ip_address: Optional[str]
requester_metadata: Optional[dict]
class StandardLoggingHiddenParams(TypedDict):

View file

@ -99,3 +99,24 @@ async def test_o1_max_completion_tokens(respx_mock: MockRouter, model: str):
print(f"response: {response}")
assert isinstance(response, ModelResponse)
def test_litellm_responses():
"""
ensures that type of completion_tokens_details is correctly handled / returned
"""
from litellm import ModelResponse
from litellm.types.utils import CompletionTokensDetails
response = ModelResponse(
usage={
"completion_tokens": 436,
"prompt_tokens": 14,
"total_tokens": 450,
"completion_tokens_details": {"reasoning_tokens": 0},
}
)
print("response: ", response)
assert isinstance(response.usage.completion_tokens_details, CompletionTokensDetails)