use thread pool exexutor for threads

This commit is contained in:
Ishaan Jaff 2024-10-10 08:53:34 +05:30
parent 60baa65e0e
commit 6255a49de9

View file

@ -178,7 +178,7 @@ from .types.router import LiteLLM_Params
MAX_THREADS = 100 MAX_THREADS = 100
# Create a ThreadPoolExecutor # Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS) executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=MAX_THREADS)
sentry_sdk_instance = None sentry_sdk_instance = None
capture_exception = None capture_exception = None
add_breadcrumb = None add_breadcrumb = None
@ -914,10 +914,13 @@ def client(original_function):
additional_args=None, additional_args=None,
stream=kwargs.get("stream", False), stream=kwargs.get("stream", False),
) )
threading.Thread( executor.submit(
target=logging_obj.success_handler, logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit), cached_result,
).start() start_time,
end_time,
cache_hit,
)
return cached_result return cached_result
else: else:
print_verbose( print_verbose(
@ -1000,9 +1003,12 @@ def client(original_function):
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
verbose_logger.info("Wrapper: Completed Call, calling success_handler") verbose_logger.info("Wrapper: Completed Call, calling success_handler")
threading.Thread( executor.submit(
target=logging_obj.success_handler, args=(result, start_time, end_time) logging_obj.success_handler,
).start() result,
start_time,
end_time,
)
# RETURN RESULT # RETURN RESULT
if hasattr(result, "_hidden_params"): if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
@ -1280,10 +1286,13 @@ def client(original_function):
cached_result, start_time, end_time, cache_hit cached_result, start_time, end_time, cache_hit
) )
) )
threading.Thread( executor.submit(
target=logging_obj.success_handler, logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit), cached_result,
).start() start_time,
end_time,
cache_hit,
)
cache_key = kwargs.get("preset_cache_key", None) cache_key = kwargs.get("preset_cache_key", None)
if ( if (
isinstance(cached_result, BaseModel) isinstance(cached_result, BaseModel)
@ -1385,15 +1394,13 @@ def client(original_function):
cache_hit, cache_hit,
) )
) )
threading.Thread( executor.submit(
target=logging_obj.success_handler, logging_obj.success_handler,
args=(
final_embedding_cached_response, final_embedding_cached_response,
start_time, start_time,
end_time, end_time,
cache_hit, cache_hit,
), )
).start()
return final_embedding_cached_response return final_embedding_cached_response
# MODEL CALL # MODEL CALL
result = await original_function(*args, **kwargs) result = await original_function(*args, **kwargs)
@ -1475,11 +1482,12 @@ def client(original_function):
) )
) )
elif isinstance(litellm.cache.cache, S3Cache): elif isinstance(litellm.cache.cache, S3Cache):
threading.Thread( executor.submit(
target=litellm.cache.add_cache, litellm.cache.add_cache,
args=(result,) + args, result,
kwargs=kwargs, *args,
).start() **kwargs,
)
else: else:
asyncio.create_task( asyncio.create_task(
litellm.cache.async_add_cache( litellm.cache.async_add_cache(
@ -1498,10 +1506,12 @@ def client(original_function):
asyncio.create_task( asyncio.create_task(
logging_obj.async_success_handler(result, start_time, end_time) logging_obj.async_success_handler(result, start_time, end_time)
) )
threading.Thread( executor.submit(
target=logging_obj.success_handler, logging_obj.success_handler,
args=(result, start_time, end_time), result,
).start() start_time,
end_time,
)
# REBUILD EMBEDDING CACHING # REBUILD EMBEDDING CACHING
if ( if (
@ -6455,15 +6465,6 @@ def get_model_list():
data = response.json() data = response.json()
# update model list # update model list
model_list = data["model_list"] model_list = data["model_list"]
# # check if all model providers are in environment
# model_providers = data["model_providers"]
# missing_llm_provider = None
# for item in model_providers:
# if f"{item.upper()}_API_KEY" not in os.environ:
# missing_llm_provider = item
# break
# # update environment - if required
# threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start()
return model_list return model_list
return [] # return empty list by default return [] # return empty list by default
except Exception: except Exception:
@ -8147,10 +8148,11 @@ class CustomStreamWrapper:
if response is None: if response is None:
continue continue
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.run_success_logging_in_thread, self.run_success_logging_in_thread,
args=(response, cache_hit), response,
).start() # log response cache_hit,
)
choice = response.choices[0] choice = response.choices[0]
if isinstance(choice, StreamingChoices): if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or "" self.response_uptil_now += choice.delta.get("content", "") or ""
@ -8201,10 +8203,13 @@ class CustomStreamWrapper:
getattr(complete_streaming_response, "usage"), getattr(complete_streaming_response, "usage"),
) )
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.logging_obj.success_handler, self.logging_obj.success_handler,
args=(response, None, None, cache_hit), response,
).start() # log response None,
None,
cache_hit,
)
self.sent_stream_usage = True self.sent_stream_usage = True
return response return response
raise # Re-raise StopIteration raise # Re-raise StopIteration
@ -8215,17 +8220,20 @@ class CustomStreamWrapper:
usage = calculate_total_usage(chunks=self.chunks) usage = calculate_total_usage(chunks=self.chunks)
processed_chunk._hidden_params["usage"] = usage processed_chunk._hidden_params["usage"] = usage
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.run_success_logging_in_thread, self.run_success_logging_in_thread,
args=(processed_chunk, cache_hit), processed_chunk,
).start() # log response cache_hit,
)
return processed_chunk return processed_chunk
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread( executor.submit(
target=self.logging_obj.failure_handler, args=(e, traceback_exception) self.logging_obj.failure_handler,
).start() e,
traceback_exception,
)
if isinstance(e, OpenAIError): if isinstance(e, OpenAIError):
raise e raise e
else: else:
@ -8314,11 +8322,13 @@ class CustomStreamWrapper:
if processed_chunk is None: if processed_chunk is None:
continue continue
## LOGGING ## LOGGING
## LOGGING executor.submit(
threading.Thread( self.logging_obj.success_handler,
target=self.logging_obj.success_handler, processed_chunk,
args=(processed_chunk, None, None, cache_hit), None,
).start() # log response None,
cache_hit,
) # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit processed_chunk, cache_hit=cache_hit
@ -8368,10 +8378,13 @@ class CustomStreamWrapper:
if processed_chunk is None: if processed_chunk is None:
continue continue
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.logging_obj.success_handler, self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit), processed_chunk,
).start() # log processed_chunk None,
None,
cache_hit,
) # log processed_chunk
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit processed_chunk, cache_hit=cache_hit
@ -8410,10 +8423,13 @@ class CustomStreamWrapper:
getattr(complete_streaming_response, "usage"), getattr(complete_streaming_response, "usage"),
) )
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.logging_obj.success_handler, self.logging_obj.success_handler,
args=(response, None, None, cache_hit), response,
).start() # log response None,
None,
cache_hit,
) # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
response, cache_hit=cache_hit response, cache_hit=cache_hit
@ -8426,10 +8442,13 @@ class CustomStreamWrapper:
self.sent_last_chunk = True self.sent_last_chunk = True
processed_chunk = self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.logging_obj.success_handler, self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit), processed_chunk,
).start() # log response None,
None,
cache_hit,
) # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit processed_chunk, cache_hit=cache_hit
@ -8455,10 +8474,13 @@ class CustomStreamWrapper:
getattr(complete_streaming_response, "usage"), getattr(complete_streaming_response, "usage"),
) )
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.logging_obj.success_handler, self.logging_obj.success_handler,
args=(response, None, None, cache_hit), response,
).start() # log response None,
None,
cache_hit,
) # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
response, cache_hit=cache_hit response, cache_hit=cache_hit
@ -8471,10 +8493,13 @@ class CustomStreamWrapper:
self.sent_last_chunk = True self.sent_last_chunk = True
processed_chunk = self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.logging_obj.success_handler, self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit), processed_chunk,
).start() # log response None,
None,
cache_hit,
) # log response
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_success_handler( self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit processed_chunk, cache_hit=cache_hit
@ -8489,11 +8514,11 @@ class CustomStreamWrapper:
) )
if self.logging_obj is not None: if self.logging_obj is not None:
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.logging_obj.failure_handler, self.logging_obj.failure_handler,
args=(e, traceback_exception), e,
).start() # log response traceback_exception,
# Handle any exceptions that might occur during streaming ) # Handle any exceptions that might occur during streaming
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_failure_handler(e, traceback_exception) self.logging_obj.async_failure_handler(e, traceback_exception)
) )
@ -8502,11 +8527,11 @@ class CustomStreamWrapper:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
if self.logging_obj is not None: if self.logging_obj is not None:
## LOGGING ## LOGGING
threading.Thread( executor.submit(
target=self.logging_obj.failure_handler, self.logging_obj.failure_handler,
args=(e, traceback_exception), e,
).start() # log response traceback_exception,
# Handle any exceptions that might occur during streaming ) # Handle any exceptions that might occur during streaming
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore
) )