diff --git a/litellm/litellm_core_utils/logging_callback_manager.py b/litellm/litellm_core_utils/logging_callback_manager.py index 860a57c5f6..e55df44474 100644 --- a/litellm/litellm_core_utils/logging_callback_manager.py +++ b/litellm/litellm_core_utils/logging_callback_manager.py @@ -85,6 +85,21 @@ class LoggingCallbackManager: callback=callback, parent_list=litellm._async_failure_callback ) + def remove_callback_from_list_by_object( + self, callback_list, obj + ): + """ + Remove callbacks that are methods of a particular object (e.g., router cleanup) + """ + if not isinstance(callback_list, list): # Not list -> do nothing + return + + remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj] + + for c in remove_list: + callback_list.remove(c) + + def _add_string_callback_to_list( self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]] ): diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 82b9c9ba38..5465a24945 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -27,6 +27,7 @@ from typing_extensions import overload import litellm from litellm import LlmProviders from litellm._logging import verbose_logger +from litellm.constants import DEFAULT_MAX_RETRIES from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator @@ -320,6 +321,17 @@ class OpenAIChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() + def _set_dynamic_params_on_client( + self, + client: Union[OpenAI, AsyncOpenAI], + organization: Optional[str] = None, + max_retries: Optional[int] = None, + ): + if organization is not None: + client.organization = organization + if max_retries is not None: + client.max_retries = max_retries + def _get_openai_client( self, is_async: bool, @@ -327,11 +339,10 @@ class OpenAIChatCompletion(BaseLLM): api_base: Optional[str] = None, api_version: Optional[str] = None, timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), - max_retries: Optional[int] = 2, + max_retries: Optional[int] = DEFAULT_MAX_RETRIES, organization: Optional[str] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None, ): - args = locals() if client is None: if not isinstance(max_retries, int): raise OpenAIError( @@ -364,7 +375,6 @@ class OpenAIChatCompletion(BaseLLM): organization=organization, ) else: - _new_client = OpenAI( api_key=api_key, base_url=api_base, @@ -383,6 +393,11 @@ class OpenAIChatCompletion(BaseLLM): return _new_client else: + self._set_dynamic_params_on_client( + client=client, + organization=organization, + max_retries=max_retries, + ) return client @track_llm_api_timing() diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 94a8d6d50c..0000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 19cc8cdd07..5347d4a791 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -39,6 +39,7 @@ litellm_settings: general_settings: enable_jwt_auth: True + forward_openai_org_id: True litellm_jwtauth: user_id_jwt_field: "sub" team_ids_jwt_field: "groups" diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index b913c238db..5892b7afc6 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -238,6 +238,7 @@ class LiteLLMProxyRequestSetup: return None for header, value in headers.items(): if header.lower() == "openai-organization": + verbose_logger.info(f"found openai org id: {value}, sending to llm") return value return None diff --git a/litellm/proxy/route_llm_request.py b/litellm/proxy/route_llm_request.py index 1ff211c20c..6102a26b23 100644 --- a/litellm/proxy/route_llm_request.py +++ b/litellm/proxy/route_llm_request.py @@ -58,7 +58,9 @@ async def route_request( elif "user_config" in data: router_config = data.pop("user_config") user_router = litellm.Router(**router_config) - return getattr(user_router, f"{route_type}")(**data) + ret_val = getattr(user_router, f"{route_type}")(**data) + user_router.discard() + return ret_val elif ( route_type == "acompletion" diff --git a/litellm/router.py b/litellm/router.py index 597ba9fd06..bdac540f1a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -573,6 +573,21 @@ class Router: litellm.amoderation, call_type="moderation" ) + + def discard(self): + """ + Pseudo-destructor to be invoked to clean up global data structures when router is no longer used. + For now, unhook router's callbacks from all lists + """ + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_success_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.success_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_failure_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.failure_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.input_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.service_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.callbacks, self) + + def _update_redis_cache(self, cache: RedisCache): """ Update the redis cache for the router, if none set. @@ -587,6 +602,7 @@ class Router: if self.cache.redis_cache is None: self.cache.redis_cache = cache + def initialize_assistants_endpoint(self): ## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ## self.acreate_assistants = self.factory_function(litellm.acreate_assistants) diff --git a/tests/litellm_utils_tests/test_logging_callback_manager.py b/tests/litellm_utils_tests/test_logging_callback_manager.py index 71ffb18678..1b70631e4d 100644 --- a/tests/litellm_utils_tests/test_logging_callback_manager.py +++ b/tests/litellm_utils_tests/test_logging_callback_manager.py @@ -160,6 +160,39 @@ def test_async_callbacks(): assert async_failure in litellm._async_failure_callback +def test_remove_callback_from_list_by_object(): + manager = LoggingCallbackManager() + # Reset all callbacks + manager._reset_all_callbacks() + + def TestObject(): + def __init__(self): + manager.add_litellm_callback(self.callback) + manager.add_litellm_success_callback(self.callback) + manager.add_litellm_failure_callback(self.callback) + manager.add_litellm_async_success_callback(self.callback) + manager.add_litellm_async_failure_callback(self.callback) + + def callback(self): + pass + + obj = TestObject() + + manager.remove_callback_from_list_by_object(litellm.callbacks, obj) + manager.remove_callback_from_list_by_object(litellm.success_callback, obj) + manager.remove_callback_from_list_by_object(litellm.failure_callback, obj) + manager.remove_callback_from_list_by_object(litellm._async_success_callback, obj) + manager.remove_callback_from_list_by_object(litellm._async_failure_callback, obj) + + # Verify all callback lists are empty + assert len(litellm.callbacks) == 0 + assert len(litellm.success_callback) == 0 + assert len(litellm.failure_callback) == 0 + assert len(litellm._async_success_callback) == 0 + assert len(litellm._async_failure_callback) == 0 + + + def test_reset_callbacks(callback_manager): # Add various callbacks callback_manager.add_litellm_callback("test") diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index e02b47ec36..f12371baeb 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -918,6 +918,31 @@ def test_flush_cache(model_list): assert router.cache.get_cache("test") is None +def test_discard(model_list): + """ + Test that discard properly removes a Router from the callback lists + """ + litellm.callbacks = [] + litellm.success_callback = [] + litellm._async_success_callback = [] + litellm.failure_callback = [] + litellm._async_failure_callback = [] + litellm.input_callback = [] + litellm.service_callback = [] + + router = Router(model_list=model_list) + router.discard() + + # Verify all callback lists are empty + assert len(litellm.callbacks) == 0 + assert len(litellm.success_callback) == 0 + assert len(litellm.failure_callback) == 0 + assert len(litellm._async_success_callback) == 0 + assert len(litellm._async_failure_callback) == 0 + assert len(litellm.input_callback) == 0 + assert len(litellm.service_callback) == 0 + + def test_initialize_assistants_endpoint(model_list): """Test if the 'initialize_assistants_endpoint' function is working correctly""" router = Router(model_list=model_list)