add + test provider specific routing

This commit is contained in:
Ishaan Jaff 2024-08-07 13:49:46 -07:00
parent da7469296a
commit 6a1a4eb822

View file

@ -17,6 +17,7 @@ import inspect
import json
import logging
import random
import re
import threading
import time
import traceback
@ -310,6 +311,7 @@ class Router:
)
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests
self.provider_default_deployments: Dict[str, List] = {}
if model_list is not None:
model_list = copy.deepcopy(model_list)
@ -3607,6 +3609,10 @@ class Router:
),
)
provider_specific_deployment = re.match(
f"{custom_llm_provider}/*", deployment.model_name
)
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if deployment.model_name == "*":
@ -3615,6 +3621,17 @@ class Router:
self.router_general_settings.pass_through_all_models = True
else:
self.default_deployment = deployment.to_json(exclude_none=True)
# Check if user is using provider specific wildcard routing
# example model_name = "databricks/*" or model_name = "anthropic/*"
elif provider_specific_deployment:
if custom_llm_provider in self.provider_default_deployments:
self.provider_default_deployments[custom_llm_provider].append(
deployment.to_json(exclude_none=True)
)
else:
self.provider_default_deployments[custom_llm_provider] = [
deployment.to_json(exclude_none=True)
]
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = deployment.litellm_params.get("dataSources", []) or []