add cost tracking for pass through imagen

This commit is contained in:
Ishaan Jaff 2024-09-02 18:10:46 -07:00
parent 9fcab392a4
commit 4a0fdc40f1
4 changed files with 12 additions and 2 deletions

View file

@ -24,7 +24,7 @@ from litellm.llms.anthropic.cost_calculation import (
) )
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import Usage from litellm.types.utils import PassthroughCallTypes, Usage
from litellm.utils import ( from litellm.utils import (
CallTypes, CallTypes,
CostPerToken, CostPerToken,
@ -625,6 +625,7 @@ def completion_cost(
if ( if (
call_type == CallTypes.image_generation.value call_type == CallTypes.image_generation.value
or call_type == CallTypes.aimage_generation.value or call_type == CallTypes.aimage_generation.value
or call_type == PassthroughCallTypes.passthrough_image_generation.value
): ):
### IMAGE GENERATION COST CALCULATION ### ### IMAGE GENERATION COST CALCULATION ###
if custom_llm_provider == "vertex_ai": if custom_llm_provider == "vertex_ai":

View file

@ -110,6 +110,7 @@ class PassThroughEndpointLogging:
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
transform_vertex_response_to_openai, transform_vertex_response_to_openai,
) )
from litellm.types.utils import PassthroughCallTypes
vertex_image_generation_class = VertexImageGeneration() vertex_image_generation_class = VertexImageGeneration()
@ -127,6 +128,10 @@ class PassThroughEndpointLogging:
model=model, model=model,
) )
) )
logging_obj.call_type = (
PassthroughCallTypes.passthrough_image_generation.value
)
else: else:
litellm_model_response = await transform_vertex_response_to_openai( litellm_model_response = await transform_vertex_response_to_openai(
response=_json_response, response=_json_response,

View file

@ -72,6 +72,6 @@ def test_get_llm_provider_deepseek_custom_api_base():
def test_get_llm_provider_vertex_ai_image_models(): def test_get_llm_provider_vertex_ai_image_models():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="imagegeneration@006", model="imagegeneration@006", custom_llm_provider=None
) )
assert custom_llm_provider == "vertex_ai" assert custom_llm_provider == "vertex_ai"

View file

@ -119,6 +119,10 @@ class CallTypes(Enum):
speech = "speech" speech = "speech"
class PassthroughCallTypes(Enum):
passthrough_image_generation = "passthrough-image-generation"
class TopLogprob(OpenAIObject): class TopLogprob(OpenAIObject):
token: str token: str
"""The token.""" """The token."""