forked from phoenix/litellm-mirror
use thread pool exexutor for threads
This commit is contained in:
parent
60baa65e0e
commit
6255a49de9
1 changed files with 115 additions and 90 deletions
195
litellm/utils.py
195
litellm/utils.py
|
@ -178,7 +178,7 @@ from .types.router import LiteLLM_Params
|
||||||
MAX_THREADS = 100
|
MAX_THREADS = 100
|
||||||
|
|
||||||
# Create a ThreadPoolExecutor
|
# Create a ThreadPoolExecutor
|
||||||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
||||||
sentry_sdk_instance = None
|
sentry_sdk_instance = None
|
||||||
capture_exception = None
|
capture_exception = None
|
||||||
add_breadcrumb = None
|
add_breadcrumb = None
|
||||||
|
@ -914,10 +914,13 @@ def client(original_function):
|
||||||
additional_args=None,
|
additional_args=None,
|
||||||
stream=kwargs.get("stream", False),
|
stream=kwargs.get("stream", False),
|
||||||
)
|
)
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=logging_obj.success_handler,
|
logging_obj.success_handler,
|
||||||
args=(cached_result, start_time, end_time, cache_hit),
|
cached_result,
|
||||||
).start()
|
start_time,
|
||||||
|
end_time,
|
||||||
|
cache_hit,
|
||||||
|
)
|
||||||
return cached_result
|
return cached_result
|
||||||
else:
|
else:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -1000,9 +1003,12 @@ def client(original_function):
|
||||||
|
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||||
verbose_logger.info("Wrapper: Completed Call, calling success_handler")
|
verbose_logger.info("Wrapper: Completed Call, calling success_handler")
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
logging_obj.success_handler,
|
||||||
).start()
|
result,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
)
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
if hasattr(result, "_hidden_params"):
|
if hasattr(result, "_hidden_params"):
|
||||||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||||
|
@ -1280,10 +1286,13 @@ def client(original_function):
|
||||||
cached_result, start_time, end_time, cache_hit
|
cached_result, start_time, end_time, cache_hit
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=logging_obj.success_handler,
|
logging_obj.success_handler,
|
||||||
args=(cached_result, start_time, end_time, cache_hit),
|
cached_result,
|
||||||
).start()
|
start_time,
|
||||||
|
end_time,
|
||||||
|
cache_hit,
|
||||||
|
)
|
||||||
cache_key = kwargs.get("preset_cache_key", None)
|
cache_key = kwargs.get("preset_cache_key", None)
|
||||||
if (
|
if (
|
||||||
isinstance(cached_result, BaseModel)
|
isinstance(cached_result, BaseModel)
|
||||||
|
@ -1385,15 +1394,13 @@ def client(original_function):
|
||||||
cache_hit,
|
cache_hit,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=logging_obj.success_handler,
|
logging_obj.success_handler,
|
||||||
args=(
|
|
||||||
final_embedding_cached_response,
|
final_embedding_cached_response,
|
||||||
start_time,
|
start_time,
|
||||||
end_time,
|
end_time,
|
||||||
cache_hit,
|
cache_hit,
|
||||||
),
|
)
|
||||||
).start()
|
|
||||||
return final_embedding_cached_response
|
return final_embedding_cached_response
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = await original_function(*args, **kwargs)
|
result = await original_function(*args, **kwargs)
|
||||||
|
@ -1475,11 +1482,12 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(litellm.cache.cache, S3Cache):
|
elif isinstance(litellm.cache.cache, S3Cache):
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=litellm.cache.add_cache,
|
litellm.cache.add_cache,
|
||||||
args=(result,) + args,
|
result,
|
||||||
kwargs=kwargs,
|
*args,
|
||||||
).start()
|
**kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
litellm.cache.async_add_cache(
|
litellm.cache.async_add_cache(
|
||||||
|
@ -1498,10 +1506,12 @@ def client(original_function):
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
logging_obj.async_success_handler(result, start_time, end_time)
|
logging_obj.async_success_handler(result, start_time, end_time)
|
||||||
)
|
)
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=logging_obj.success_handler,
|
logging_obj.success_handler,
|
||||||
args=(result, start_time, end_time),
|
result,
|
||||||
).start()
|
start_time,
|
||||||
|
end_time,
|
||||||
|
)
|
||||||
|
|
||||||
# REBUILD EMBEDDING CACHING
|
# REBUILD EMBEDDING CACHING
|
||||||
if (
|
if (
|
||||||
|
@ -6455,15 +6465,6 @@ def get_model_list():
|
||||||
data = response.json()
|
data = response.json()
|
||||||
# update model list
|
# update model list
|
||||||
model_list = data["model_list"]
|
model_list = data["model_list"]
|
||||||
# # check if all model providers are in environment
|
|
||||||
# model_providers = data["model_providers"]
|
|
||||||
# missing_llm_provider = None
|
|
||||||
# for item in model_providers:
|
|
||||||
# if f"{item.upper()}_API_KEY" not in os.environ:
|
|
||||||
# missing_llm_provider = item
|
|
||||||
# break
|
|
||||||
# # update environment - if required
|
|
||||||
# threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start()
|
|
||||||
return model_list
|
return model_list
|
||||||
return [] # return empty list by default
|
return [] # return empty list by default
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -8147,10 +8148,11 @@ class CustomStreamWrapper:
|
||||||
if response is None:
|
if response is None:
|
||||||
continue
|
continue
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.run_success_logging_in_thread,
|
self.run_success_logging_in_thread,
|
||||||
args=(response, cache_hit),
|
response,
|
||||||
).start() # log response
|
cache_hit,
|
||||||
|
)
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
if isinstance(choice, StreamingChoices):
|
if isinstance(choice, StreamingChoices):
|
||||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||||
|
@ -8201,10 +8203,13 @@ class CustomStreamWrapper:
|
||||||
getattr(complete_streaming_response, "usage"),
|
getattr(complete_streaming_response, "usage"),
|
||||||
)
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.logging_obj.success_handler,
|
self.logging_obj.success_handler,
|
||||||
args=(response, None, None, cache_hit),
|
response,
|
||||||
).start() # log response
|
None,
|
||||||
|
None,
|
||||||
|
cache_hit,
|
||||||
|
)
|
||||||
self.sent_stream_usage = True
|
self.sent_stream_usage = True
|
||||||
return response
|
return response
|
||||||
raise # Re-raise StopIteration
|
raise # Re-raise StopIteration
|
||||||
|
@ -8215,17 +8220,20 @@ class CustomStreamWrapper:
|
||||||
usage = calculate_total_usage(chunks=self.chunks)
|
usage = calculate_total_usage(chunks=self.chunks)
|
||||||
processed_chunk._hidden_params["usage"] = usage
|
processed_chunk._hidden_params["usage"] = usage
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.run_success_logging_in_thread,
|
self.run_success_logging_in_thread,
|
||||||
args=(processed_chunk, cache_hit),
|
processed_chunk,
|
||||||
).start() # log response
|
cache_hit,
|
||||||
|
)
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback_exception = traceback.format_exc()
|
traceback_exception = traceback.format_exc()
|
||||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.logging_obj.failure_handler, args=(e, traceback_exception)
|
self.logging_obj.failure_handler,
|
||||||
).start()
|
e,
|
||||||
|
traceback_exception,
|
||||||
|
)
|
||||||
if isinstance(e, OpenAIError):
|
if isinstance(e, OpenAIError):
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
|
@ -8314,11 +8322,13 @@ class CustomStreamWrapper:
|
||||||
if processed_chunk is None:
|
if processed_chunk is None:
|
||||||
continue
|
continue
|
||||||
## LOGGING
|
## LOGGING
|
||||||
## LOGGING
|
executor.submit(
|
||||||
threading.Thread(
|
self.logging_obj.success_handler,
|
||||||
target=self.logging_obj.success_handler,
|
processed_chunk,
|
||||||
args=(processed_chunk, None, None, cache_hit),
|
None,
|
||||||
).start() # log response
|
None,
|
||||||
|
cache_hit,
|
||||||
|
) # log response
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.logging_obj.async_success_handler(
|
self.logging_obj.async_success_handler(
|
||||||
processed_chunk, cache_hit=cache_hit
|
processed_chunk, cache_hit=cache_hit
|
||||||
|
@ -8368,10 +8378,13 @@ class CustomStreamWrapper:
|
||||||
if processed_chunk is None:
|
if processed_chunk is None:
|
||||||
continue
|
continue
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.logging_obj.success_handler,
|
self.logging_obj.success_handler,
|
||||||
args=(processed_chunk, None, None, cache_hit),
|
processed_chunk,
|
||||||
).start() # log processed_chunk
|
None,
|
||||||
|
None,
|
||||||
|
cache_hit,
|
||||||
|
) # log processed_chunk
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.logging_obj.async_success_handler(
|
self.logging_obj.async_success_handler(
|
||||||
processed_chunk, cache_hit=cache_hit
|
processed_chunk, cache_hit=cache_hit
|
||||||
|
@ -8410,10 +8423,13 @@ class CustomStreamWrapper:
|
||||||
getattr(complete_streaming_response, "usage"),
|
getattr(complete_streaming_response, "usage"),
|
||||||
)
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.logging_obj.success_handler,
|
self.logging_obj.success_handler,
|
||||||
args=(response, None, None, cache_hit),
|
response,
|
||||||
).start() # log response
|
None,
|
||||||
|
None,
|
||||||
|
cache_hit,
|
||||||
|
) # log response
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.logging_obj.async_success_handler(
|
self.logging_obj.async_success_handler(
|
||||||
response, cache_hit=cache_hit
|
response, cache_hit=cache_hit
|
||||||
|
@ -8426,10 +8442,13 @@ class CustomStreamWrapper:
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
processed_chunk = self.finish_reason_handler()
|
processed_chunk = self.finish_reason_handler()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.logging_obj.success_handler,
|
self.logging_obj.success_handler,
|
||||||
args=(processed_chunk, None, None, cache_hit),
|
processed_chunk,
|
||||||
).start() # log response
|
None,
|
||||||
|
None,
|
||||||
|
cache_hit,
|
||||||
|
) # log response
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.logging_obj.async_success_handler(
|
self.logging_obj.async_success_handler(
|
||||||
processed_chunk, cache_hit=cache_hit
|
processed_chunk, cache_hit=cache_hit
|
||||||
|
@ -8455,10 +8474,13 @@ class CustomStreamWrapper:
|
||||||
getattr(complete_streaming_response, "usage"),
|
getattr(complete_streaming_response, "usage"),
|
||||||
)
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.logging_obj.success_handler,
|
self.logging_obj.success_handler,
|
||||||
args=(response, None, None, cache_hit),
|
response,
|
||||||
).start() # log response
|
None,
|
||||||
|
None,
|
||||||
|
cache_hit,
|
||||||
|
) # log response
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.logging_obj.async_success_handler(
|
self.logging_obj.async_success_handler(
|
||||||
response, cache_hit=cache_hit
|
response, cache_hit=cache_hit
|
||||||
|
@ -8471,10 +8493,13 @@ class CustomStreamWrapper:
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
processed_chunk = self.finish_reason_handler()
|
processed_chunk = self.finish_reason_handler()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.logging_obj.success_handler,
|
self.logging_obj.success_handler,
|
||||||
args=(processed_chunk, None, None, cache_hit),
|
processed_chunk,
|
||||||
).start() # log response
|
None,
|
||||||
|
None,
|
||||||
|
cache_hit,
|
||||||
|
) # log response
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.logging_obj.async_success_handler(
|
self.logging_obj.async_success_handler(
|
||||||
processed_chunk, cache_hit=cache_hit
|
processed_chunk, cache_hit=cache_hit
|
||||||
|
@ -8489,11 +8514,11 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
if self.logging_obj is not None:
|
if self.logging_obj is not None:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.logging_obj.failure_handler,
|
self.logging_obj.failure_handler,
|
||||||
args=(e, traceback_exception),
|
e,
|
||||||
).start() # log response
|
traceback_exception,
|
||||||
# Handle any exceptions that might occur during streaming
|
) # Handle any exceptions that might occur during streaming
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.logging_obj.async_failure_handler(e, traceback_exception)
|
self.logging_obj.async_failure_handler(e, traceback_exception)
|
||||||
)
|
)
|
||||||
|
@ -8502,11 +8527,11 @@ class CustomStreamWrapper:
|
||||||
traceback_exception = traceback.format_exc()
|
traceback_exception = traceback.format_exc()
|
||||||
if self.logging_obj is not None:
|
if self.logging_obj is not None:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
executor.submit(
|
||||||
target=self.logging_obj.failure_handler,
|
self.logging_obj.failure_handler,
|
||||||
args=(e, traceback_exception),
|
e,
|
||||||
).start() # log response
|
traceback_exception,
|
||||||
# Handle any exceptions that might occur during streaming
|
) # Handle any exceptions that might occur during streaming
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore
|
self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue