mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Litellm dev 02 10 2025 p2 (#8443)
* Fixed issue #8246 (#8250) * Fixed issue #8246 * Added unit tests for discard() and for remove_callback_from_list_by_object() * fix(openai.py): support dynamic passing of organization param to openai handles scenario where client-side org id is passed to openai --------- Co-authored-by: Erez Hadad <erezh@il.ibm.com>
This commit is contained in:
parent
47f46f92c8
commit
e26d7df91b
9 changed files with 112 additions and 5 deletions
|
@ -85,6 +85,21 @@ class LoggingCallbackManager:
|
|||
callback=callback, parent_list=litellm._async_failure_callback
|
||||
)
|
||||
|
||||
def remove_callback_from_list_by_object(
|
||||
self, callback_list, obj
|
||||
):
|
||||
"""
|
||||
Remove callbacks that are methods of a particular object (e.g., router cleanup)
|
||||
"""
|
||||
if not isinstance(callback_list, list): # Not list -> do nothing
|
||||
return
|
||||
|
||||
remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj]
|
||||
|
||||
for c in remove_list:
|
||||
callback_list.remove(c)
|
||||
|
||||
|
||||
def _add_string_callback_to_list(
|
||||
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||
):
|
||||
|
|
|
@ -27,6 +27,7 @@ from typing_extensions import overload
|
|||
import litellm
|
||||
from litellm import LlmProviders
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import DEFAULT_MAX_RETRIES
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
|
@ -320,6 +321,17 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _set_dynamic_params_on_client(
|
||||
self,
|
||||
client: Union[OpenAI, AsyncOpenAI],
|
||||
organization: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
):
|
||||
if organization is not None:
|
||||
client.organization = organization
|
||||
if max_retries is not None:
|
||||
client.max_retries = max_retries
|
||||
|
||||
def _get_openai_client(
|
||||
self,
|
||||
is_async: bool,
|
||||
|
@ -327,11 +339,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
max_retries: Optional[int] = 2,
|
||||
max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||
):
|
||||
args = locals()
|
||||
if client is None:
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(
|
||||
|
@ -364,7 +375,6 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
organization=organization,
|
||||
)
|
||||
else:
|
||||
|
||||
_new_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
|
@ -383,6 +393,11 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
return _new_client
|
||||
|
||||
else:
|
||||
self._set_dynamic_params_on_client(
|
||||
client=client,
|
||||
organization=organization,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
return client
|
||||
|
||||
@track_llm_api_timing()
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -39,6 +39,7 @@ litellm_settings:
|
|||
|
||||
general_settings:
|
||||
enable_jwt_auth: True
|
||||
forward_openai_org_id: True
|
||||
litellm_jwtauth:
|
||||
user_id_jwt_field: "sub"
|
||||
team_ids_jwt_field: "groups"
|
||||
|
|
|
@ -238,6 +238,7 @@ class LiteLLMProxyRequestSetup:
|
|||
return None
|
||||
for header, value in headers.items():
|
||||
if header.lower() == "openai-organization":
|
||||
verbose_logger.info(f"found openai org id: {value}, sending to llm")
|
||||
return value
|
||||
return None
|
||||
|
||||
|
|
|
@ -58,7 +58,9 @@ async def route_request(
|
|||
elif "user_config" in data:
|
||||
router_config = data.pop("user_config")
|
||||
user_router = litellm.Router(**router_config)
|
||||
return getattr(user_router, f"{route_type}")(**data)
|
||||
ret_val = getattr(user_router, f"{route_type}")(**data)
|
||||
user_router.discard()
|
||||
return ret_val
|
||||
|
||||
elif (
|
||||
route_type == "acompletion"
|
||||
|
|
|
@ -573,6 +573,21 @@ class Router:
|
|||
litellm.amoderation, call_type="moderation"
|
||||
)
|
||||
|
||||
|
||||
def discard(self):
|
||||
"""
|
||||
Pseudo-destructor to be invoked to clean up global data structures when router is no longer used.
|
||||
For now, unhook router's callbacks from all lists
|
||||
"""
|
||||
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_success_callback, self)
|
||||
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.success_callback, self)
|
||||
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_failure_callback, self)
|
||||
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.failure_callback, self)
|
||||
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.input_callback, self)
|
||||
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.service_callback, self)
|
||||
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.callbacks, self)
|
||||
|
||||
|
||||
def _update_redis_cache(self, cache: RedisCache):
|
||||
"""
|
||||
Update the redis cache for the router, if none set.
|
||||
|
@ -587,6 +602,7 @@ class Router:
|
|||
if self.cache.redis_cache is None:
|
||||
self.cache.redis_cache = cache
|
||||
|
||||
|
||||
def initialize_assistants_endpoint(self):
|
||||
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
||||
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
||||
|
|
|
@ -160,6 +160,39 @@ def test_async_callbacks():
|
|||
assert async_failure in litellm._async_failure_callback
|
||||
|
||||
|
||||
def test_remove_callback_from_list_by_object():
|
||||
manager = LoggingCallbackManager()
|
||||
# Reset all callbacks
|
||||
manager._reset_all_callbacks()
|
||||
|
||||
def TestObject():
|
||||
def __init__(self):
|
||||
manager.add_litellm_callback(self.callback)
|
||||
manager.add_litellm_success_callback(self.callback)
|
||||
manager.add_litellm_failure_callback(self.callback)
|
||||
manager.add_litellm_async_success_callback(self.callback)
|
||||
manager.add_litellm_async_failure_callback(self.callback)
|
||||
|
||||
def callback(self):
|
||||
pass
|
||||
|
||||
obj = TestObject()
|
||||
|
||||
manager.remove_callback_from_list_by_object(litellm.callbacks, obj)
|
||||
manager.remove_callback_from_list_by_object(litellm.success_callback, obj)
|
||||
manager.remove_callback_from_list_by_object(litellm.failure_callback, obj)
|
||||
manager.remove_callback_from_list_by_object(litellm._async_success_callback, obj)
|
||||
manager.remove_callback_from_list_by_object(litellm._async_failure_callback, obj)
|
||||
|
||||
# Verify all callback lists are empty
|
||||
assert len(litellm.callbacks) == 0
|
||||
assert len(litellm.success_callback) == 0
|
||||
assert len(litellm.failure_callback) == 0
|
||||
assert len(litellm._async_success_callback) == 0
|
||||
assert len(litellm._async_failure_callback) == 0
|
||||
|
||||
|
||||
|
||||
def test_reset_callbacks(callback_manager):
|
||||
# Add various callbacks
|
||||
callback_manager.add_litellm_callback("test")
|
||||
|
|
|
@ -918,6 +918,31 @@ def test_flush_cache(model_list):
|
|||
assert router.cache.get_cache("test") is None
|
||||
|
||||
|
||||
def test_discard(model_list):
|
||||
"""
|
||||
Test that discard properly removes a Router from the callback lists
|
||||
"""
|
||||
litellm.callbacks = []
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm._async_failure_callback = []
|
||||
litellm.input_callback = []
|
||||
litellm.service_callback = []
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
router.discard()
|
||||
|
||||
# Verify all callback lists are empty
|
||||
assert len(litellm.callbacks) == 0
|
||||
assert len(litellm.success_callback) == 0
|
||||
assert len(litellm.failure_callback) == 0
|
||||
assert len(litellm._async_success_callback) == 0
|
||||
assert len(litellm._async_failure_callback) == 0
|
||||
assert len(litellm.input_callback) == 0
|
||||
assert len(litellm.service_callback) == 0
|
||||
|
||||
|
||||
def test_initialize_assistants_endpoint(model_list):
|
||||
"""Test if the 'initialize_assistants_endpoint' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue