mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 02 10 2025 p2 (#8443)
* Fixed issue #8246 (#8250) * Fixed issue #8246 * Added unit tests for discard() and for remove_callback_from_list_by_object() * fix(openai.py): support dynamic passing of organization param to openai handles scenario where client-side org id is passed to openai --------- Co-authored-by: Erez Hadad <erezh@il.ibm.com>
This commit is contained in:
parent
47f46f92c8
commit
e26d7df91b
9 changed files with 112 additions and 5 deletions
|
@ -85,6 +85,21 @@ class LoggingCallbackManager:
|
||||||
callback=callback, parent_list=litellm._async_failure_callback
|
callback=callback, parent_list=litellm._async_failure_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def remove_callback_from_list_by_object(
|
||||||
|
self, callback_list, obj
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Remove callbacks that are methods of a particular object (e.g., router cleanup)
|
||||||
|
"""
|
||||||
|
if not isinstance(callback_list, list): # Not list -> do nothing
|
||||||
|
return
|
||||||
|
|
||||||
|
remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj]
|
||||||
|
|
||||||
|
for c in remove_list:
|
||||||
|
callback_list.remove(c)
|
||||||
|
|
||||||
|
|
||||||
def _add_string_callback_to_list(
|
def _add_string_callback_to_list(
|
||||||
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
|
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||||
):
|
):
|
||||||
|
|
|
@ -27,6 +27,7 @@ from typing_extensions import overload
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import LlmProviders
|
from litellm import LlmProviders
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.constants import DEFAULT_MAX_RETRIES
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
|
@ -320,6 +321,17 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def _set_dynamic_params_on_client(
|
||||||
|
self,
|
||||||
|
client: Union[OpenAI, AsyncOpenAI],
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
max_retries: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if organization is not None:
|
||||||
|
client.organization = organization
|
||||||
|
if max_retries is not None:
|
||||||
|
client.max_retries = max_retries
|
||||||
|
|
||||||
def _get_openai_client(
|
def _get_openai_client(
|
||||||
self,
|
self,
|
||||||
is_async: bool,
|
is_async: bool,
|
||||||
|
@ -327,11 +339,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||||
max_retries: Optional[int] = 2,
|
max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||||
):
|
):
|
||||||
args = locals()
|
|
||||||
if client is None:
|
if client is None:
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
|
@ -364,7 +375,6 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
organization=organization,
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
_new_client = OpenAI(
|
_new_client = OpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
|
@ -383,6 +393,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
return _new_client
|
return _new_client
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
self._set_dynamic_params_on_client(
|
||||||
|
client=client,
|
||||||
|
organization=organization,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@track_llm_api_timing()
|
@track_llm_api_timing()
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -39,6 +39,7 @@ litellm_settings:
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
|
forward_openai_org_id: True
|
||||||
litellm_jwtauth:
|
litellm_jwtauth:
|
||||||
user_id_jwt_field: "sub"
|
user_id_jwt_field: "sub"
|
||||||
team_ids_jwt_field: "groups"
|
team_ids_jwt_field: "groups"
|
||||||
|
|
|
@ -238,6 +238,7 @@ class LiteLLMProxyRequestSetup:
|
||||||
return None
|
return None
|
||||||
for header, value in headers.items():
|
for header, value in headers.items():
|
||||||
if header.lower() == "openai-organization":
|
if header.lower() == "openai-organization":
|
||||||
|
verbose_logger.info(f"found openai org id: {value}, sending to llm")
|
||||||
return value
|
return value
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,9 @@ async def route_request(
|
||||||
elif "user_config" in data:
|
elif "user_config" in data:
|
||||||
router_config = data.pop("user_config")
|
router_config = data.pop("user_config")
|
||||||
user_router = litellm.Router(**router_config)
|
user_router = litellm.Router(**router_config)
|
||||||
return getattr(user_router, f"{route_type}")(**data)
|
ret_val = getattr(user_router, f"{route_type}")(**data)
|
||||||
|
user_router.discard()
|
||||||
|
return ret_val
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
route_type == "acompletion"
|
route_type == "acompletion"
|
||||||
|
|
|
@ -573,6 +573,21 @@ class Router:
|
||||||
litellm.amoderation, call_type="moderation"
|
litellm.amoderation, call_type="moderation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def discard(self):
|
||||||
|
"""
|
||||||
|
Pseudo-destructor to be invoked to clean up global data structures when router is no longer used.
|
||||||
|
For now, unhook router's callbacks from all lists
|
||||||
|
"""
|
||||||
|
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_success_callback, self)
|
||||||
|
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.success_callback, self)
|
||||||
|
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_failure_callback, self)
|
||||||
|
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.failure_callback, self)
|
||||||
|
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.input_callback, self)
|
||||||
|
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.service_callback, self)
|
||||||
|
litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.callbacks, self)
|
||||||
|
|
||||||
|
|
||||||
def _update_redis_cache(self, cache: RedisCache):
|
def _update_redis_cache(self, cache: RedisCache):
|
||||||
"""
|
"""
|
||||||
Update the redis cache for the router, if none set.
|
Update the redis cache for the router, if none set.
|
||||||
|
@ -587,6 +602,7 @@ class Router:
|
||||||
if self.cache.redis_cache is None:
|
if self.cache.redis_cache is None:
|
||||||
self.cache.redis_cache = cache
|
self.cache.redis_cache = cache
|
||||||
|
|
||||||
|
|
||||||
def initialize_assistants_endpoint(self):
|
def initialize_assistants_endpoint(self):
|
||||||
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
||||||
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
||||||
|
|
|
@ -160,6 +160,39 @@ def test_async_callbacks():
|
||||||
assert async_failure in litellm._async_failure_callback
|
assert async_failure in litellm._async_failure_callback
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_callback_from_list_by_object():
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
# Reset all callbacks
|
||||||
|
manager._reset_all_callbacks()
|
||||||
|
|
||||||
|
def TestObject():
|
||||||
|
def __init__(self):
|
||||||
|
manager.add_litellm_callback(self.callback)
|
||||||
|
manager.add_litellm_success_callback(self.callback)
|
||||||
|
manager.add_litellm_failure_callback(self.callback)
|
||||||
|
manager.add_litellm_async_success_callback(self.callback)
|
||||||
|
manager.add_litellm_async_failure_callback(self.callback)
|
||||||
|
|
||||||
|
def callback(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
obj = TestObject()
|
||||||
|
|
||||||
|
manager.remove_callback_from_list_by_object(litellm.callbacks, obj)
|
||||||
|
manager.remove_callback_from_list_by_object(litellm.success_callback, obj)
|
||||||
|
manager.remove_callback_from_list_by_object(litellm.failure_callback, obj)
|
||||||
|
manager.remove_callback_from_list_by_object(litellm._async_success_callback, obj)
|
||||||
|
manager.remove_callback_from_list_by_object(litellm._async_failure_callback, obj)
|
||||||
|
|
||||||
|
# Verify all callback lists are empty
|
||||||
|
assert len(litellm.callbacks) == 0
|
||||||
|
assert len(litellm.success_callback) == 0
|
||||||
|
assert len(litellm.failure_callback) == 0
|
||||||
|
assert len(litellm._async_success_callback) == 0
|
||||||
|
assert len(litellm._async_failure_callback) == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_reset_callbacks(callback_manager):
|
def test_reset_callbacks(callback_manager):
|
||||||
# Add various callbacks
|
# Add various callbacks
|
||||||
callback_manager.add_litellm_callback("test")
|
callback_manager.add_litellm_callback("test")
|
||||||
|
|
|
@ -918,6 +918,31 @@ def test_flush_cache(model_list):
|
||||||
assert router.cache.get_cache("test") is None
|
assert router.cache.get_cache("test") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_discard(model_list):
|
||||||
|
"""
|
||||||
|
Test that discard properly removes a Router from the callback lists
|
||||||
|
"""
|
||||||
|
litellm.callbacks = []
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
litellm.failure_callback = []
|
||||||
|
litellm._async_failure_callback = []
|
||||||
|
litellm.input_callback = []
|
||||||
|
litellm.service_callback = []
|
||||||
|
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
router.discard()
|
||||||
|
|
||||||
|
# Verify all callback lists are empty
|
||||||
|
assert len(litellm.callbacks) == 0
|
||||||
|
assert len(litellm.success_callback) == 0
|
||||||
|
assert len(litellm.failure_callback) == 0
|
||||||
|
assert len(litellm._async_success_callback) == 0
|
||||||
|
assert len(litellm._async_failure_callback) == 0
|
||||||
|
assert len(litellm.input_callback) == 0
|
||||||
|
assert len(litellm.service_callback) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_assistants_endpoint(model_list):
|
def test_initialize_assistants_endpoint(model_list):
|
||||||
"""Test if the 'initialize_assistants_endpoint' function is working correctly"""
|
"""Test if the 'initialize_assistants_endpoint' function is working correctly"""
|
||||||
router = Router(model_list=model_list)
|
router = Router(model_list=model_list)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue