fix(types/utils.py): support passing prompt cache usage stats in usage object

Passes deepseek prompt caching values through to end user
This commit is contained in:
Krrish Dholakia 2024-08-02 09:30:50 -07:00
parent cd073d5ad3
commit 0a30ba9674
3 changed files with 40 additions and 9 deletions

View file

@ -1007,3 +1007,20 @@ def test_completion_cost_anthropic():
print(input_cost) print(input_cost)
print(output_cost) print(output_cost)
def test_completion_cost_deepseek():
litellm.set_verbose = True
model_name = "deepseek/deepseek-chat"
messages = [{"role": "user", "content": "Hey, how's it going?"}]
try:
response_1 = litellm.completion(model=model_name, messages=messages)
response_2 = litellm.completion(model=model_name, messages=messages)
# Add any assertions here to check the response
print(response_2)
assert response_2.usage.prompt_cache_hit_tokens is not None
assert response_2.usage.prompt_cache_miss_tokens is not None
except litellm.APIError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -5,7 +5,7 @@ from enum import Enum
from typing import Dict, List, Literal, Optional, Tuple, Union from typing import Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict from pydantic import ConfigDict, Field, PrivateAttr
from typing_extensions import Dict, Required, TypedDict, override from typing_extensions import Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason from ..litellm_core_utils.core_helpers import map_finish_reason
@ -445,16 +445,28 @@ class Choices(OpenAIObject):
class Usage(OpenAIObject): class Usage(OpenAIObject):
prompt_cache_hit_tokens: Optional[int] = Field(default=None)
prompt_cache_miss_tokens: Optional[int] = Field(default=None)
prompt_tokens: Optional[int] = Field(default=None)
completion_tokens: Optional[int] = Field(default=None)
total_tokens: Optional[int] = Field(default=None)
def __init__( def __init__(
self, prompt_tokens=None, completion_tokens=None, total_tokens=None, **params self,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None,
prompt_cache_hit_tokens: Optional[int] = None,
prompt_cache_miss_tokens: Optional[int] = None,
): ):
super(Usage, self).__init__(**params) data = {
if prompt_tokens: "prompt_tokens": prompt_tokens,
self.prompt_tokens = prompt_tokens "completion_tokens": completion_tokens,
if completion_tokens: "total_tokens": total_tokens,
self.completion_tokens = completion_tokens "prompt_cache_hit_tokens": prompt_cache_hit_tokens,
if total_tokens: "prompt_cache_miss_tokens": prompt_cache_miss_tokens,
self.total_tokens = total_tokens }
super().__init__(**data)
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator

View file

@ -5825,6 +5825,8 @@ def convert_to_model_response_object(
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
model_response_object.usage.prompt_cache_hit_tokens = response_object["usage"].get("prompt_cache_hit_tokens", None) # type: ignore
model_response_object.usage.prompt_cache_miss_tokens = response_object["usage"].get("prompt_cache_miss_tokens", None) # type: ignore
if "created" in response_object: if "created" in response_object:
model_response_object.created = response_object["created"] or int( model_response_object.created = response_object["created"] or int(