(feat) router - add model_group_alias_map

This commit is contained in:
ishaan-jaff 2023-12-06 20:13:22 -08:00
parent 9573123e2b
commit ee70c4e822
3 changed files with 57 additions and 1 deletions

View file

@ -45,6 +45,7 @@ caching: bool = False # Not used anymore, will be removed in next MAJOR release
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
model_alias_map: Dict[str, str] = {} model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers max_budget: float = 0.0 # set the max budget across all providers
_current_cost = 0 # private variable, used if max budget is set _current_cost = 0 # private variable, used if max budget is set
error_logs: Dict = {} error_logs: Dict = {}

View file

@ -1101,6 +1101,11 @@ class Router:
return deployment return deployment
raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}") raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}")
# check if aliases set on litellm model alias map
if model in litellm.model_group_alias_map:
self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {litellm.model_group_alias_map.get(model)}")
model = litellm.model_group_alias_map.get(model)
## get healthy deployments ## get healthy deployments
### get all deployments ### get all deployments
### filter out the deployments currently cooling down ### filter out the deployments currently cooling down

View file

@ -284,4 +284,54 @@ def test_weighted_selection_router_no_rpm_set():
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_weighted_selection_router_no_rpm_set() # test_weighted_selection_router_no_rpm_set()
def test_model_group_aliases():
try:
litellm.set_verbose = False
litellm.model_group_alias_map = {"gpt-4": "gpt-3.5-turbo"}
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
"rpm": 6,
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
"rpm": 1440,
},
},
{
"model_name": "claude-1",
"litellm_params": {
"model": "bedrock/claude1.2",
"rpm": 1440,
},
}
]
router = Router(
model_list=model_list,
)
for _ in range(20):
selected_model = router.get_available_deployment("gpt-4")
print("\n selected model", selected_model)
selected_model_name = selected_model.get("model_name")
if selected_model_name is not "gpt-3.5-turbo":
pytest.fail(f"Selected model {selected_model_name} is not gpt-3.5-turbo")
router.reset()
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
# test_model_group_aliases()