mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(router.py): support multiple orgs in 1 model definition
Closes https://github.com/BerriAI/litellm/issues/3949
This commit is contained in:
parent
83b97d9763
commit
14b66c3daa
3 changed files with 86 additions and 27 deletions
|
@ -3543,6 +3543,22 @@ class Router:
|
||||||
|
|
||||||
return hash_object.hexdigest()
|
return hash_object.hexdigest()
|
||||||
|
|
||||||
|
def _create_deployment(
|
||||||
|
self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict
|
||||||
|
):
|
||||||
|
deployment = Deployment(
|
||||||
|
**model,
|
||||||
|
model_name=_model_name,
|
||||||
|
litellm_params=_litellm_params, # type: ignore
|
||||||
|
model_info=_model_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
deployment = self._add_deployment(deployment=deployment)
|
||||||
|
|
||||||
|
model = deployment.to_json(exclude_none=True)
|
||||||
|
|
||||||
|
self.model_list.append(model)
|
||||||
|
|
||||||
def set_model_list(self, model_list: list):
|
def set_model_list(self, model_list: list):
|
||||||
original_model_list = copy.deepcopy(model_list)
|
original_model_list = copy.deepcopy(model_list)
|
||||||
self.model_list = []
|
self.model_list = []
|
||||||
|
@ -3565,18 +3581,24 @@ class Router:
|
||||||
_id = self._generate_model_id(_model_name, _litellm_params)
|
_id = self._generate_model_id(_model_name, _litellm_params)
|
||||||
_model_info["id"] = _id
|
_model_info["id"] = _id
|
||||||
|
|
||||||
deployment = Deployment(
|
if _litellm_params.get("organization", None) is not None and isinstance(
|
||||||
**model,
|
_litellm_params["organization"], list
|
||||||
model_name=_model_name,
|
): # Addresses https://github.com/BerriAI/litellm/issues/3949
|
||||||
litellm_params=_litellm_params,
|
for org in _litellm_params["organization"]:
|
||||||
model_info=_model_info,
|
_litellm_params["organization"] = org
|
||||||
)
|
self._create_deployment(
|
||||||
|
model=model,
|
||||||
deployment = self._add_deployment(deployment=deployment)
|
_model_name=_model_name,
|
||||||
|
_litellm_params=_litellm_params,
|
||||||
model = deployment.to_json(exclude_none=True)
|
_model_info=_model_info,
|
||||||
|
)
|
||||||
self.model_list.append(model)
|
else:
|
||||||
|
self._create_deployment(
|
||||||
|
model=model,
|
||||||
|
_model_name=_model_name,
|
||||||
|
_litellm_params=_litellm_params,
|
||||||
|
_model_info=_model_info,
|
||||||
|
)
|
||||||
|
|
||||||
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
|
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
|
||||||
self.model_names = [m["model_name"] for m in model_list]
|
self.model_names = [m["model_name"] for m in model_list]
|
||||||
|
|
|
@ -1,24 +1,54 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests litellm router
|
# This tests litellm router
|
||||||
|
|
||||||
import sys, os, time, openai
|
import asyncio
|
||||||
import traceback, asyncio
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from collections import defaultdict
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import os, httpx
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_multi_org_list():
|
||||||
|
"""
|
||||||
|
Pass list of orgs in 1 model definition,
|
||||||
|
expect a unique deployment for each to be created
|
||||||
|
"""
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/*",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "https://api.openai.com/v1",
|
||||||
|
"organization": ["org-1", "org-2", "org-3"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(router.get_model_list()) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_router_sensitive_keys():
|
def test_router_sensitive_keys():
|
||||||
try:
|
try:
|
||||||
router = Router(
|
router = Router(
|
||||||
|
@ -527,9 +557,10 @@ def test_router_context_window_fallback():
|
||||||
- Send a 5k prompt
|
- Send a 5k prompt
|
||||||
- Assert it works
|
- Assert it works
|
||||||
"""
|
"""
|
||||||
from large_text import text
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from large_text import text
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
||||||
print(f"len(text): {len(text)}")
|
print(f"len(text): {len(text)}")
|
||||||
|
@ -577,9 +608,10 @@ async def test_async_router_context_window_fallback():
|
||||||
- Send a 5k prompt
|
- Send a 5k prompt
|
||||||
- Assert it works
|
- Assert it works
|
||||||
"""
|
"""
|
||||||
from large_text import text
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from large_text import text
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
||||||
print(f"len(text): {len(text)}")
|
print(f"len(text): {len(text)}")
|
||||||
|
@ -660,9 +692,10 @@ def test_router_context_window_check_pre_call_check_in_group():
|
||||||
- Send a 5k prompt
|
- Send a 5k prompt
|
||||||
- Assert it works
|
- Assert it works
|
||||||
"""
|
"""
|
||||||
from large_text import text
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from large_text import text
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
||||||
print(f"len(text): {len(text)}")
|
print(f"len(text): {len(text)}")
|
||||||
|
@ -708,9 +741,10 @@ def test_router_context_window_check_pre_call_check_out_group():
|
||||||
- Send a 5k prompt
|
- Send a 5k prompt
|
||||||
- Assert it works
|
- Assert it works
|
||||||
"""
|
"""
|
||||||
from large_text import text
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from large_text import text
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
||||||
print(f"len(text): {len(text)}")
|
print(f"len(text): {len(text)}")
|
||||||
|
@ -1536,9 +1570,10 @@ def test_router_anthropic_key_dynamic():
|
||||||
|
|
||||||
def test_router_timeout():
|
def test_router_timeout():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
from litellm._logging import verbose_logger
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
verbose_logger.setLevel(logging.DEBUG)
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
model_list = [
|
model_list = [
|
||||||
{
|
{
|
||||||
|
|
|
@ -2,12 +2,14 @@
|
||||||
litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
|
litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
|
import datetime
|
||||||
import uuid
|
|
||||||
import enum
|
import enum
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
import datetime
|
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
from .embedding import EmbeddingRequest
|
from .embedding import EmbeddingRequest
|
||||||
|
|
||||||
|
@ -293,7 +295,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
timeout: Optional[Union[float, str, httpx.Timeout]]
|
timeout: Optional[Union[float, str, httpx.Timeout]]
|
||||||
stream_timeout: Optional[Union[float, str]]
|
stream_timeout: Optional[Union[float, str]]
|
||||||
max_retries: Optional[int]
|
max_retries: Optional[int]
|
||||||
organization: Optional[str] # for openai orgs
|
organization: Optional[Union[List, str]] # for openai orgs
|
||||||
## DROP PARAMS ##
|
## DROP PARAMS ##
|
||||||
drop_params: Optional[bool]
|
drop_params: Optional[bool]
|
||||||
## UNIFIED PROJECT/REGION ##
|
## UNIFIED PROJECT/REGION ##
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue