mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Litellm dev 12 26 2024 p4 (#7439)
* fix(model_dashboard.tsx): support setting model_info params - e.g. mode on ui Closes https://github.com/BerriAI/litellm/issues/5270 * fix(lowest_tpm_rpm_v2.py): deployment rpm over limit check fixes selection error when getting potential deployments below known tpm/rpm limit Fixes https://github.com/BerriAI/litellm/issues/7395 * fix(test_tpm_rpm_routing_v2.py): add unit test for https://github.com/BerriAI/litellm/issues/7395 * fix(lowest_tpm_rpm_v2.py): fix tpm key name in dict post rpm update * test: rename test to run earlier * test: skip flaky test
This commit is contained in:
parent
7cf347918e
commit
d88de268dd
4 changed files with 173 additions and 69 deletions
|
@ -1,20 +1,17 @@
|
|||
model_list:
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: fake-openai-endpoint
|
||||
- model_name: model-test
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: whisper-v3
|
||||
mock_response: "Hello, world!"
|
||||
rpm: 1
|
||||
- model_name: model-test
|
||||
litellm_params:
|
||||
model: fireworks_ai/whisper-v3
|
||||
api_base: https://audio-prod.us-virginia-1.direct.fireworks.ai/v1
|
||||
api_key: os.environ/FIREWORKS_API_KEY
|
||||
model_info:
|
||||
mode: audio_transcription
|
||||
model: openai/o1-mini
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
mock_response: "Hello, world, it's o1!"
|
||||
rpm: 10
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
||||
disable_end_user_cost_tracking_prometheus_only: true
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
disable_cooldowns: True
|
|
@ -315,59 +315,14 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
)
|
||||
pass
|
||||
|
||||
def _common_checks_available_deployment( # noqa: PLR0915
|
||||
def _return_potential_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
tpm_keys: list,
|
||||
tpm_values: Optional[list],
|
||||
rpm_keys: list,
|
||||
rpm_values: Optional[list],
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Common checks for get available deployment, across sync + async implementations
|
||||
"""
|
||||
if tpm_values is None or rpm_values is None:
|
||||
return None
|
||||
|
||||
tpm_dict = {} # {model_id: 1, ..}
|
||||
for idx, key in enumerate(tpm_keys):
|
||||
tpm_dict[tpm_keys[idx]] = tpm_values[idx]
|
||||
|
||||
rpm_dict = {} # {model_id: 1, ..}
|
||||
for idx, key in enumerate(rpm_keys):
|
||||
rpm_dict[rpm_keys[idx]] = rpm_values[idx]
|
||||
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
verbose_router_logger.debug(f"input_tokens={input_tokens}")
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
healthy_deployments: List[Dict],
|
||||
all_deployments: Dict,
|
||||
input_tokens: int,
|
||||
rpm_dict: Dict,
|
||||
):
|
||||
lowest_tpm = float("inf")
|
||||
|
||||
if tpm_dict is None: # base case - none of the deployments have been used
|
||||
# initialize a tpm dict with {model_id: 0}
|
||||
tpm_dict = {}
|
||||
for deployment in healthy_deployments:
|
||||
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||
else:
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
tpm_key = f"{d['model_info']['id']}:tpm:{current_minute}"
|
||||
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
|
||||
tpm_dict[tpm_key] = 0
|
||||
|
||||
all_deployments = tpm_dict
|
||||
potential_deployments = [] # if multiple deployments have the same low value
|
||||
for item, item_tpm in all_deployments.items():
|
||||
## get the item from model list
|
||||
|
@ -402,8 +357,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
_deployment_rpm = float("inf")
|
||||
if item_tpm + input_tokens > _deployment_tpm:
|
||||
continue
|
||||
elif (rpm_dict is not None and item in rpm_dict) and (
|
||||
rpm_dict[item] + 1 >= _deployment_rpm
|
||||
elif (
|
||||
(rpm_dict is not None and item in rpm_dict)
|
||||
and rpm_dict[item] is not None
|
||||
and (rpm_dict[item] + 1 >= _deployment_rpm)
|
||||
):
|
||||
continue
|
||||
elif item_tpm == lowest_tpm:
|
||||
|
@ -411,6 +368,62 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
elif item_tpm < lowest_tpm:
|
||||
lowest_tpm = item_tpm
|
||||
potential_deployments = [_deployment]
|
||||
return potential_deployments
|
||||
|
||||
def _common_checks_available_deployment( # noqa: PLR0915
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
tpm_keys: list,
|
||||
tpm_values: Optional[list],
|
||||
rpm_keys: list,
|
||||
rpm_values: Optional[list],
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Common checks for get available deployment, across sync + async implementations
|
||||
"""
|
||||
|
||||
if tpm_values is None or rpm_values is None:
|
||||
return None
|
||||
|
||||
tpm_dict = {} # {model_id: 1, ..}
|
||||
for idx, key in enumerate(tpm_keys):
|
||||
tpm_dict[tpm_keys[idx].split(":")[0]] = tpm_values[idx]
|
||||
|
||||
rpm_dict = {} # {model_id: 1, ..}
|
||||
for idx, key in enumerate(rpm_keys):
|
||||
rpm_dict[rpm_keys[idx].split(":")[0]] = rpm_values[idx]
|
||||
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
verbose_router_logger.debug(f"input_tokens={input_tokens}")
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
|
||||
if tpm_dict is None: # base case - none of the deployments have been used
|
||||
# initialize a tpm dict with {model_id: 0}
|
||||
tpm_dict = {}
|
||||
for deployment in healthy_deployments:
|
||||
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||
else:
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
tpm_key = d["model_info"]["id"]
|
||||
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
|
||||
tpm_dict[tpm_key] = 0
|
||||
|
||||
all_deployments = tpm_dict
|
||||
potential_deployments = self._return_potential_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=all_deployments,
|
||||
input_tokens=input_tokens,
|
||||
rpm_dict=rpm_dict,
|
||||
)
|
||||
print_verbose("returning picked lowest tpm/rpm deployment.")
|
||||
|
||||
if len(potential_deployments) > 0:
|
||||
|
|
|
@ -8,7 +8,7 @@ import sys
|
|||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Dict
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
@ -175,7 +175,7 @@ def test_get_available_deployments():
|
|||
|
||||
def test_router_get_available_deployments():
|
||||
"""
|
||||
Test if routers 'get_available_deployments' returns the least busy deployment
|
||||
Test if routers 'get_available_deployments' returns the lowest tpm deployment
|
||||
"""
|
||||
model_list = [
|
||||
{
|
||||
|
@ -634,3 +634,67 @@ def test_router_caching_ttl_sync():
|
|||
assert current_ttl >= 0
|
||||
|
||||
print(f"current_ttl: {current_ttl}")
|
||||
|
||||
|
||||
def test_return_potential_deployments():
|
||||
"""
|
||||
Assert deployment at limit is filtered out
|
||||
"""
|
||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||
|
||||
test_cache = DualCache()
|
||||
model_list = []
|
||||
lowest_tpm_logger = LowestTPMLoggingHandler(
|
||||
router_cache=test_cache, model_list=model_list
|
||||
)
|
||||
|
||||
args: Dict = {
|
||||
"healthy_deployments": [
|
||||
{
|
||||
"model_name": "model-test",
|
||||
"litellm_params": {
|
||||
"rpm": 1,
|
||||
"api_key": "sk-1234",
|
||||
"model": "openai/gpt-3.5-turbo",
|
||||
"mock_response": "Hello, world!",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "dd8e67fce56963bae6a60206b48d3f03faeb43be20cf0fd96a5f39b1a2bbd11d",
|
||||
"db_model": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "model-test",
|
||||
"litellm_params": {
|
||||
"rpm": 10,
|
||||
"api_key": "sk-1234",
|
||||
"model": "openai/o1-mini",
|
||||
"mock_response": "Hello, world, it's o1!",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "e13a56981607e1749b1433e6968ffc7df5552540ad3faa44b0b44ba4f3443bfe",
|
||||
"db_model": False,
|
||||
},
|
||||
},
|
||||
],
|
||||
"all_deployments": {
|
||||
"dd8e67fce56963bae6a60206b48d3f03faeb43be20cf0fd96a5f39b1a2bbd11d": None,
|
||||
"e13a56981607e1749b1433e6968ffc7df5552540ad3faa44b0b44ba4f3443bfe": None,
|
||||
"dd8e67fce56963bae6a60206b48d3f03faeb43be20cf0fd96a5f39b1a2bbd11d:tpm:02-17": 0,
|
||||
"e13a56981607e1749b1433e6968ffc7df5552540ad3faa44b0b44ba4f3443bfe:tpm:02-17": 0,
|
||||
},
|
||||
"input_tokens": 98,
|
||||
"rpm_dict": {
|
||||
"dd8e67fce56963bae6a60206b48d3f03faeb43be20cf0fd96a5f39b1a2bbd11d": 1,
|
||||
"e13a56981607e1749b1433e6968ffc7df5552540ad3faa44b0b44ba4f3443bfe": None,
|
||||
},
|
||||
}
|
||||
|
||||
potential_deployments = lowest_tpm_logger._return_potential_deployments(
|
||||
healthy_deployments=args["healthy_deployments"],
|
||||
all_deployments=args["all_deployments"],
|
||||
input_tokens=args["input_tokens"],
|
||||
rpm_dict=args["rpm_dict"],
|
||||
)
|
||||
|
||||
assert len(potential_deployments) == 1
|
||||
|
|
|
@ -250,6 +250,23 @@ const handleSubmit = async (
|
|||
litellmParamsObj[key] = value;
|
||||
}
|
||||
}
|
||||
} else if (key == "model_info_params") {
|
||||
console.log("model_info_params:", value);
|
||||
let modelInfoParams = {};
|
||||
if (value && value != undefined) {
|
||||
try {
|
||||
modelInfoParams = JSON.parse(value);
|
||||
} catch (error) {
|
||||
message.error(
|
||||
"Failed to parse LiteLLM Extra Params: " + error,
|
||||
10
|
||||
);
|
||||
throw new Error("Failed to parse litellm_extra_params: " + error);
|
||||
}
|
||||
for (const [key, value] of Object.entries(modelInfoParams)) {
|
||||
modelInfoObj[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if key is any of the specified API related keys
|
||||
|
@ -2056,6 +2073,19 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
</Text>
|
||||
</Col>
|
||||
</Row>
|
||||
<Form.Item
|
||||
label="Model Info"
|
||||
name="model_info_params"
|
||||
tooltip="Optional model info params. Returned when calling `/model/info` endpoint."
|
||||
className="mb-0"
|
||||
>
|
||||
<TextArea
|
||||
rows={4}
|
||||
placeholder='{
|
||||
"mode": "chat"
|
||||
}'
|
||||
/>
|
||||
</Form.Item>
|
||||
</>
|
||||
<div style={{ textAlign: "center", marginTop: "10px" }}>
|
||||
<Button2 htmlType="submit">Add Model</Button2>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue