From 18305b23f429c1c9e6c9fed123497a40442f2b15 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 7 Aug 2024 13:49:46 -0700 Subject: [PATCH] add + test provider specific routing --- litellm/router.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/litellm/router.py b/litellm/router.py index 5a4d83885..51fb12ea8 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -17,6 +17,7 @@ import inspect import json import logging import random +import re import threading import time import traceback @@ -310,6 +311,7 @@ class Router: ) self.default_deployment = None # use this to track the users default deployment, when they want to use model = * self.default_max_parallel_requests = default_max_parallel_requests + self.provider_default_deployments: Dict[str, List] = {} if model_list is not None: model_list = copy.deepcopy(model_list) @@ -3607,6 +3609,10 @@ class Router: ), ) + provider_specific_deployment = re.match( + f"{custom_llm_provider}/*", deployment.model_name + ) + # Check if user is trying to use model_name == "*" # this is a catch all model for their specific api key if deployment.model_name == "*": @@ -3615,6 +3621,17 @@ class Router: self.router_general_settings.pass_through_all_models = True else: self.default_deployment = deployment.to_json(exclude_none=True) + # Check if user is using provider specific wildcard routing + # example model_name = "databricks/*" or model_name = "anthropic/*" + elif provider_specific_deployment: + if custom_llm_provider in self.provider_default_deployments: + self.provider_default_deployments[custom_llm_provider].append( + deployment.to_json(exclude_none=True) + ) + else: + self.provider_default_deployments[custom_llm_provider] = [ + deployment.to_json(exclude_none=True) + ] # Azure GPT-Vision Enhancements, users can pass os.environ/ data_sources = deployment.litellm_params.get("dataSources", []) or []