From 61b636c20d6ca8c5fdc7dcaacd48f8cddb67d9cd Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 23 Dec 2024 17:42:24 -0800 Subject: [PATCH] [Bug Fix]: Errors in LiteLLM When Using Embeddings Model with Usage-Based Routing (#7390) * use slp for usage based routing v2 * update error msg * fix usage based routing v2 * test_tpm_rpm_updated * fix unused imports * fix unused imports --- litellm/router_strategy/lowest_tpm_rpm_v2.py | 150 ++++++++---------- .../create_mock_standard_logging_payload.py | 131 +++++++++++++++ .../local_testing/test_tpm_rpm_routing_v2.py | 70 ++++++-- 3 files changed, 254 insertions(+), 97 deletions(-) create mode 100644 tests/local_testing/create_mock_standard_logging_payload.py diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 205da3808a..e6b65299b2 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -1,10 +1,9 @@ #### What this does #### # identifies lowest tpm deployment import random -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx -from pydantic import BaseModel import litellm from litellm import token_counter @@ -13,7 +12,7 @@ from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.types.router import RouterErrors -from litellm.types.utils import LiteLLMPydanticObjectBase +from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload from litellm.utils import get_utc_datetime, print_verbose if TYPE_CHECKING: @@ -223,45 +222,44 @@ class LowestTPMLoggingHandler_v2(CustomLogger): """ Update TPM/RPM usage on success """ - if kwargs["litellm_params"].get("metadata") is None: - pass - else: - model_group = kwargs["litellm_params"]["metadata"].get( - "model_group", None - ) + standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + if standard_logging_object is None: + raise ValueError("standard_logging_object not passed in.") + model_group = standard_logging_object.get("model_group") + id = standard_logging_object.get("model_id") + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) - id = kwargs["litellm_params"].get("model_info", {}).get("id", None) - if model_group is None or id is None: - return - elif isinstance(id, int): - id = str(id) + total_tokens = standard_logging_object.get("total_tokens") - total_tokens = response_obj["usage"]["total_tokens"] + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock - # ------------ - # Setup values - # ------------ - dt = get_utc_datetime() - current_minute = dt.strftime( - "%H-%M" - ) # use the same timezone regardless of system clock + tpm_key = f"{id}:tpm:{current_minute}" + # ------------ + # Update usage + # ------------ + # update cache - tpm_key = f"{id}:tpm:{current_minute}" - # ------------ - # Update usage - # ------------ - # update cache - - ## TPM - self.router_cache.increment_cache( - key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl - ) - ### TESTING ### - if self.test_flag: - self.logged_success += 1 + ## TPM + self.router_cache.increment_cache( + key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl + ) + ### TESTING ### + if self.test_flag: + self.logged_success += 1 except Exception as e: verbose_logger.exception( - "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( + "litellm.proxy.hooks.lowest_tpm_rpm_v2.py::log_success_event(): Exception occured - {}".format( str(e) ) ) @@ -272,54 +270,46 @@ class LowestTPMLoggingHandler_v2(CustomLogger): """ Update TPM usage on success """ - if kwargs["litellm_params"].get("metadata") is None: - pass - else: - model_group = kwargs["litellm_params"]["metadata"].get( - "model_group", None - ) + standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + if standard_logging_object is None: + raise ValueError("standard_logging_object not passed in.") + model_group = standard_logging_object.get("model_group") + id = standard_logging_object.get("model_id") + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + total_tokens = standard_logging_object.get("total_tokens") + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock - if isinstance(response_obj, BaseModel) and not hasattr( - response_obj, "usage" - ): - return + tpm_key = f"{id}:tpm:{current_minute}" + # ------------ + # Update usage + # ------------ + # update cache + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + ## TPM + await self.router_cache.async_increment_cache( + key=tpm_key, + value=total_tokens, + ttl=self.routing_args.ttl, + parent_otel_span=parent_otel_span, + ) - id = kwargs["litellm_params"].get("model_info", {}).get("id", None) - if model_group is None or id is None: - return - elif isinstance(id, int): - id = str(id) - - total_tokens = cast(dict, response_obj)["usage"]["total_tokens"] - - # ------------ - # Setup values - # ------------ - dt = get_utc_datetime() - current_minute = dt.strftime( - "%H-%M" - ) # use the same timezone regardless of system clock - - tpm_key = f"{id}:tpm:{current_minute}" - # ------------ - # Update usage - # ------------ - # update cache - parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) - ## TPM - await self.router_cache.async_increment_cache( - key=tpm_key, - value=total_tokens, - ttl=self.routing_args.ttl, - parent_otel_span=parent_otel_span, - ) - - ### TESTING ### - if self.test_flag: - self.logged_success += 1 + ### TESTING ### + if self.test_flag: + self.logged_success += 1 except Exception as e: verbose_logger.exception( - "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( + "litellm.proxy.hooks.lowest_tpm_rpm_v2.py::async_log_success_event(): Exception occured - {}".format( str(e) ) ) diff --git a/tests/local_testing/create_mock_standard_logging_payload.py b/tests/local_testing/create_mock_standard_logging_payload.py new file mode 100644 index 0000000000..2fd6a4ffa8 --- /dev/null +++ b/tests/local_testing/create_mock_standard_logging_payload.py @@ -0,0 +1,131 @@ +import io +import os +import sys + + +sys.path.insert(0, os.path.abspath("../..")) + +import asyncio +import gzip +import json +import logging +import time +from unittest.mock import AsyncMock, patch + +import pytest + +import litellm +from litellm import completion +from litellm._logging import verbose_logger +from litellm.integrations.datadog.datadog import * +from datetime import datetime, timedelta +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingModelInformation, + StandardLoggingMetadata, + StandardLoggingHiddenParams, +) + +verbose_logger.setLevel(logging.DEBUG) + + +def create_standard_logging_payload() -> StandardLoggingPayload: + return StandardLoggingPayload( + id="test_id", + call_type="completion", + response_cost=0.1, + response_cost_failure_debug_info=None, + status="success", + total_tokens=30, + prompt_tokens=20, + completion_tokens=10, + startTime=1234567890.0, + endTime=1234567891.0, + completionStartTime=1234567890.5, + model_map_information=StandardLoggingModelInformation( + model_map_key="gpt-3.5-turbo", model_map_value=None + ), + model="gpt-3.5-turbo", + model_id="model-123", + model_group="openai-gpt", + api_base="https://api.openai.com", + metadata=StandardLoggingMetadata( + user_api_key_hash="test_hash", + user_api_key_org_id=None, + user_api_key_alias="test_alias", + user_api_key_team_id="test_team", + user_api_key_user_id="test_user", + user_api_key_team_alias="test_team_alias", + spend_logs_metadata=None, + requester_ip_address="127.0.0.1", + requester_metadata=None, + ), + cache_hit=False, + cache_key=None, + saved_cache_cost=0.0, + request_tags=[], + end_user=None, + requester_ip_address="127.0.0.1", + messages=[{"role": "user", "content": "Hello, world!"}], + response={"choices": [{"message": {"content": "Hi there!"}}]}, + error_str=None, + model_parameters={"stream": True}, + hidden_params=StandardLoggingHiddenParams( + model_id="model-123", + cache_key=None, + api_base="https://api.openai.com", + response_cost="0.1", + additional_headers=None, + ), + ) + + +def create_standard_logging_payload_with_long_content() -> StandardLoggingPayload: + return StandardLoggingPayload( + id="test_id", + call_type="completion", + response_cost=0.1, + response_cost_failure_debug_info=None, + status="success", + total_tokens=30, + prompt_tokens=20, + completion_tokens=10, + startTime=1234567890.0, + endTime=1234567891.0, + completionStartTime=1234567890.5, + model_map_information=StandardLoggingModelInformation( + model_map_key="gpt-3.5-turbo", model_map_value=None + ), + model="gpt-3.5-turbo", + model_id="model-123", + model_group="openai-gpt", + api_base="https://api.openai.com", + metadata=StandardLoggingMetadata( + user_api_key_hash="test_hash", + user_api_key_org_id=None, + user_api_key_alias="test_alias", + user_api_key_team_id="test_team", + user_api_key_user_id="test_user", + user_api_key_team_alias="test_team_alias", + spend_logs_metadata=None, + requester_ip_address="127.0.0.1", + requester_metadata=None, + ), + cache_hit=False, + cache_key=None, + saved_cache_cost=0.0, + request_tags=[], + end_user=None, + requester_ip_address="127.0.0.1", + messages=[{"role": "user", "content": "Hello, world!" * 80000}], + response={"choices": [{"message": {"content": "Hi there!" * 80000}}]}, + error_str="error_str" * 80000, + model_parameters={"stream": True}, + hidden_params=StandardLoggingHiddenParams( + model_id="model-123", + cache_key=None, + api_base="https://api.openai.com", + response_cost="0.1", + additional_headers=None, + ), + ) diff --git a/tests/local_testing/test_tpm_rpm_routing_v2.py b/tests/local_testing/test_tpm_rpm_routing_v2.py index 3641eecadb..d718249bab 100644 --- a/tests/local_testing/test_tpm_rpm_routing_v2.py +++ b/tests/local_testing/test_tpm_rpm_routing_v2.py @@ -18,7 +18,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from unittest.mock import AsyncMock, MagicMock, patch - +from litellm.types.utils import StandardLoggingPayload import pytest import litellm @@ -28,6 +28,7 @@ from litellm.router_strategy.lowest_tpm_rpm_v2 import ( LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler, ) from litellm.utils import get_utc_datetime +from create_mock_standard_logging_payload import create_standard_logging_payload ### UNIT TESTS FOR TPM/RPM ROUTING ### @@ -44,17 +45,25 @@ def test_tpm_rpm_updated(): ) model_group = "gpt-3.5-turbo" deployment_id = "1234" + deployment = "azure/chatgpt-v-2" + total_tokens = 50 + standard_logging_payload = create_standard_logging_payload() + standard_logging_payload["model_group"] = model_group + standard_logging_payload["model_id"] = deployment_id + standard_logging_payload["total_tokens"] = total_tokens kwargs = { "litellm_params": { "metadata": { - "model_group": "gpt-3.5-turbo", - "deployment": "azure/chatgpt-v-2", + "model_group": model_group, + "deployment": deployment, }, "model_info": {"id": deployment_id}, - } + }, + "standard_logging_object": standard_logging_payload, } + start_time = time.time() - response_obj = {"usage": {"total_tokens": 50}} + response_obj = {"usage": {"total_tokens": total_tokens}} end_time = time.time() lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"]) lowest_tpm_logger.log_success_event( @@ -97,18 +106,25 @@ def test_get_available_deployments(): ) model_group = "gpt-3.5-turbo" ## DEPLOYMENT 1 ## + total_tokens = 50 deployment_id = "1234" + deployment = "azure/chatgpt-v-2" + standard_logging_payload = create_standard_logging_payload() + standard_logging_payload["model_group"] = model_group + standard_logging_payload["model_id"] = deployment_id + standard_logging_payload["total_tokens"] = total_tokens kwargs = { "litellm_params": { "metadata": { - "model_group": "gpt-3.5-turbo", - "deployment": "azure/chatgpt-v-2", + "model_group": model_group, + "deployment": deployment, }, "model_info": {"id": deployment_id}, - } + }, + "standard_logging_object": standard_logging_payload, } start_time = time.time() - response_obj = {"usage": {"total_tokens": 50}} + response_obj = {"usage": {"total_tokens": total_tokens}} end_time = time.time() lowest_tpm_logger.log_success_event( response_obj=response_obj, @@ -117,18 +133,24 @@ def test_get_available_deployments(): end_time=end_time, ) ## DEPLOYMENT 2 ## + total_tokens = 20 deployment_id = "5678" + standard_logging_payload = create_standard_logging_payload() + standard_logging_payload["model_group"] = model_group + standard_logging_payload["model_id"] = deployment_id + standard_logging_payload["total_tokens"] = total_tokens kwargs = { "litellm_params": { "metadata": { - "model_group": "gpt-3.5-turbo", - "deployment": "azure/chatgpt-v-2", + "model_group": model_group, + "deployment": deployment, }, "model_info": {"id": deployment_id}, - } + }, + "standard_logging_object": standard_logging_payload, } start_time = time.time() - response_obj = {"usage": {"total_tokens": 20}} + response_obj = {"usage": {"total_tokens": total_tokens}} end_time = time.time() lowest_tpm_logger.log_success_event( response_obj=response_obj, @@ -187,13 +209,17 @@ def test_router_get_available_deployments(): print(f"router id's: {router.get_model_ids()}") ## DEPLOYMENT 1 ## deployment_id = 1 + standard_logging_payload = create_standard_logging_payload() + standard_logging_payload["model_group"] = "azure-model" + standard_logging_payload["model_id"] = str(deployment_id) kwargs = { "litellm_params": { "metadata": { "model_group": "azure-model", }, "model_info": {"id": 1}, - } + }, + "standard_logging_object": standard_logging_payload, } start_time = time.time() response_obj = {"usage": {"total_tokens": 50}} @@ -206,13 +232,17 @@ def test_router_get_available_deployments(): ) ## DEPLOYMENT 2 ## deployment_id = 2 + standard_logging_payload = create_standard_logging_payload() + standard_logging_payload["model_group"] = "azure-model" + standard_logging_payload["model_id"] = str(deployment_id) kwargs = { "litellm_params": { "metadata": { "model_group": "azure-model", }, "model_info": {"id": 2}, - } + }, + "standard_logging_object": standard_logging_payload, } start_time = time.time() response_obj = {"usage": {"total_tokens": 20}} @@ -260,16 +290,22 @@ def test_router_skip_rate_limited_deployments(): ## DEPLOYMENT 1 ## deployment_id = 1 + total_tokens = 1439 + standard_logging_payload = create_standard_logging_payload() + standard_logging_payload["model_group"] = "azure-model" + standard_logging_payload["model_id"] = str(deployment_id) + standard_logging_payload["total_tokens"] = total_tokens kwargs = { "litellm_params": { "metadata": { "model_group": "azure-model", }, "model_info": {"id": deployment_id}, - } + }, + "standard_logging_object": standard_logging_payload, } start_time = time.time() - response_obj = {"usage": {"total_tokens": 1439}} + response_obj = {"usage": {"total_tokens": total_tokens}} end_time = time.time() router.lowesttpm_logger_v2.log_success_event( response_obj=response_obj,