forked from phoenix/litellm-mirror
fix(router.py): support multiple orgs in 1 model definition
Closes https://github.com/BerriAI/litellm/issues/3949
This commit is contained in:
parent
83b97d9763
commit
14b66c3daa
3 changed files with 86 additions and 27 deletions
|
@ -3543,6 +3543,22 @@ class Router:
|
|||
|
||||
return hash_object.hexdigest()
|
||||
|
||||
def _create_deployment(
|
||||
self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict
|
||||
):
|
||||
deployment = Deployment(
|
||||
**model,
|
||||
model_name=_model_name,
|
||||
litellm_params=_litellm_params, # type: ignore
|
||||
model_info=_model_info,
|
||||
)
|
||||
|
||||
deployment = self._add_deployment(deployment=deployment)
|
||||
|
||||
model = deployment.to_json(exclude_none=True)
|
||||
|
||||
self.model_list.append(model)
|
||||
|
||||
def set_model_list(self, model_list: list):
|
||||
original_model_list = copy.deepcopy(model_list)
|
||||
self.model_list = []
|
||||
|
@ -3565,18 +3581,24 @@ class Router:
|
|||
_id = self._generate_model_id(_model_name, _litellm_params)
|
||||
_model_info["id"] = _id
|
||||
|
||||
deployment = Deployment(
|
||||
**model,
|
||||
model_name=_model_name,
|
||||
litellm_params=_litellm_params,
|
||||
model_info=_model_info,
|
||||
)
|
||||
|
||||
deployment = self._add_deployment(deployment=deployment)
|
||||
|
||||
model = deployment.to_json(exclude_none=True)
|
||||
|
||||
self.model_list.append(model)
|
||||
if _litellm_params.get("organization", None) is not None and isinstance(
|
||||
_litellm_params["organization"], list
|
||||
): # Addresses https://github.com/BerriAI/litellm/issues/3949
|
||||
for org in _litellm_params["organization"]:
|
||||
_litellm_params["organization"] = org
|
||||
self._create_deployment(
|
||||
model=model,
|
||||
_model_name=_model_name,
|
||||
_litellm_params=_litellm_params,
|
||||
_model_info=_model_info,
|
||||
)
|
||||
else:
|
||||
self._create_deployment(
|
||||
model=model,
|
||||
_model_name=_model_name,
|
||||
_litellm_params=_litellm_params,
|
||||
_model_info=_model_info,
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
|
||||
self.model_names = [m["model_name"] for m in model_list]
|
||||
|
|
|
@ -1,24 +1,54 @@
|
|||
#### What this tests ####
|
||||
# This tests litellm router
|
||||
|
||||
import sys, os, time, openai
|
||||
import traceback, asyncio
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from dotenv import load_dotenv
|
||||
import os, httpx
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def test_router_multi_org_list():
|
||||
"""
|
||||
Pass list of orgs in 1 model definition,
|
||||
expect a unique deployment for each to be created
|
||||
"""
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": "my-key",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"organization": ["org-1", "org-2", "org-3"],
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert len(router.get_model_list()) == 3
|
||||
|
||||
|
||||
def test_router_sensitive_keys():
|
||||
try:
|
||||
router = Router(
|
||||
|
@ -527,9 +557,10 @@ def test_router_context_window_fallback():
|
|||
- Send a 5k prompt
|
||||
- Assert it works
|
||||
"""
|
||||
from large_text import text
|
||||
import os
|
||||
|
||||
from large_text import text
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
print(f"len(text): {len(text)}")
|
||||
|
@ -577,9 +608,10 @@ async def test_async_router_context_window_fallback():
|
|||
- Send a 5k prompt
|
||||
- Assert it works
|
||||
"""
|
||||
from large_text import text
|
||||
import os
|
||||
|
||||
from large_text import text
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
print(f"len(text): {len(text)}")
|
||||
|
@ -660,9 +692,10 @@ def test_router_context_window_check_pre_call_check_in_group():
|
|||
- Send a 5k prompt
|
||||
- Assert it works
|
||||
"""
|
||||
from large_text import text
|
||||
import os
|
||||
|
||||
from large_text import text
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
print(f"len(text): {len(text)}")
|
||||
|
@ -708,9 +741,10 @@ def test_router_context_window_check_pre_call_check_out_group():
|
|||
- Send a 5k prompt
|
||||
- Assert it works
|
||||
"""
|
||||
from large_text import text
|
||||
import os
|
||||
|
||||
from large_text import text
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
print(f"len(text): {len(text)}")
|
||||
|
@ -1536,9 +1570,10 @@ def test_router_anthropic_key_dynamic():
|
|||
|
||||
def test_router_timeout():
|
||||
litellm.set_verbose = True
|
||||
from litellm._logging import verbose_logger
|
||||
import logging
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
model_list = [
|
||||
{
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
|
||||
import uuid
|
||||
import datetime
|
||||
import enum
|
||||
import uuid
|
||||
from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
import datetime
|
||||
|
||||
from .completion import CompletionRequest
|
||||
from .embedding import EmbeddingRequest
|
||||
|
||||
|
@ -293,7 +295,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
timeout: Optional[Union[float, str, httpx.Timeout]]
|
||||
stream_timeout: Optional[Union[float, str]]
|
||||
max_retries: Optional[int]
|
||||
organization: Optional[str] # for openai orgs
|
||||
organization: Optional[Union[List, str]] # for openai orgs
|
||||
## DROP PARAMS ##
|
||||
drop_params: Optional[bool]
|
||||
## UNIFIED PROJECT/REGION ##
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue