fix(router.py): generate consistent model id's

having the same id for a deployment, lets redis usage caching work across multiple instances
This commit is contained in:
Krrish Dholakia 2024-04-10 15:23:57 -07:00
parent 180cf9bd5c
commit a47a719caa
4 changed files with 78 additions and 9 deletions

View file

@ -11,7 +11,7 @@ import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
import litellm, openai import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache from litellm.caching import RedisCache, InMemoryCache, DualCache
import logging, asyncio import logging, asyncio
@ -2072,6 +2072,34 @@ class Router:
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
def _generate_model_id(self, model_group: str, litellm_params: dict):
"""
Helper function to consistently generate the same id for a deployment
- create a string from all the litellm params
- hash
- use hash as id
"""
concat_str = model_group
for k, v in litellm_params.items():
if isinstance(k, str):
concat_str += k
elif isinstance(k, dict):
concat_str += json.dumps(k)
else:
concat_str += str(k)
if isinstance(v, str):
concat_str += v
elif isinstance(v, dict):
concat_str += json.dumps(v)
else:
concat_str += str(v)
hash_object = hashlib.sha256(concat_str.encode())
return hash_object.hexdigest()
def set_model_list(self, model_list: list): def set_model_list(self, model_list: list):
original_model_list = copy.deepcopy(model_list) original_model_list = copy.deepcopy(model_list)
self.model_list = [] self.model_list = []
@ -2087,7 +2115,13 @@ class Router:
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
_litellm_params[k] = litellm.get_secret(v) _litellm_params[k] = litellm.get_secret(v)
_model_info = model.pop("model_info", {}) _model_info: dict = model.pop("model_info", {})
# check if model info has id
if "id" not in _model_info:
_id = self._generate_model_id(_model_name, _litellm_params)
_model_info["id"] = _id
deployment = Deployment( deployment = Deployment(
**model, **model,
model_name=_model_name, model_name=_model_name,

View file

@ -3,6 +3,7 @@
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
import datetime as datetime_og
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
@ -59,7 +60,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
current_minute = datetime.now().strftime("%H-%M") current_minute = datetime.now(datetime_og.UTC).strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}" tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}" rpm_key = f"{model_group}:rpm:{current_minute}"
@ -109,7 +110,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
current_minute = datetime.now().strftime("%H-%M") current_minute = datetime.now(datetime_og.UTC).strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}" tpm_key = f"{id}:tpm:{current_minute}"
rpm_key = f"{id}:rpm:{current_minute}" rpm_key = f"{id}:rpm:{current_minute}"
@ -162,7 +165,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}" f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
) )
current_minute = datetime.now().strftime("%H-%M") current_minute = datetime.now(datetime_og.UTC).strftime("%H-%M")
tpm_keys = [] tpm_keys = []
rpm_keys = [] rpm_keys = []
for m in healthy_deployments: for m in healthy_deployments:

View file

@ -932,6 +932,35 @@ def test_openai_completion_on_router():
# test_openai_completion_on_router() # test_openai_completion_on_router()
def test_consistent_model_id():
"""
- For a given model group + litellm params, assert the model id is always the same
Test on `_generate_model_id`
Test on `set_model_list`
Test on `_add_deployment`
"""
model_group = "gpt-3.5-turbo"
litellm_params = {
"model": "openai/my-fake-model",
"api_key": "my-fake-key",
"api_base": "https://openai-function-calling-workers.tasslexyz.workers.dev/",
"stream_timeout": 0.001,
}
id1 = Router()._generate_model_id(
model_group=model_group, litellm_params=litellm_params
)
id2 = Router()._generate_model_id(
model_group=model_group, litellm_params=litellm_params
)
assert id1 == id2
def test_reading_keys_os_environ(): def test_reading_keys_os_environ():
import openai import openai

View file

@ -13,7 +13,7 @@ class ModelConfig(BaseModel):
rpm: int rpm: int
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
class RouterConfig(BaseModel): class RouterConfig(BaseModel):
@ -45,7 +45,8 @@ class RouterConfig(BaseModel):
] = "simple-shuffle" ] = "simple-shuffle"
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
id: Optional[ id: Optional[
@ -132,9 +133,11 @@ class Deployment(BaseModel):
litellm_params: LiteLLM_Params litellm_params: LiteLLM_Params
model_info: ModelInfo model_info: ModelInfo
def __init__(self, model_info: Optional[ModelInfo] = None, **params): def __init__(self, model_info: Optional[Union[ModelInfo, dict]] = None, **params):
if model_info is None: if model_info is None:
model_info = ModelInfo() model_info = ModelInfo()
elif isinstance(model_info, dict):
model_info = ModelInfo(**model_info)
super().__init__(model_info=model_info, **params) super().__init__(model_info=model_info, **params)
def to_json(self, **kwargs): def to_json(self, **kwargs):
@ -146,7 +149,7 @@ class Deployment(BaseModel):
class Config: class Config:
extra = "allow" extra = "allow"
protected_namespaces = () protected_namespaces = ()
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator