Merge pull request #3370 from BerriAI/litellm_latency_buffer

fix(lowest_latency.py): allow setting a buffer for getting values within a certain latency threshold
This commit is contained in:
Krish Dholakia 2024-04-30 16:01:47 -07:00 committed by GitHub
commit ce9ede6110
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 266 additions and 31 deletions

View file

@ -326,9 +326,9 @@ class Router:
litellm.failure_callback.append(self.deployment_callback_on_failure) litellm.failure_callback.append(self.deployment_callback_on_failure)
else: else:
litellm.failure_callback = [self.deployment_callback_on_failure] litellm.failure_callback = [self.deployment_callback_on_failure]
verbose_router_logger.info( print( # noqa
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}" f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}"
) ) # noqa
self.routing_strategy_args = routing_strategy_args self.routing_strategy_args = routing_strategy_args
def print_deployment(self, deployment: dict): def print_deployment(self, deployment: dict):
@ -2616,6 +2616,11 @@ class Router:
for var in vars_to_include: for var in vars_to_include:
if var in _all_vars: if var in _all_vars:
_settings_to_return[var] = _all_vars[var] _settings_to_return[var] = _all_vars[var]
if (
var == "routing_strategy_args"
and self.routing_strategy == "latency-based-routing"
):
_settings_to_return[var] = self.lowestlatency_logger.routing_args.json()
return _settings_to_return return _settings_to_return
def update_settings(self, **kwargs): def update_settings(self, **kwargs):

View file

