[Bug Fix]: Errors in LiteLLM When Using Embeddings Model with Usage-Based Routing (#7390)

* use slp for usage based routing v2

* update error msg

* fix usage based routing v2

* test_tpm_rpm_updated

* fix unused imports

* fix unused imports
This commit is contained in:
Ishaan Jaff 2024-12-23 17:42:24 -08:00 committed by GitHub
parent 48316520f4
commit 61b636c20d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 254 additions and 97 deletions

View file

@ -1,10 +1,9 @@
#### What this does #### #### What this does ####
# identifies lowest tpm deployment # identifies lowest tpm deployment
import random import random
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx import httpx
from pydantic import BaseModel
import litellm import litellm
from litellm import token_counter from litellm import token_counter
@ -13,7 +12,7 @@ from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.router import RouterErrors from litellm.types.router import RouterErrors
from litellm.types.utils import LiteLLMPydanticObjectBase from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload
from litellm.utils import get_utc_datetime, print_verbose from litellm.utils import get_utc_datetime, print_verbose
if TYPE_CHECKING: if TYPE_CHECKING:
@ -223,20 +222,19 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
""" """
Update TPM/RPM usage on success Update TPM/RPM usage on success
""" """
if kwargs["litellm_params"].get("metadata") is None: standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
pass "standard_logging_object"
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
) )
if standard_logging_object is None:
id = kwargs["litellm_params"].get("model_info", {}).get("id", None) raise ValueError("standard_logging_object not passed in.")
model_group = standard_logging_object.get("model_group")
id = standard_logging_object.get("model_id")
if model_group is None or id is None: if model_group is None or id is None:
return return
elif isinstance(id, int): elif isinstance(id, int):
id = str(id) id = str(id)
total_tokens = response_obj["usage"]["total_tokens"] total_tokens = standard_logging_object.get("total_tokens")
# ------------ # ------------
# Setup values # Setup values
@ -261,7 +259,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
verbose_logger.exception( verbose_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( "litellm.proxy.hooks.lowest_tpm_rpm_v2.py::log_success_event(): Exception occured - {}".format(
str(e) str(e)
) )
) )
@ -272,26 +270,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
""" """
Update TPM usage on success Update TPM usage on success
""" """
if kwargs["litellm_params"].get("metadata") is None: standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
pass "standard_logging_object"
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
) )
if standard_logging_object is None:
if isinstance(response_obj, BaseModel) and not hasattr( raise ValueError("standard_logging_object not passed in.")
response_obj, "usage" model_group = standard_logging_object.get("model_group")
): id = standard_logging_object.get("model_id")
return
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None: if model_group is None or id is None:
return return
elif isinstance(id, int): elif isinstance(id, int):
id = str(id) id = str(id)
total_tokens = standard_logging_object.get("total_tokens")
total_tokens = cast(dict, response_obj)["usage"]["total_tokens"]
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
@ -319,7 +309,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
self.logged_success += 1 self.logged_success += 1
except Exception as e: except Exception as e:
verbose_logger.exception( verbose_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( "litellm.proxy.hooks.lowest_tpm_rpm_v2.py::async_log_success_event(): Exception occured - {}".format(
str(e) str(e)
) )
) )

View file

