fix(router.py): support multiple orgs in 1 model definition

Closes https://github.com/BerriAI/litellm/issues/3949
This commit is contained in:
Krrish Dholakia 2024-06-18 19:36:58 -07:00
parent 83b97d9763
commit 14b66c3daa
3 changed files with 86 additions and 27 deletions

View file

@ -3543,6 +3543,22 @@ class Router:
return hash_object.hexdigest() return hash_object.hexdigest()
def _create_deployment(
self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict
):
deployment = Deployment(
**model,
model_name=_model_name,
litellm_params=_litellm_params, # type: ignore
model_info=_model_info,
)
deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True)
self.model_list.append(model)
def set_model_list(self, model_list: list): def set_model_list(self, model_list: list):
original_model_list = copy.deepcopy(model_list) original_model_list = copy.deepcopy(model_list)
self.model_list = [] self.model_list = []
@ -3565,18 +3581,24 @@ class Router:
_id = self._generate_model_id(_model_name, _litellm_params) _id = self._generate_model_id(_model_name, _litellm_params)
_model_info["id"] = _id _model_info["id"] = _id
deployment = Deployment( if _litellm_params.get("organization", None) is not None and isinstance(
**model, _litellm_params["organization"], list
model_name=_model_name, ): # Addresses https://github.com/BerriAI/litellm/issues/3949
litellm_params=_litellm_params, for org in _litellm_params["organization"]:
model_info=_model_info, _litellm_params["organization"] = org
) self._create_deployment(
model=model,
deployment = self._add_deployment(deployment=deployment) _model_name=_model_name,
_litellm_params=_litellm_params,
model = deployment.to_json(exclude_none=True) _model_info=_model_info,
)
self.model_list.append(model) else:
self._create_deployment(
model=model,
_model_name=_model_name,
_litellm_params=_litellm_params,
_model_info=_model_info,
)
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}") verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
self.model_names = [m["model_name"] for m in model_list] self.model_names = [m["model_name"] for m in model_list]

View file

@ -1,24 +1,54 @@
#### What this tests #### #### What this tests ####
# This tests litellm router # This tests litellm router
import sys, os, time, openai import asyncio
import traceback, asyncio import os
import sys
import time
import traceback
import openai
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import httpx
from dotenv import load_dotenv
import litellm import litellm
from litellm import Router from litellm import Router
from litellm.router import Deployment, LiteLLM_Params, ModelInfo from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
import os, httpx
load_dotenv() load_dotenv()
def test_router_multi_org_list():
"""
Pass list of orgs in 1 model definition,
expect a unique deployment for each to be created
"""
router = litellm.Router(
model_list=[
{
"model_name": "*",
"litellm_params": {
"model": "openai/*",
"api_key": "my-key",
"api_base": "https://api.openai.com/v1",
"organization": ["org-1", "org-2", "org-3"],
},
}
]
)
assert len(router.get_model_list()) == 3
def test_router_sensitive_keys(): def test_router_sensitive_keys():
try: try:
router = Router( router = Router(
@ -527,9 +557,10 @@ def test_router_context_window_fallback():
- Send a 5k prompt - Send a 5k prompt
- Assert it works - Assert it works
""" """
from large_text import text
import os import os
from large_text import text
litellm.set_verbose = False litellm.set_verbose = False
print(f"len(text): {len(text)}") print(f"len(text): {len(text)}")
@ -577,9 +608,10 @@ async def test_async_router_context_window_fallback():
- Send a 5k prompt - Send a 5k prompt
- Assert it works - Assert it works
""" """
from large_text import text
import os import os
from large_text import text
litellm.set_verbose = False litellm.set_verbose = False
print(f"len(text): {len(text)}") print(f"len(text): {len(text)}")
@ -660,9 +692,10 @@ def test_router_context_window_check_pre_call_check_in_group():
- Send a 5k prompt - Send a 5k prompt
- Assert it works - Assert it works
""" """
from large_text import text
import os import os
from large_text import text
litellm.set_verbose = False litellm.set_verbose = False
print(f"len(text): {len(text)}") print(f"len(text): {len(text)}")
@ -708,9 +741,10 @@ def test_router_context_window_check_pre_call_check_out_group():
- Send a 5k prompt - Send a 5k prompt
- Assert it works - Assert it works
""" """
from large_text import text
import os import os
from large_text import text
litellm.set_verbose = False litellm.set_verbose = False
print(f"len(text): {len(text)}") print(f"len(text): {len(text)}")
@ -1536,9 +1570,10 @@ def test_router_anthropic_key_dynamic():
def test_router_timeout(): def test_router_timeout():
litellm.set_verbose = True litellm.set_verbose = True
from litellm._logging import verbose_logger
import logging import logging
from litellm._logging import verbose_logger
verbose_logger.setLevel(logging.DEBUG) verbose_logger.setLevel(logging.DEBUG)
model_list = [ model_list = [
{ {

View file

@ -2,12 +2,14 @@
litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
""" """
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict import datetime
import uuid
import enum import enum
import uuid
from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union
import httpx import httpx
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
import datetime
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
@ -293,7 +295,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
timeout: Optional[Union[float, str, httpx.Timeout]] timeout: Optional[Union[float, str, httpx.Timeout]]
stream_timeout: Optional[Union[float, str]] stream_timeout: Optional[Union[float, str]]
max_retries: Optional[int] max_retries: Optional[int]
organization: Optional[str] # for openai orgs organization: Optional[Union[List, str]] # for openai orgs
## DROP PARAMS ## ## DROP PARAMS ##
drop_params: Optional[bool] drop_params: Optional[bool]
## UNIFIED PROJECT/REGION ## ## UNIFIED PROJECT/REGION ##