@ -4,6 +4,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -29,6 +30,7 @@ class LiteLLMBase(BaseModel):
class RoutingArgs(LiteLLMBase): class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 * 60 # 1 hour ttl: int = 1 * 60 * 60 # 1 hour
lowest_latency_buffer: float = 0
class LowestLatencyLoggingHandler(CustomLogger): class LowestLatencyLoggingHandler(CustomLogger):
@ -314,8 +316,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
# randomly sample from all_deployments, incase all deployments have latency=0.0 # randomly sample from all_deployments, incase all deployments have latency=0.0
_items = all_deployments.items() _items = all_deployments.items()
all_deployments = random.sample(list(_items), len(_items)) all_deployments = random.sample(list(_items), len(_items))
all_deployments = dict(all_deployments) all_deployments = dict(all_deployments)
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
potential_deployments = []
for item, item_map in all_deployments.items(): for item, item_map in all_deployments.items():
## get the item from model list ## get the item from model list
_deployment = None _deployment = None
@ -364,17 +370,33 @@ class LowestLatencyLoggingHandler(CustomLogger):
# End of Debugging Logic # End of Debugging Logic
# -------------- # # -------------- #
if item_latency == 0: if (
deployment = _deployment
break
elif (
item_tpm + input_tokens > _deployment_tpm item_tpm + input_tokens > _deployment_tpm
or item_rpm + 1 > _deployment_rpm or item_rpm + 1 > _deployment_rpm
): # if user passed in tpm / rpm in the model_list ): # if user passed in tpm / rpm in the model_list
continue continue
elif item_latency < lowest_latency: else:
lowest_latency = item_latency potential_deployments.append((_deployment, item_latency))
deployment = _deployment
if len(potential_deployments) == 0:
return None
# Sort potential deployments by latency
sorted_deployments = sorted(potential_deployments, key=lambda x: x[1])
# Find lowest latency deployment
lowest_latency = sorted_deployments[0][1]
# Find deployments within buffer of lowest latency
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
valid_deployments = [
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
]
# Pick a random deployment from valid deployments
random_valid_deployment = random.choice(valid_deployments)
deployment = random_valid_deployment[0]
if request_kwargs is not None and "metadata" in request_kwargs: if request_kwargs is not None and "metadata" in request_kwargs:
request_kwargs["metadata"][ request_kwargs["metadata"][
"_latency_per_deployment" "_latency_per_deployment"

View file

@ -394,6 +394,8 @@ async def test_async_vertexai_response():
pass pass
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except litellm.APIError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")

View file

@ -631,3 +631,95 @@ async def test_lowest_latency_routing_first_pick():
# assert that len(deployments) >1 # assert that len(deployments) >1
assert len(deployments) > 1 assert len(deployments) > 1
@pytest.mark.parametrize("buffer", [0, 1])
@pytest.mark.asyncio
async def test_lowest_latency_routing_buffer(buffer):
"""
Allow shuffling calls within a certain latency buffer
"""
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
routing_strategy="latency-based-routing",
set_verbose=False,
num_retries=3,
routing_strategy_args={"lowest_latency_buffer": buffer},
) # type: ignore
## DEPLOYMENT 1 ##
deployment_id = 1
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": 1},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
time.sleep(3)
end_time = time.time()
router.lowestlatency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## DEPLOYMENT 2 ##
deployment_id = 2
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": 2},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}}
time.sleep(2)
end_time = time.time()
router.lowestlatency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
selected_deployments = {}
for _ in range(50):
print(router.get_available_deployment(model="azure-model"))
selected_deployments[
router.get_available_deployment(model="azure-model")["model_info"]["id"]
] = 1
if buffer == 0:
assert len(selected_deployments.keys()) == 1
else:
assert len(selected_deployments.keys()) == 2

View file

@ -81,7 +81,6 @@ def test_async_fallbacks(caplog):
# Define the expected log messages # Define the expected log messages
# - error request, falling back notice, success notice # - error request, falling back notice, success notice
expected_logs = [ expected_logs = [
"Intialized router with Routing strategy: simple-shuffle\n\nRouting fallbacks: [{'gpt-3.5-turbo': ['azure/gpt-3.5-turbo']}]\n\nRouting context window fallbacks: None\n\nRouter Redis Caching=None",
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m", "litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
"Falling back to model_group = azure/gpt-3.5-turbo", "Falling back to model_group = azure/gpt-3.5-turbo",
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m", "litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",

View file

@ -766,10 +766,10 @@ def test_usage_based_routing_fallbacks():
load_dotenv() load_dotenv()
# Constants for TPM and RPM allocation # Constants for TPM and RPM allocation
AZURE_FAST_TPM = 3 AZURE_FAST_RPM = 3
AZURE_BASIC_TPM = 4 AZURE_BASIC_RPM = 4
OPENAI_TPM = 400 OPENAI_RPM = 10
ANTHROPIC_TPM = 100000 ANTHROPIC_RPM = 100000
def get_azure_params(deployment_name: str): def get_azure_params(deployment_name: str):
params = { params = {
@ -798,22 +798,26 @@ def test_usage_based_routing_fallbacks():
{ {
"model_name": "azure/gpt-4-fast", "model_name": "azure/gpt-4-fast",
"litellm_params": get_azure_params("chatgpt-v-2"), "litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": AZURE_FAST_TPM, "model_info": {"id": 1},
"rpm": AZURE_FAST_RPM,
}, },
{ {
"model_name": "azure/gpt-4-basic", "model_name": "azure/gpt-4-basic",
"litellm_params": get_azure_params("chatgpt-v-2"), "litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": AZURE_BASIC_TPM, "model_info": {"id": 2},
"rpm": AZURE_BASIC_RPM,
}, },
{ {
"model_name": "openai-gpt-4", "model_name": "openai-gpt-4",
"litellm_params": get_openai_params("gpt-3.5-turbo"), "litellm_params": get_openai_params("gpt-3.5-turbo"),
"tpm": OPENAI_TPM, "model_info": {"id": 3},
"rpm": OPENAI_RPM,
}, },
{ {
"model_name": "anthropic-claude-instant-1.2", "model_name": "anthropic-claude-instant-1.2",
"litellm_params": get_anthropic_params("claude-instant-1.2"), "litellm_params": get_anthropic_params("claude-instant-1.2"),
"tpm": ANTHROPIC_TPM, "model_info": {"id": 4},
"rpm": ANTHROPIC_RPM,
}, },
] ]
# litellm.set_verbose=True # litellm.set_verbose=True
@ -844,10 +848,10 @@ def test_usage_based_routing_fallbacks():
mock_response="very nice to meet you", mock_response="very nice to meet you",
) )
print("response: ", response) print("response: ", response)
print("response._hidden_params: ", response._hidden_params) print(f"response._hidden_params: {response._hidden_params}")
# in this test, we expect azure/gpt-4 fast to fail, then azure-gpt-4 basic to fail and then openai-gpt-4 to pass # in this test, we expect azure/gpt-4 fast to fail, then azure-gpt-4 basic to fail and then openai-gpt-4 to pass
# the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM # the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM
assert response._hidden_params["custom_llm_provider"] == "openai" assert response._hidden_params["model_id"] == "1"
# now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2 # now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2
for i in range(20): for i in range(20):
@ -861,7 +865,7 @@ def test_usage_based_routing_fallbacks():
print("response._hidden_params: ", response._hidden_params) print("response._hidden_params: ", response._hidden_params)
if i == 19: if i == 19:
# by the 19th call we should have hit TPM LIMIT for OpenAI, it should fallback to anthropic-claude-instant-1.2 # by the 19th call we should have hit TPM LIMIT for OpenAI, it should fallback to anthropic-claude-instant-1.2
assert response._hidden_params["custom_llm_provider"] == "anthropic" assert response._hidden_params["model_id"] == "4"
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred {e}") pytest.fail(f"An exception occurred {e}")

View file

@ -8130,7 +8130,10 @@ def exception_type(
llm_provider="vertex_ai", llm_provider="vertex_ai",
response=original_exception.response, response=original_exception.response,
) )
elif "None Unknown Error." in error_str: elif (
"None Unknown Error." in error_str
or "Content has no parts." in error_str
):
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
message=f"VertexAIException - {error_str}", message=f"VertexAIException - {error_str}",

View file

@ -15,7 +15,13 @@ import {
Grid, Grid,
Button, Button,
TextInput, TextInput,
Select as Select2,
SelectItem,
Col, Col,
Accordion,
AccordionBody,
AccordionHeader,
AccordionList,
} from "@tremor/react"; } from "@tremor/react";
import { TabPanel, TabPanels, TabGroup, TabList, Tab, Icon } from "@tremor/react"; import { TabPanel, TabPanels, TabGroup, TabList, Tab, Icon } from "@tremor/react";
import { getCallbacksCall, setCallbacksCall, serviceHealthCheck } from "./networking"; import { getCallbacksCall, setCallbacksCall, serviceHealthCheck } from "./networking";
@ -24,6 +30,7 @@ import { InformationCircleIcon, PencilAltIcon, PencilIcon, StatusOnlineIcon, Tra
import StaticGenerationSearchParamsBailoutProvider from "next/dist/client/components/static-generation-searchparams-bailout-provider"; import StaticGenerationSearchParamsBailoutProvider from "next/dist/client/components/static-generation-searchparams-bailout-provider";
import AddFallbacks from "./add_fallbacks" import AddFallbacks from "./add_fallbacks"
import openai from "openai"; import openai from "openai";
import Paragraph from "antd/es/skeleton/Paragraph";
interface GeneralSettingsPageProps { interface GeneralSettingsPageProps {
accessToken: string | null; accessToken: string | null;
@ -72,6 +79,62 @@ async function testFallbackModelResponse(
} }
} }
interface AccordionHeroProps {
selectedStrategy: string | null;
strategyArgs: routingStrategyArgs;
paramExplanation: { [key: string]: string }
}
interface routingStrategyArgs {
ttl?: number;
lowest_latency_buffer?: number;
}
const defaultLowestLatencyArgs: routingStrategyArgs = {
"ttl": 3600,
"lowest_latency_buffer": 0
}
export const AccordionHero: React.FC<AccordionHeroProps> = ({ selectedStrategy, strategyArgs, paramExplanation }) => (
<Accordion>
<AccordionHeader className="text-sm font-medium text-tremor-content-strong dark:text-dark-tremor-content-strong">Routing Strategy Specific Args</AccordionHeader>
<AccordionBody>
{
selectedStrategy == "latency-based-routing" ?
<Card>
<Table>
<TableHead>
<TableRow>
<TableHeaderCell>Setting</TableHeaderCell>
<TableHeaderCell>Value</TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
{Object.entries(strategyArgs).map(([param, value]) => (
<TableRow key={param}>
<TableCell>
<Text>{param}</Text>
<p style={{fontSize: '0.65rem', color: '#808080', fontStyle: 'italic'}} className="mt-1">{paramExplanation[param]}</p>
</TableCell>
<TableCell>
<TextInput
name={param}
defaultValue={
typeof value === 'object' ? JSON.stringify(value, null, 2) : value.toString()
}
/>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</Card>
: <Text>No specific settings</Text>
}
</AccordionBody>
</Accordion>
);
const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
accessToken, accessToken,
userRole, userRole,
@ -82,6 +145,8 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
const [isModalVisible, setIsModalVisible] = useState(false); const [isModalVisible, setIsModalVisible] = useState(false);
const [form] = Form.useForm(); const [form] = Form.useForm();
const [selectedCallback, setSelectedCallback] = useState<string | null>(null); const [selectedCallback, setSelectedCallback] = useState<string | null>(null);
const [selectedStrategy, setSelectedStrategy] = useState<string | null>(null)
const [strategySettings, setStrategySettings] = useState<routingStrategyArgs | null>(null);
let paramExplanation: { [key: string]: string } = { let paramExplanation: { [key: string]: string } = {
"routing_strategy_args": "(dict) Arguments to pass to the routing strategy", "routing_strategy_args": "(dict) Arguments to pass to the routing strategy",
@ -91,6 +156,8 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
"num_retries": "(int) Number of retries for failed requests. Defaults to 0.", "num_retries": "(int) Number of retries for failed requests. Defaults to 0.",
"timeout": "(float) Timeout for requests. Defaults to None.", "timeout": "(float) Timeout for requests. Defaults to None.",
"retry_after": "(int) Minimum time to wait before retrying a failed request", "retry_after": "(int) Minimum time to wait before retrying a failed request",
"ttl": "(int) Sliding window to look back over when calculating the average latency of a deployment. Default - 1 hour (in seconds).",
"lowest_latency_buffer": "(float) Shuffle between deployments within this % of the lowest latency. Default - 0 (i.e. always pick lowest latency)."
} }
useEffect(() => { useEffect(() => {
@ -141,6 +208,7 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
try { try {
await setCallbacksCall(accessToken, payload); await setCallbacksCall(accessToken, payload);
setRouterSettings({ ...routerSettings }); setRouterSettings({ ...routerSettings });
setSelectedStrategy(routerSettings["routing_strategy"])
message.success("Router settings updated successfully"); message.success("Router settings updated successfully");
} catch (error) { } catch (error) {
message.error("Failed to update router settings: " + error, 20); message.error("Failed to update router settings: " + error, 20);
@ -156,11 +224,33 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
const updatedVariables = Object.fromEntries( const updatedVariables = Object.fromEntries(
Object.entries(router_settings).map(([key, value]) => { Object.entries(router_settings).map(([key, value]) => {
if (key !== 'routing_strategy_args') { if (key !== 'routing_strategy_args' && key !== "routing_strategy") {
return [key, (document.querySelector(`input[name="${key}"]`) as HTMLInputElement)?.value || value]; return [key, (document.querySelector(`input[name="${key}"]`) as HTMLInputElement)?.value || value];
} }
else if (key == "routing_strategy") {
return [key, selectedStrategy]
}
else if (key == "routing_strategy_args" && selectedStrategy == "latency-based-routing") {
let setRoutingStrategyArgs: routingStrategyArgs = {}
const lowestLatencyBufferElement = document.querySelector(`input[name="lowest_latency_buffer"]`) as HTMLInputElement;
const ttlElement = document.querySelector(`input[name="ttl"]`) as HTMLInputElement;
if (lowestLatencyBufferElement?.value) {
setRoutingStrategyArgs["lowest_latency_buffer"] = Number(lowestLatencyBufferElement.value)
}
if (ttlElement?.value) {
setRoutingStrategyArgs["ttl"] = Number(ttlElement.value)
}
console.log(`setRoutingStrategyArgs: ${setRoutingStrategyArgs}`)
return [
"routing_strategy_args", setRoutingStrategyArgs
]
}
return null; return null;
}).filter(entry => entry !== null) as Iterable<[string, unknown]> }).filter(entry => entry !== null && entry !== undefined) as Iterable<[string, unknown]>
); );
console.log("updatedVariables", updatedVariables); console.log("updatedVariables", updatedVariables);
@ -183,6 +273,7 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
return null; return null;
} }
return ( return (
<div className="w-full mx-4"> <div className="w-full mx-4">
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2"> <TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
@ -203,24 +294,41 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
</TableRow> </TableRow>
</TableHead> </TableHead>
<TableBody> <TableBody>
{Object.entries(routerSettings).filter(([param, value]) => param != "fallbacks" && param != "context_window_fallbacks").map(([param, value]) => ( {Object.entries(routerSettings).filter(([param, value]) => param != "fallbacks" && param != "context_window_fallbacks" && param != "routing_strategy_args").map(([param, value]) => (
<TableRow key={param}> <TableRow key={param}>
<TableCell> <TableCell>
<Text>{param}</Text> <Text>{param}</Text>
<p style={{fontSize: '0.65rem', color: '#808080', fontStyle: 'italic'}} className="mt-1">{paramExplanation[param]}</p> <p style={{fontSize: '0.65rem', color: '#808080', fontStyle: 'italic'}} className="mt-1">{paramExplanation[param]}</p>
</TableCell> </TableCell>
<TableCell> <TableCell>
<TextInput {
name={param} param == "routing_strategy" ?
defaultValue={ <Select2 defaultValue={value} className="w-full max-w-md" onValueChange={setSelectedStrategy}>
typeof value === 'object' ? JSON.stringify(value, null, 2) : value.toString() <SelectItem value="usage-based-routing">usage-based-routing</SelectItem>
} <SelectItem value="latency-based-routing">latency-based-routing</SelectItem>
/> <SelectItem value="simple-shuffle">simple-shuffle</SelectItem>
</Select2> :
<TextInput
name={param}
defaultValue={
typeof value === 'object' ? JSON.stringify(value, null, 2) : value.toString()
}
/>
}
</TableCell> </TableCell>
</TableRow> </TableRow>
))} ))}
</TableBody> </TableBody>
</Table> </Table>
<AccordionHero
selectedStrategy={selectedStrategy}
strategyArgs={
routerSettings && routerSettings['routing_strategy_args'] && Object.keys(routerSettings['routing_strategy_args']).length > 0
? routerSettings['routing_strategy_args']
: defaultLowestLatencyArgs // default value when keys length is 0
}
paramExplanation={paramExplanation}
/>
</Card> </Card>
<Col> <Col>
<Button className="mt-2" onClick={() => handleSaveChanges(routerSettings)}> <Button className="mt-2" onClick={() => handleSaveChanges(routerSettings)}>