@ -0,0 +1,131 @@
import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch
import pytest
import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.datadog.datadog import *
from datetime import datetime, timedelta
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingModelInformation,
StandardLoggingMetadata,
StandardLoggingHiddenParams,
)
verbose_logger.setLevel(logging.DEBUG)
def create_standard_logging_payload() -> StandardLoggingPayload:
return StandardLoggingPayload(
id="test_id",
call_type="completion",
response_cost=0.1,
response_cost_failure_debug_info=None,
status="success",
total_tokens=30,
prompt_tokens=20,
completion_tokens=10,
startTime=1234567890.0,
endTime=1234567891.0,
completionStartTime=1234567890.5,
model_map_information=StandardLoggingModelInformation(
model_map_key="gpt-3.5-turbo", model_map_value=None
),
model="gpt-3.5-turbo",
model_id="model-123",
model_group="openai-gpt",
api_base="https://api.openai.com",
metadata=StandardLoggingMetadata(
user_api_key_hash="test_hash",
user_api_key_org_id=None,
user_api_key_alias="test_alias",
user_api_key_team_id="test_team",
user_api_key_user_id="test_user",
user_api_key_team_alias="test_team_alias",
spend_logs_metadata=None,
requester_ip_address="127.0.0.1",
requester_metadata=None,
),
cache_hit=False,
cache_key=None,
saved_cache_cost=0.0,
request_tags=[],
end_user=None,
requester_ip_address="127.0.0.1",
messages=[{"role": "user", "content": "Hello, world!"}],
response={"choices": [{"message": {"content": "Hi there!"}}]},
error_str=None,
model_parameters={"stream": True},
hidden_params=StandardLoggingHiddenParams(
model_id="model-123",
cache_key=None,
api_base="https://api.openai.com",
response_cost="0.1",
additional_headers=None,
),
)
def create_standard_logging_payload_with_long_content() -> StandardLoggingPayload:
return StandardLoggingPayload(
id="test_id",
call_type="completion",
response_cost=0.1,
response_cost_failure_debug_info=None,
status="success",
total_tokens=30,
prompt_tokens=20,
completion_tokens=10,
startTime=1234567890.0,
endTime=1234567891.0,
completionStartTime=1234567890.5,
model_map_information=StandardLoggingModelInformation(
model_map_key="gpt-3.5-turbo", model_map_value=None
),
model="gpt-3.5-turbo",
model_id="model-123",
model_group="openai-gpt",
api_base="https://api.openai.com",
metadata=StandardLoggingMetadata(
user_api_key_hash="test_hash",
user_api_key_org_id=None,
user_api_key_alias="test_alias",
user_api_key_team_id="test_team",
user_api_key_user_id="test_user",
user_api_key_team_alias="test_team_alias",
spend_logs_metadata=None,
requester_ip_address="127.0.0.1",
requester_metadata=None,
),
cache_hit=False,
cache_key=None,
saved_cache_cost=0.0,
request_tags=[],
end_user=None,
requester_ip_address="127.0.0.1",
messages=[{"role": "user", "content": "Hello, world!" * 80000}],
response={"choices": [{"message": {"content": "Hi there!" * 80000}}]},
error_str="error_str" * 80000,
model_parameters={"stream": True},
hidden_params=StandardLoggingHiddenParams(
model_id="model-123",
cache_key=None,
api_base="https://api.openai.com",
response_cost="0.1",
additional_headers=None,
),
)

View file

