diff --git a/litellm/utils.py b/litellm/utils.py index 15266ad34..e37ac15dc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -178,7 +178,7 @@ from .types.router import LiteLLM_Params MAX_THREADS = 100 # Create a ThreadPoolExecutor -executor = ThreadPoolExecutor(max_workers=MAX_THREADS) +executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=MAX_THREADS) sentry_sdk_instance = None capture_exception = None add_breadcrumb = None @@ -914,10 +914,13 @@ def client(original_function): additional_args=None, stream=kwargs.get("stream", False), ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + executor.submit( + logging_obj.success_handler, + cached_result, + start_time, + end_time, + cache_hit, + ) return cached_result else: print_verbose( @@ -1000,9 +1003,12 @@ def client(original_function): # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated verbose_logger.info("Wrapper: Completed Call, calling success_handler") - threading.Thread( - target=logging_obj.success_handler, args=(result, start_time, end_time) - ).start() + executor.submit( + logging_obj.success_handler, + result, + start_time, + end_time, + ) # RETURN RESULT if hasattr(result, "_hidden_params"): result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( @@ -1280,10 +1286,13 @@ def client(original_function): cached_result, start_time, end_time, cache_hit ) ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + executor.submit( + logging_obj.success_handler, + cached_result, + start_time, + end_time, + cache_hit, + ) cache_key = kwargs.get("preset_cache_key", None) if ( isinstance(cached_result, BaseModel) @@ -1385,15 +1394,13 @@ def client(original_function): cache_hit, ) ) - threading.Thread( - target=logging_obj.success_handler, - args=( - final_embedding_cached_response, - start_time, - end_time, - cache_hit, - ), - ).start() + executor.submit( + logging_obj.success_handler, + final_embedding_cached_response, + start_time, + end_time, + cache_hit, + ) return final_embedding_cached_response # MODEL CALL result = await original_function(*args, **kwargs) @@ -1475,11 +1482,12 @@ def client(original_function): ) ) elif isinstance(litellm.cache.cache, S3Cache): - threading.Thread( - target=litellm.cache.add_cache, - args=(result,) + args, - kwargs=kwargs, - ).start() + executor.submit( + litellm.cache.add_cache, + result, + *args, + **kwargs, + ) else: asyncio.create_task( litellm.cache.async_add_cache( @@ -1498,10 +1506,12 @@ def client(original_function): asyncio.create_task( logging_obj.async_success_handler(result, start_time, end_time) ) - threading.Thread( - target=logging_obj.success_handler, - args=(result, start_time, end_time), - ).start() + executor.submit( + logging_obj.success_handler, + result, + start_time, + end_time, + ) # REBUILD EMBEDDING CACHING if ( @@ -3197,7 +3207,7 @@ def get_optional_params( if stream: optional_params["stream"] = stream - #return optional_params + # return optional_params if max_tokens is not None: if "vicuna" in model or "flan" in model: optional_params["max_length"] = max_tokens @@ -6455,15 +6465,6 @@ def get_model_list(): data = response.json() # update model list model_list = data["model_list"] - # # check if all model providers are in environment - # model_providers = data["model_providers"] - # missing_llm_provider = None - # for item in model_providers: - # if f"{item.upper()}_API_KEY" not in os.environ: - # missing_llm_provider = item - # break - # # update environment - if required - # threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start() return model_list return [] # return empty list by default except Exception: @@ -8147,10 +8148,11 @@ class CustomStreamWrapper: if response is None: continue ## LOGGING - threading.Thread( - target=self.run_success_logging_in_thread, - args=(response, cache_hit), - ).start() # log response + executor.submit( + self.run_success_logging_in_thread, + response, + cache_hit, + ) choice = response.choices[0] if isinstance(choice, StreamingChoices): self.response_uptil_now += choice.delta.get("content", "") or "" @@ -8201,10 +8203,13 @@ class CustomStreamWrapper: getattr(complete_streaming_response, "usage"), ) ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, - args=(response, None, None, cache_hit), - ).start() # log response + executor.submit( + self.logging_obj.success_handler, + response, + None, + None, + cache_hit, + ) self.sent_stream_usage = True return response raise # Re-raise StopIteration @@ -8215,17 +8220,20 @@ class CustomStreamWrapper: usage = calculate_total_usage(chunks=self.chunks) processed_chunk._hidden_params["usage"] = usage ## LOGGING - threading.Thread( - target=self.run_success_logging_in_thread, - args=(processed_chunk, cache_hit), - ).start() # log response + executor.submit( + self.run_success_logging_in_thread, + processed_chunk, + cache_hit, + ) return processed_chunk except Exception as e: traceback_exception = traceback.format_exc() # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated - threading.Thread( - target=self.logging_obj.failure_handler, args=(e, traceback_exception) - ).start() + executor.submit( + self.logging_obj.failure_handler, + e, + traceback_exception, + ) if isinstance(e, OpenAIError): raise e else: @@ -8314,11 +8322,13 @@ class CustomStreamWrapper: if processed_chunk is None: continue ## LOGGING - ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, - args=(processed_chunk, None, None, cache_hit), - ).start() # log response + executor.submit( + self.logging_obj.success_handler, + processed_chunk, + None, + None, + cache_hit, + ) # log response asyncio.create_task( self.logging_obj.async_success_handler( processed_chunk, cache_hit=cache_hit @@ -8368,10 +8378,13 @@ class CustomStreamWrapper: if processed_chunk is None: continue ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, - args=(processed_chunk, None, None, cache_hit), - ).start() # log processed_chunk + executor.submit( + self.logging_obj.success_handler, + processed_chunk, + None, + None, + cache_hit, + ) # log processed_chunk asyncio.create_task( self.logging_obj.async_success_handler( processed_chunk, cache_hit=cache_hit @@ -8410,10 +8423,13 @@ class CustomStreamWrapper: getattr(complete_streaming_response, "usage"), ) ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, - args=(response, None, None, cache_hit), - ).start() # log response + executor.submit( + self.logging_obj.success_handler, + response, + None, + None, + cache_hit, + ) # log response asyncio.create_task( self.logging_obj.async_success_handler( response, cache_hit=cache_hit @@ -8426,10 +8442,13 @@ class CustomStreamWrapper: self.sent_last_chunk = True processed_chunk = self.finish_reason_handler() ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, - args=(processed_chunk, None, None, cache_hit), - ).start() # log response + executor.submit( + self.logging_obj.success_handler, + processed_chunk, + None, + None, + cache_hit, + ) # log response asyncio.create_task( self.logging_obj.async_success_handler( processed_chunk, cache_hit=cache_hit @@ -8455,10 +8474,13 @@ class CustomStreamWrapper: getattr(complete_streaming_response, "usage"), ) ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, - args=(response, None, None, cache_hit), - ).start() # log response + executor.submit( + self.logging_obj.success_handler, + response, + None, + None, + cache_hit, + ) # log response asyncio.create_task( self.logging_obj.async_success_handler( response, cache_hit=cache_hit @@ -8471,10 +8493,13 @@ class CustomStreamWrapper: self.sent_last_chunk = True processed_chunk = self.finish_reason_handler() ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, - args=(processed_chunk, None, None, cache_hit), - ).start() # log response + executor.submit( + self.logging_obj.success_handler, + processed_chunk, + None, + None, + cache_hit, + ) # log response asyncio.create_task( self.logging_obj.async_success_handler( processed_chunk, cache_hit=cache_hit @@ -8489,11 +8514,11 @@ class CustomStreamWrapper: ) if self.logging_obj is not None: ## LOGGING - threading.Thread( - target=self.logging_obj.failure_handler, - args=(e, traceback_exception), - ).start() # log response - # Handle any exceptions that might occur during streaming + executor.submit( + self.logging_obj.failure_handler, + e, + traceback_exception, + ) # Handle any exceptions that might occur during streaming asyncio.create_task( self.logging_obj.async_failure_handler(e, traceback_exception) ) @@ -8502,11 +8527,11 @@ class CustomStreamWrapper: traceback_exception = traceback.format_exc() if self.logging_obj is not None: ## LOGGING - threading.Thread( - target=self.logging_obj.failure_handler, - args=(e, traceback_exception), - ).start() # log response - # Handle any exceptions that might occur during streaming + executor.submit( + self.logging_obj.failure_handler, + e, + traceback_exception, + ) # Handle any exceptions that might occur during streaming asyncio.create_task( self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore )