fix(router.py): support multiple orgs in 1 model definition

Closes https://github.com/BerriAI/litellm/issues/3949
This commit is contained in:
Krrish Dholakia 2024-06-18 19:36:58 -07:00
parent 83b97d9763
commit 14b66c3daa
3 changed files with 86 additions and 27 deletions

View file

@ -3543,6 +3543,22 @@ class Router:
return hash_object.hexdigest()
def _create_deployment(
self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict
):
deployment = Deployment(
**model,
model_name=_model_name,
litellm_params=_litellm_params, # type: ignore
model_info=_model_info,
)
deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True)
self.model_list.append(model)
def set_model_list(self, model_list: list):
original_model_list = copy.deepcopy(model_list)
self.model_list = []
@ -3565,18 +3581,24 @@ class Router:
_id = self._generate_model_id(_model_name, _litellm_params)
_model_info["id"] = _id
deployment = Deployment(
**model,
model_name=_model_name,
litellm_params=_litellm_params,
model_info=_model_info,
if _litellm_params.get("organization", None) is not None and isinstance(
_litellm_params["organization"], list
): # Addresses https://github.com/BerriAI/litellm/issues/3949
for org in _litellm_params["organization"]:
_litellm_params["organization"] = org
self._create_deployment(
model=model,
_model_name=_model_name,
_litellm_params=_litellm_params,
_model_info=_model_info,
)
else:
self._create_deployment(
model=model,
_model_name=_model_name,
_litellm_params=_litellm_params,
_model_info=_model_info,
)
deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True)
self.model_list.append(model)
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
self.model_names = [m["model_name"] for m in model_list]

View file

@ -1,24 +1,54 @@
#### What this tests ####
# This tests litellm router
import sys, os, time, openai
import traceback, asyncio
import asyncio
import os
import sys
import time
import traceback
import openai
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import httpx
from dotenv import load_dotenv
import litellm
from litellm import Router
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
import os, httpx
load_dotenv()
def test_router_multi_org_list():
"""
Pass list of orgs in 1 model definition,
expect a unique deployment for each to be created
"""
router = litellm.Router(
model_list=[
{
"model_name": "*",
"litellm_params": {
"model": "openai/*",
"api_key": "my-key",
"api_base": "https://api.openai.com/v1",
"organization": ["org-1", "org-2", "org-3"],
},
}
]
)
assert len(router.get_model_list()) == 3
def test_router_sensitive_keys():
try:
router = Router(
@ -527,9 +557,10 @@ def test_router_context_window_fallback():
- Send a 5k prompt
- Assert it works
"""
from large_text import text
import os
from large_text import text
litellm.set_verbose = False
print(f"len(text): {len(text)}")
@ -577,9 +608,10 @@ async def test_async_router_context_window_fallback():
- Send a 5k prompt
- Assert it works
"""
from large_text import text
import os
from large_text import text
litellm.set_verbose = False
print(f"len(text): {len(text)}")
@ -660,9 +692,10 @@ def test_router_context_window_check_pre_call_check_in_group():
- Send a 5k prompt
- Assert it works
"""
from large_text import text
import os
from large_text import text
litellm.set_verbose = False
print(f"len(text): {len(text)}")
@ -708,9 +741,10 @@ def test_router_context_window_check_pre_call_check_out_group():
- Send a 5k prompt
- Assert it works
"""
from large_text import text
import os
from large_text import text
litellm.set_verbose = False
print(f"len(text): {len(text)}")
@ -1536,9 +1570,10 @@ def test_router_anthropic_key_dynamic():
def test_router_timeout():
litellm.set_verbose = True
from litellm._logging import verbose_logger
import logging
from litellm._logging import verbose_logger
verbose_logger.setLevel(logging.DEBUG)
model_list = [
{

View file

@ -2,12 +2,14 @@
litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
"""
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import uuid
import datetime
import enum
import uuid
from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union
import httpx
from pydantic import BaseModel, ConfigDict, Field
import datetime
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
@ -293,7 +295,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
timeout: Optional[Union[float, str, httpx.Timeout]]
stream_timeout: Optional[Union[float, str]]
max_retries: Optional[int]
organization: Optional[str] # for openai orgs
organization: Optional[Union[List, str]] # for openai orgs
## DROP PARAMS ##
drop_params: Optional[bool]
## UNIFIED PROJECT/REGION ##