@ -18,7 +18,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from litellm.types.utils import StandardLoggingPayload
import pytest import pytest
import litellm import litellm
@ -28,6 +28,7 @@ from litellm.router_strategy.lowest_tpm_rpm_v2 import (
LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler, LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler,
) )
from litellm.utils import get_utc_datetime from litellm.utils import get_utc_datetime
from create_mock_standard_logging_payload import create_standard_logging_payload
### UNIT TESTS FOR TPM/RPM ROUTING ### ### UNIT TESTS FOR TPM/RPM ROUTING ###
@ -44,17 +45,25 @@ def test_tpm_rpm_updated():
) )
model_group = "gpt-3.5-turbo" model_group = "gpt-3.5-turbo"
deployment_id = "1234" deployment_id = "1234"
deployment = "azure/chatgpt-v-2"
total_tokens = 50
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["model_group"] = model_group
standard_logging_payload["model_id"] = deployment_id
standard_logging_payload["total_tokens"] = total_tokens
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"model_group": "gpt-3.5-turbo", "model_group": model_group,
"deployment": "azure/chatgpt-v-2", "deployment": deployment,
}, },
"model_info": {"id": deployment_id}, "model_info": {"id": deployment_id},
},
"standard_logging_object": standard_logging_payload,
} }
}
start_time = time.time() start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}} response_obj = {"usage": {"total_tokens": total_tokens}}
end_time = time.time() end_time = time.time()
lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"]) lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"])
lowest_tpm_logger.log_success_event( lowest_tpm_logger.log_success_event(
@ -97,18 +106,25 @@ def test_get_available_deployments():
) )
model_group = "gpt-3.5-turbo" model_group = "gpt-3.5-turbo"
## DEPLOYMENT 1 ## ## DEPLOYMENT 1 ##
total_tokens = 50
deployment_id = "1234" deployment_id = "1234"
deployment = "azure/chatgpt-v-2"
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["model_group"] = model_group
standard_logging_payload["model_id"] = deployment_id
standard_logging_payload["total_tokens"] = total_tokens
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"model_group": "gpt-3.5-turbo", "model_group": model_group,
"deployment": "azure/chatgpt-v-2", "deployment": deployment,
}, },
"model_info": {"id": deployment_id}, "model_info": {"id": deployment_id},
} },
"standard_logging_object": standard_logging_payload,
} }
start_time = time.time() start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}} response_obj = {"usage": {"total_tokens": total_tokens}}
end_time = time.time() end_time = time.time()
lowest_tpm_logger.log_success_event( lowest_tpm_logger.log_success_event(
response_obj=response_obj, response_obj=response_obj,
@ -117,18 +133,24 @@ def test_get_available_deployments():
end_time=end_time, end_time=end_time,
) )
## DEPLOYMENT 2 ## ## DEPLOYMENT 2 ##
total_tokens = 20
deployment_id = "5678" deployment_id = "5678"
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["model_group"] = model_group
standard_logging_payload["model_id"] = deployment_id
standard_logging_payload["total_tokens"] = total_tokens
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"model_group": "gpt-3.5-turbo", "model_group": model_group,
"deployment": "azure/chatgpt-v-2", "deployment": deployment,
}, },
"model_info": {"id": deployment_id}, "model_info": {"id": deployment_id},
} },
"standard_logging_object": standard_logging_payload,
} }
start_time = time.time() start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}} response_obj = {"usage": {"total_tokens": total_tokens}}
end_time = time.time() end_time = time.time()
lowest_tpm_logger.log_success_event( lowest_tpm_logger.log_success_event(
response_obj=response_obj, response_obj=response_obj,
@ -187,13 +209,17 @@ def test_router_get_available_deployments():
print(f"router id's: {router.get_model_ids()}") print(f"router id's: {router.get_model_ids()}")
## DEPLOYMENT 1 ## ## DEPLOYMENT 1 ##
deployment_id = 1 deployment_id = 1
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["model_group"] = "azure-model"
standard_logging_payload["model_id"] = str(deployment_id)
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"model_group": "azure-model", "model_group": "azure-model",
}, },
"model_info": {"id": 1}, "model_info": {"id": 1},
} },
"standard_logging_object": standard_logging_payload,
} }
start_time = time.time() start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}} response_obj = {"usage": {"total_tokens": 50}}
@ -206,13 +232,17 @@ def test_router_get_available_deployments():
) )
## DEPLOYMENT 2 ## ## DEPLOYMENT 2 ##
deployment_id = 2 deployment_id = 2
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["model_group"] = "azure-model"
standard_logging_payload["model_id"] = str(deployment_id)
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"model_group": "azure-model", "model_group": "azure-model",
}, },
"model_info": {"id": 2}, "model_info": {"id": 2},
} },
"standard_logging_object": standard_logging_payload,
} }
start_time = time.time() start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}} response_obj = {"usage": {"total_tokens": 20}}
@ -260,16 +290,22 @@ def test_router_skip_rate_limited_deployments():
## DEPLOYMENT 1 ## ## DEPLOYMENT 1 ##
deployment_id = 1 deployment_id = 1
total_tokens = 1439
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["model_group"] = "azure-model"
standard_logging_payload["model_id"] = str(deployment_id)
standard_logging_payload["total_tokens"] = total_tokens
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"model_group": "azure-model", "model_group": "azure-model",
}, },
"model_info": {"id": deployment_id}, "model_info": {"id": deployment_id},
} },
"standard_logging_object": standard_logging_payload,
} }
start_time = time.time() start_time = time.time()
response_obj = {"usage": {"total_tokens": 1439}} response_obj = {"usage": {"total_tokens": total_tokens}}
end_time = time.time() end_time = time.time()
router.lowesttpm_logger_v2.log_success_event( router.lowesttpm_logger_v2.log_success_event(
response_obj=response_obj, response_obj=response_obj,