use thread pool exexutor for threads

This commit is contained in:
Ishaan Jaff 2024-10-10 08:53:34 +05:30
parent 60baa65e0e
commit 6255a49de9

View file

@ -178,7 +178,7 @@ from .types.router import LiteLLM_Params
MAX_THREADS = 100
# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=MAX_THREADS)
sentry_sdk_instance = None
capture_exception = None
add_breadcrumb = None
@ -914,10 +914,13 @@ def client(original_function):
additional_args=None,
stream=kwargs.get("stream", False),
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
executor.submit(
logging_obj.success_handler,
cached_result,
start_time,
end_time,
cache_hit,
)
return cached_result
else:
print_verbose(
@ -1000,9 +1003,12 @@ def client(original_function):
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
verbose_logger.info("Wrapper: Completed Call, calling success_handler")
threading.Thread(
target=logging_obj.success_handler, args=(result, start_time, end_time)
).start()
executor.submit(
logging_obj.success_handler,
result,
start_time,
end_time,
)
# RETURN RESULT
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
@ -1280,10 +1286,13 @@ def client(original_function):
cached_result, start_time, end_time, cache_hit
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
executor.submit(
logging_obj.success_handler,
cached_result,
start_time,
end_time,
cache_hit,
)
cache_key = kwargs.get("preset_cache_key", None)
if (
isinstance(cached_result, BaseModel)
@ -1385,15 +1394,13 @@ def client(original_function):
cache_hit,
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
),
).start()
executor.submit(
logging_obj.success_handler,
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
)
return final_embedding_cached_response
# MODEL CALL
result = await original_function(*args, **kwargs)
@ -1475,11 +1482,12 @@ def client(original_function):
)
)
elif isinstance(litellm.cache.cache, S3Cache):
threading.Thread(
target=litellm.cache.add_cache,
args=(result,) + args,
kwargs=kwargs,
).start()
executor.submit(
litellm.cache.add_cache,
result,
*args,
**kwargs,
)
else:
asyncio.create_task(
litellm.cache.async_add_cache(
@ -1498,10 +1506,12 @@ def client(original_function):
asyncio.create_task(
logging_obj.async_success_handler(result, start_time, end_time)
)
threading.Thread(
target=logging_obj.success_handler,
args=(result, start_time, end_time),
).start()
executor.submit(
logging_obj.success_handler,
result,
start_time,
end_time,
)
# REBUILD EMBEDDING CACHING
if (
@ -3197,7 +3207,7 @@ def get_optional_params(
if stream:
optional_params["stream"] = stream
#return optional_params
# return optional_params
if max_tokens is not None:
if "vicuna" in model or "flan" in model:
optional_params["max_length"] = max_tokens
@ -6455,15 +6465,6 @@ def get_model_list():
data = response.json()
# update model list
model_list = data["model_list"]
# # check if all model providers are in environment
# model_providers = data["model_providers"]
# missing_llm_provider = None
# for item in model_providers:
# if f"{item.upper()}_API_KEY" not in os.environ:
# missing_llm_provider = item
# break
# # update environment - if required
# threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start()
return model_list
return [] # return empty list by default
except Exception:
@ -8147,10 +8148,11 @@ class CustomStreamWrapper:
if response is None:
continue
## LOGGING
threading.Thread(
target=self.run_success_logging_in_thread,
args=(response, cache_hit),
).start() # log response
executor.submit(
self.run_success_logging_in_thread,
response,
cache_hit,
)
choice = response.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""
@ -8201,10 +8203,13 @@ class CustomStreamWrapper:
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
executor.submit(
self.logging_obj.success_handler,
response,
None,
None,
cache_hit,
)
self.sent_stream_usage = True
return response
raise # Re-raise StopIteration
@ -8215,17 +8220,20 @@ class CustomStreamWrapper:
usage = calculate_total_usage(chunks=self.chunks)
processed_chunk._hidden_params["usage"] = usage
## LOGGING
threading.Thread(
target=self.run_success_logging_in_thread,
args=(processed_chunk, cache_hit),
).start() # log response
executor.submit(
self.run_success_logging_in_thread,
processed_chunk,
cache_hit,
)
return processed_chunk
except Exception as e:
traceback_exception = traceback.format_exc()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(
target=self.logging_obj.failure_handler, args=(e, traceback_exception)
).start()
executor.submit(
self.logging_obj.failure_handler,
e,
traceback_exception,
)
if isinstance(e, OpenAIError):
raise e
else:
@ -8314,11 +8322,13 @@ class CustomStreamWrapper:
if processed_chunk is None:
continue
## LOGGING
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
executor.submit(
self.logging_obj.success_handler,
processed_chunk,
None,
None,
cache_hit,
) # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
@ -8368,10 +8378,13 @@ class CustomStreamWrapper:
if processed_chunk is None:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log processed_chunk
executor.submit(
self.logging_obj.success_handler,
processed_chunk,
None,
None,
cache_hit,
) # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
@ -8410,10 +8423,13 @@ class CustomStreamWrapper:
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
executor.submit(
self.logging_obj.success_handler,
response,
None,
None,
cache_hit,
) # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
response, cache_hit=cache_hit
@ -8426,10 +8442,13 @@ class CustomStreamWrapper:
self.sent_last_chunk = True
processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
executor.submit(
self.logging_obj.success_handler,
processed_chunk,
None,
None,
cache_hit,
) # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
@ -8455,10 +8474,13 @@ class CustomStreamWrapper:
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
executor.submit(
self.logging_obj.success_handler,
response,
None,
None,
cache_hit,
) # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
response, cache_hit=cache_hit
@ -8471,10 +8493,13 @@ class CustomStreamWrapper:
self.sent_last_chunk = True
processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
executor.submit(
self.logging_obj.success_handler,
processed_chunk,
None,
None,
cache_hit,
) # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
@ -8489,11 +8514,11 @@ class CustomStreamWrapper:
)
if self.logging_obj is not None:
## LOGGING
threading.Thread(
target=self.logging_obj.failure_handler,
args=(e, traceback_exception),
).start() # log response
# Handle any exceptions that might occur during streaming
executor.submit(
self.logging_obj.failure_handler,
e,
traceback_exception,
) # Handle any exceptions that might occur during streaming
asyncio.create_task(
self.logging_obj.async_failure_handler(e, traceback_exception)
)
@ -8502,11 +8527,11 @@ class CustomStreamWrapper:
traceback_exception = traceback.format_exc()
if self.logging_obj is not None:
## LOGGING
threading.Thread(
target=self.logging_obj.failure_handler,
args=(e, traceback_exception),
).start() # log response
# Handle any exceptions that might occur during streaming
executor.submit(
self.logging_obj.failure_handler,
e,
traceback_exception,
) # Handle any exceptions that might occur during streaming
asyncio.create_task(
self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore
)