Litellm dev 12 26 2024 p4 (#7439)

* fix(model_dashboard.tsx): support setting model_info params - e.g. mode on ui

Closes https://github.com/BerriAI/litellm/issues/5270

* fix(lowest_tpm_rpm_v2.py): deployment rpm over limit check

fixes selection error when getting potential deployments below known tpm/rpm limit

 Fixes https://github.com/BerriAI/litellm/issues/7395

* fix(test_tpm_rpm_routing_v2.py): add unit test for https://github.com/BerriAI/litellm/issues/7395

* fix(lowest_tpm_rpm_v2.py): fix tpm key name in dict post rpm update

* test: rename test to run earlier

* test: skip flaky test
This commit is contained in:
Krish Dholakia 2024-12-27 12:01:42 -08:00 committed by GitHub
parent 7cf347918e
commit d88de268dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 173 additions and 69 deletions

View file

@ -1,20 +1,17 @@
model_list:
- model_name: openai/*
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: fake-openai-endpoint
- model_name: model-test
litellm_params:
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
- model_name: whisper-v3
mock_response: "Hello, world!"
rpm: 1
- model_name: model-test
litellm_params:
model: fireworks_ai/whisper-v3
api_base: https://audio-prod.us-virginia-1.direct.fireworks.ai/v1
api_key: os.environ/FIREWORKS_API_KEY
model_info:
mode: audio_transcription
model: openai/o1-mini
api_key: os.environ/OPENAI_API_KEY
mock_response: "Hello, world, it's o1!"
rpm: 10
litellm_settings:
callbacks: ["prometheus"]
disable_end_user_cost_tracking_prometheus_only: true
router_settings:
routing_strategy: usage-based-routing-v2
disable_cooldowns: True

View file

@ -315,59 +315,14 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
)
pass
def _common_checks_available_deployment( # noqa: PLR0915
def _return_potential_deployments(
self,
model_group: str,
healthy_deployments: list,
tpm_keys: list,
tpm_values: Optional[list],
rpm_keys: list,
rpm_values: Optional[list],
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
) -> Optional[dict]:
"""
Common checks for get available deployment, across sync + async implementations
"""
if tpm_values is None or rpm_values is None:
return None
tpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(tpm_keys):
tpm_dict[tpm_keys[idx]] = tpm_values[idx]
rpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(rpm_keys):
rpm_dict[rpm_keys[idx]] = rpm_values[idx]
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
verbose_router_logger.debug(f"input_tokens={input_tokens}")
# -----------------------
# Find lowest used model
# ----------------------
healthy_deployments: List[Dict],
all_deployments: Dict,
input_tokens: int,
rpm_dict: Dict,
):
lowest_tpm = float("inf")
if tpm_dict is None: # base case - none of the deployments have been used
# initialize a tpm dict with {model_id: 0}
tpm_dict = {}
for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0
else:
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
for d in healthy_deployments:
## if healthy deployment not yet used
tpm_key = f"{d['model_info']['id']}:tpm:{current_minute}"
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
tpm_dict[tpm_key] = 0
all_deployments = tpm_dict
potential_deployments = [] # if multiple deployments have the same low value
for item, item_tpm in all_deployments.items():
## get the item from model list
@ -402,8 +357,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
_deployment_rpm = float("inf")
if item_tpm + input_tokens > _deployment_tpm:
continue
elif (rpm_dict is not None and item in rpm_dict) and (
rpm_dict[item] + 1 >= _deployment_rpm
elif (
(rpm_dict is not None and item in rpm_dict)
and rpm_dict[item] is not None
and (rpm_dict[item] + 1 >= _deployment_rpm)
):
continue
elif item_tpm == lowest_tpm:
@ -411,6 +368,62 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm
potential_deployments = [_deployment]
return potential_deployments
def _common_checks_available_deployment( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
tpm_keys: list,
tpm_values: Optional[list],
rpm_keys: list,
rpm_values: Optional[list],
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
) -> Optional[dict]:
"""
Common checks for get available deployment, across sync + async implementations
"""
if tpm_values is None or rpm_values is None:
return None
tpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(tpm_keys):
tpm_dict[tpm_keys[idx].split(":")[0]] = tpm_values[idx]
rpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(rpm_keys):
rpm_dict[rpm_keys[idx].split(":")[0]] = rpm_values[idx]
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
verbose_router_logger.debug(f"input_tokens={input_tokens}")
# -----------------------
# Find lowest used model
# ----------------------
if tpm_dict is None: # base case - none of the deployments have been used
# initialize a tpm dict with {model_id: 0}
tpm_dict = {}
for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0
else:
for d in healthy_deployments:
## if healthy deployment not yet used
tpm_key = d["model_info"]["id"]
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
tpm_dict[tpm_key] = 0
all_deployments = tpm_dict
potential_deployments = self._return_potential_deployments(
healthy_deployments=healthy_deployments,
all_deployments=all_deployments,
input_tokens=input_tokens,
rpm_dict=rpm_dict,
)
print_verbose("returning picked lowest tpm/rpm deployment.")
if len(potential_deployments) > 0:

View file

@ -8,7 +8,7 @@ import sys
import time
import traceback
from datetime import datetime
from typing import Dict
from dotenv import load_dotenv
load_dotenv()
@ -175,7 +175,7 @@ def test_get_available_deployments():
def test_router_get_available_deployments():
"""
Test if routers 'get_available_deployments' returns the least busy deployment
Test if routers 'get_available_deployments' returns the lowest tpm deployment
"""
model_list = [
{
@ -634,3 +634,67 @@ def test_router_caching_ttl_sync():
assert current_ttl >= 0
print(f"current_ttl: {current_ttl}")
def test_return_potential_deployments():
"""
Assert deployment at limit is filtered out
"""
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
test_cache = DualCache()
model_list = []
lowest_tpm_logger = LowestTPMLoggingHandler(
router_cache=test_cache, model_list=model_list
)
args: Dict = {
"healthy_deployments": [
{
"model_name": "model-test",
"litellm_params": {
"rpm": 1,
"api_key": "sk-1234",
"model": "openai/gpt-3.5-turbo",
"mock_response": "Hello, world!",
},
"model_info": {
"id": "dd8e67fce56963bae6a60206b48d3f03faeb43be20cf0fd96a5f39b1a2bbd11d",
"db_model": False,
},
},
{
"model_name": "model-test",
"litellm_params": {
"rpm": 10,
"api_key": "sk-1234",
"model": "openai/o1-mini",
"mock_response": "Hello, world, it's o1!",
},
"model_info": {
"id": "e13a56981607e1749b1433e6968ffc7df5552540ad3faa44b0b44ba4f3443bfe",
"db_model": False,
},
},
],
"all_deployments": {
"dd8e67fce56963bae6a60206b48d3f03faeb43be20cf0fd96a5f39b1a2bbd11d": None,
"e13a56981607e1749b1433e6968ffc7df5552540ad3faa44b0b44ba4f3443bfe": None,
"dd8e67fce56963bae6a60206b48d3f03faeb43be20cf0fd96a5f39b1a2bbd11d:tpm:02-17": 0,
"e13a56981607e1749b1433e6968ffc7df5552540ad3faa44b0b44ba4f3443bfe:tpm:02-17": 0,
},
"input_tokens": 98,
"rpm_dict": {
"dd8e67fce56963bae6a60206b48d3f03faeb43be20cf0fd96a5f39b1a2bbd11d": 1,
"e13a56981607e1749b1433e6968ffc7df5552540ad3faa44b0b44ba4f3443bfe": None,
},
}
potential_deployments = lowest_tpm_logger._return_potential_deployments(
healthy_deployments=args["healthy_deployments"],
all_deployments=args["all_deployments"],
input_tokens=args["input_tokens"],
rpm_dict=args["rpm_dict"],
)
assert len(potential_deployments) == 1

View file

@ -250,6 +250,23 @@ const handleSubmit = async (
litellmParamsObj[key] = value;
}
}
} else if (key == "model_info_params") {
console.log("model_info_params:", value);
let modelInfoParams = {};
if (value && value != undefined) {
try {
modelInfoParams = JSON.parse(value);
} catch (error) {
message.error(
"Failed to parse LiteLLM Extra Params: " + error,
10
);
throw new Error("Failed to parse litellm_extra_params: " + error);
}
for (const [key, value] of Object.entries(modelInfoParams)) {
modelInfoObj[key] = value;
}
}
}
// Check if key is any of the specified API related keys
@ -2056,6 +2073,19 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</Text>
</Col>
</Row>
<Form.Item
label="Model Info"
name="model_info_params"
tooltip="Optional model info params. Returned when calling `/model/info` endpoint."
className="mb-0"
>
<TextArea
rows={4}
placeholder='{
"mode": "chat"
}'
/>
</Form.Item>
</>
<div style={{ textAlign: "center", marginTop: "10px" }}>
<Button2 htmlType="submit">Add Model</Button2>