Litellm dev 02 10 2025 p2 (#8443)

* Fixed issue #8246 (#8250)

* Fixed issue #8246

* Added unit tests for discard() and for remove_callback_from_list_by_object()

* fix(openai.py): support dynamic passing of organization param to openai

handles scenario where client-side org id is passed to openai

---------

Co-authored-by: Erez Hadad <erezh@il.ibm.com>
This commit is contained in:
Krish Dholakia 2025-02-10 17:53:46 -08:00 committed by GitHub
parent 47f46f92c8
commit e26d7df91b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 112 additions and 5 deletions

View file

@ -85,6 +85,21 @@ class LoggingCallbackManager:
callback=callback, parent_list=litellm._async_failure_callback
)
def remove_callback_from_list_by_object(
self, callback_list, obj
):
"""
Remove callbacks that are methods of a particular object (e.g., router cleanup)
"""
if not isinstance(callback_list, list): # Not list -> do nothing
return
remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj]
for c in remove_list:
callback_list.remove(c)
def _add_string_callback_to_list(
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
):

View file

@ -27,6 +27,7 @@ from typing_extensions import overload
import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.constants import DEFAULT_MAX_RETRIES
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
@ -320,6 +321,17 @@ class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _set_dynamic_params_on_client(
self,
client: Union[OpenAI, AsyncOpenAI],
organization: Optional[str] = None,
max_retries: Optional[int] = None,
):
if organization is not None:
client.organization = organization
if max_retries is not None:
client.max_retries = max_retries
def _get_openai_client(
self,
is_async: bool,
@ -327,11 +339,10 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str] = None,
api_version: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2,
max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
args = locals()
if client is None:
if not isinstance(max_retries, int):
raise OpenAIError(
@ -364,7 +375,6 @@ class OpenAIChatCompletion(BaseLLM):
organization=organization,
)
else:
_new_client = OpenAI(
api_key=api_key,
base_url=api_base,
@ -383,6 +393,11 @@ class OpenAIChatCompletion(BaseLLM):
return _new_client
else:
self._set_dynamic_params_on_client(
client=client,
organization=organization,
max_retries=max_retries,
)
return client
@track_llm_api_timing()

File diff suppressed because one or more lines are too long

View file

@ -39,6 +39,7 @@ litellm_settings:
general_settings:
enable_jwt_auth: True
forward_openai_org_id: True
litellm_jwtauth:
user_id_jwt_field: "sub"
team_ids_jwt_field: "groups"

View file

@ -238,6 +238,7 @@ class LiteLLMProxyRequestSetup:
return None
for header, value in headers.items():
if header.lower() == "openai-organization":
verbose_logger.info(f"found openai org id: {value}, sending to llm")
return value
return None

View file

@ -58,7 +58,9 @@ async def route_request(
elif "user_config" in data:
router_config = data.pop("user_config")
user_router = litellm.Router(**router_config)
return getattr(user_router, f"{route_type}")(**data)
ret_val = getattr(user_router, f"{route_type}")(**data)
user_router.discard()
return ret_val
elif (
route_type == "acompletion"

View file

@ -573,6 +573,21 @@ class Router:
litellm.amoderation, call_type="moderation"
)
def discard(self):
"""
Pseudo-destructor to be invoked to clean up global data structures when router is no longer used.
For now, unhook router's callbacks from all lists
"""
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_success_callback, self)
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.success_callback, self)
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_failure_callback, self)
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.failure_callback, self)
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.input_callback, self)
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.service_callback, self)
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.callbacks, self)
def _update_redis_cache(self, cache: RedisCache):
"""
Update the redis cache for the router, if none set.
@ -587,6 +602,7 @@ class Router:
if self.cache.redis_cache is None:
self.cache.redis_cache = cache
def initialize_assistants_endpoint(self):
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)

View file

@ -160,6 +160,39 @@ def test_async_callbacks():
assert async_failure in litellm._async_failure_callback
def test_remove_callback_from_list_by_object():
manager = LoggingCallbackManager()
# Reset all callbacks
manager._reset_all_callbacks()
def TestObject():
def __init__(self):
manager.add_litellm_callback(self.callback)
manager.add_litellm_success_callback(self.callback)
manager.add_litellm_failure_callback(self.callback)
manager.add_litellm_async_success_callback(self.callback)
manager.add_litellm_async_failure_callback(self.callback)
def callback(self):
pass
obj = TestObject()
manager.remove_callback_from_list_by_object(litellm.callbacks, obj)
manager.remove_callback_from_list_by_object(litellm.success_callback, obj)
manager.remove_callback_from_list_by_object(litellm.failure_callback, obj)
manager.remove_callback_from_list_by_object(litellm._async_success_callback, obj)
manager.remove_callback_from_list_by_object(litellm._async_failure_callback, obj)
# Verify all callback lists are empty
assert len(litellm.callbacks) == 0
assert len(litellm.success_callback) == 0
assert len(litellm.failure_callback) == 0
assert len(litellm._async_success_callback) == 0
assert len(litellm._async_failure_callback) == 0
def test_reset_callbacks(callback_manager):
# Add various callbacks
callback_manager.add_litellm_callback("test")

View file

@ -918,6 +918,31 @@ def test_flush_cache(model_list):
assert router.cache.get_cache("test") is None
def test_discard(model_list):
"""
Test that discard properly removes a Router from the callback lists
"""
litellm.callbacks = []
litellm.success_callback = []
litellm._async_success_callback = []
litellm.failure_callback = []
litellm._async_failure_callback = []
litellm.input_callback = []
litellm.service_callback = []
router = Router(model_list=model_list)
router.discard()
# Verify all callback lists are empty
assert len(litellm.callbacks) == 0
assert len(litellm.success_callback) == 0
assert len(litellm.failure_callback) == 0
assert len(litellm._async_success_callback) == 0
assert len(litellm._async_failure_callback) == 0
assert len(litellm.input_callback) == 0
assert len(litellm.service_callback) == 0
def test_initialize_assistants_endpoint(model_list):
"""Test if the 'initialize_assistants_endpoint' function is working correctly"""
router = Router(model_list=model_list)