diff --git a/litellm/router.py b/litellm/router.py index db38df29f0..9200089d5b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3543,6 +3543,22 @@ class Router: return hash_object.hexdigest() + def _create_deployment( + self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict + ): + deployment = Deployment( + **model, + model_name=_model_name, + litellm_params=_litellm_params, # type: ignore + model_info=_model_info, + ) + + deployment = self._add_deployment(deployment=deployment) + + model = deployment.to_json(exclude_none=True) + + self.model_list.append(model) + def set_model_list(self, model_list: list): original_model_list = copy.deepcopy(model_list) self.model_list = [] @@ -3565,18 +3581,24 @@ class Router: _id = self._generate_model_id(_model_name, _litellm_params) _model_info["id"] = _id - deployment = Deployment( - **model, - model_name=_model_name, - litellm_params=_litellm_params, - model_info=_model_info, - ) - - deployment = self._add_deployment(deployment=deployment) - - model = deployment.to_json(exclude_none=True) - - self.model_list.append(model) + if _litellm_params.get("organization", None) is not None and isinstance( + _litellm_params["organization"], list + ): # Addresses https://github.com/BerriAI/litellm/issues/3949 + for org in _litellm_params["organization"]: + _litellm_params["organization"] = org + self._create_deployment( + model=model, + _model_name=_model_name, + _litellm_params=_litellm_params, + _model_info=_model_info, + ) + else: + self._create_deployment( + model=model, + _model_name=_model_name, + _litellm_params=_litellm_params, + _model_info=_model_info, + ) verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}") self.model_names = [m["model_name"] for m in model_list] diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 9b52c7d57c..4cde1b55f7 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -1,24 +1,54 @@ #### What this tests #### # This tests litellm router -import sys, os, time, openai -import traceback, asyncio +import asyncio +import os +import sys +import time +import traceback + +import openai import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor + +import httpx +from dotenv import load_dotenv + import litellm from litellm import Router from litellm.router import Deployment, LiteLLM_Params, ModelInfo -from concurrent.futures import ThreadPoolExecutor -from collections import defaultdict -from dotenv import load_dotenv -import os, httpx load_dotenv() +def test_router_multi_org_list(): + """ + Pass list of orgs in 1 model definition, + expect a unique deployment for each to be created + """ + router = litellm.Router( + model_list=[ + { + "model_name": "*", + "litellm_params": { + "model": "openai/*", + "api_key": "my-key", + "api_base": "https://api.openai.com/v1", + "organization": ["org-1", "org-2", "org-3"], + }, + } + ] + ) + + assert len(router.get_model_list()) == 3 + + def test_router_sensitive_keys(): try: router = Router( @@ -527,9 +557,10 @@ def test_router_context_window_fallback(): - Send a 5k prompt - Assert it works """ - from large_text import text import os + from large_text import text + litellm.set_verbose = False print(f"len(text): {len(text)}") @@ -577,9 +608,10 @@ async def test_async_router_context_window_fallback(): - Send a 5k prompt - Assert it works """ - from large_text import text import os + from large_text import text + litellm.set_verbose = False print(f"len(text): {len(text)}") @@ -660,9 +692,10 @@ def test_router_context_window_check_pre_call_check_in_group(): - Send a 5k prompt - Assert it works """ - from large_text import text import os + from large_text import text + litellm.set_verbose = False print(f"len(text): {len(text)}") @@ -708,9 +741,10 @@ def test_router_context_window_check_pre_call_check_out_group(): - Send a 5k prompt - Assert it works """ - from large_text import text import os + from large_text import text + litellm.set_verbose = False print(f"len(text): {len(text)}") @@ -1536,9 +1570,10 @@ def test_router_anthropic_key_dynamic(): def test_router_timeout(): litellm.set_verbose = True - from litellm._logging import verbose_logger import logging + from litellm._logging import verbose_logger + verbose_logger.setLevel(logging.DEBUG) model_list = [ { diff --git a/litellm/types/router.py b/litellm/types/router.py index aa63e95f54..da3c999dc8 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -2,12 +2,14 @@ litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc """ -from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict -import uuid +import datetime import enum +import uuid +from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union + import httpx from pydantic import BaseModel, ConfigDict, Field -import datetime + from .completion import CompletionRequest from .embedding import EmbeddingRequest @@ -293,7 +295,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): timeout: Optional[Union[float, str, httpx.Timeout]] stream_timeout: Optional[Union[float, str]] max_retries: Optional[int] - organization: Optional[str] # for openai orgs + organization: Optional[Union[List, str]] # for openai orgs ## DROP PARAMS ## drop_params: Optional[bool] ## UNIFIED PROJECT/REGION ##