mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
add + test provider specific routing
This commit is contained in:
parent
0b98959e6d
commit
18305b23f4
1 changed files with 17 additions and 0 deletions
|
@ -17,6 +17,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -310,6 +311,7 @@ class Router:
|
||||||
)
|
)
|
||||||
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
||||||
self.default_max_parallel_requests = default_max_parallel_requests
|
self.default_max_parallel_requests = default_max_parallel_requests
|
||||||
|
self.provider_default_deployments: Dict[str, List] = {}
|
||||||
|
|
||||||
if model_list is not None:
|
if model_list is not None:
|
||||||
model_list = copy.deepcopy(model_list)
|
model_list = copy.deepcopy(model_list)
|
||||||
|
@ -3607,6 +3609,10 @@ class Router:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
provider_specific_deployment = re.match(
|
||||||
|
f"{custom_llm_provider}/*", deployment.model_name
|
||||||
|
)
|
||||||
|
|
||||||
# Check if user is trying to use model_name == "*"
|
# Check if user is trying to use model_name == "*"
|
||||||
# this is a catch all model for their specific api key
|
# this is a catch all model for their specific api key
|
||||||
if deployment.model_name == "*":
|
if deployment.model_name == "*":
|
||||||
|
@ -3615,6 +3621,17 @@ class Router:
|
||||||
self.router_general_settings.pass_through_all_models = True
|
self.router_general_settings.pass_through_all_models = True
|
||||||
else:
|
else:
|
||||||
self.default_deployment = deployment.to_json(exclude_none=True)
|
self.default_deployment = deployment.to_json(exclude_none=True)
|
||||||
|
# Check if user is using provider specific wildcard routing
|
||||||
|
# example model_name = "databricks/*" or model_name = "anthropic/*"
|
||||||
|
elif provider_specific_deployment:
|
||||||
|
if custom_llm_provider in self.provider_default_deployments:
|
||||||
|
self.provider_default_deployments[custom_llm_provider].append(
|
||||||
|
deployment.to_json(exclude_none=True)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.provider_default_deployments[custom_llm_provider] = [
|
||||||
|
deployment.to_json(exclude_none=True)
|
||||||
|
]
|
||||||
|
|
||||||
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
||||||
data_sources = deployment.litellm_params.get("dataSources", []) or []
|
data_sources = deployment.litellm_params.get("dataSources", []) or []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue