forked from phoenix/litellm-mirror
use thread pool exexutor for threads
This commit is contained in:
parent
60baa65e0e
commit
6255a49de9
1 changed files with 115 additions and 90 deletions
205
litellm/utils.py
205
litellm/utils.py
|
@ -178,7 +178,7 @@ from .types.router import LiteLLM_Params
|
|||
MAX_THREADS = 100
|
||||
|
||||
# Create a ThreadPoolExecutor
|
||||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
||||
executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
||||
sentry_sdk_instance = None
|
||||
capture_exception = None
|
||||
add_breadcrumb = None
|
||||
|
@ -914,10 +914,13 @@ def client(original_function):
|
|||
additional_args=None,
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
executor.submit(
|
||||
logging_obj.success_handler,
|
||||
cached_result,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
)
|
||||
return cached_result
|
||||
else:
|
||||
print_verbose(
|
||||
|
@ -1000,9 +1003,12 @@ def client(original_function):
|
|||
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||
verbose_logger.info("Wrapper: Completed Call, calling success_handler")
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
||||
).start()
|
||||
executor.submit(
|
||||
logging_obj.success_handler,
|
||||
result,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
# RETURN RESULT
|
||||
if hasattr(result, "_hidden_params"):
|
||||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||
|
@ -1280,10 +1286,13 @@ def client(original_function):
|
|||
cached_result, start_time, end_time, cache_hit
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
executor.submit(
|
||||
logging_obj.success_handler,
|
||||
cached_result,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
)
|
||||
cache_key = kwargs.get("preset_cache_key", None)
|
||||
if (
|
||||
isinstance(cached_result, BaseModel)
|
||||
|
@ -1385,15 +1394,13 @@ def client(original_function):
|
|||
cache_hit,
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(
|
||||
final_embedding_cached_response,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
),
|
||||
).start()
|
||||
executor.submit(
|
||||
logging_obj.success_handler,
|
||||
final_embedding_cached_response,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
)
|
||||
return final_embedding_cached_response
|
||||
# MODEL CALL
|
||||
result = await original_function(*args, **kwargs)
|
||||
|
@ -1475,11 +1482,12 @@ def client(original_function):
|
|||
)
|
||||
)
|
||||
elif isinstance(litellm.cache.cache, S3Cache):
|
||||
threading.Thread(
|
||||
target=litellm.cache.add_cache,
|
||||
args=(result,) + args,
|
||||
kwargs=kwargs,
|
||||
).start()
|
||||
executor.submit(
|
||||
litellm.cache.add_cache,
|
||||
result,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(
|
||||
|
@ -1498,10 +1506,12 @@ def client(original_function):
|
|||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(result, start_time, end_time)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(result, start_time, end_time),
|
||||
).start()
|
||||
executor.submit(
|
||||
logging_obj.success_handler,
|
||||
result,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
|
||||
# REBUILD EMBEDDING CACHING
|
||||
if (
|
||||
|
@ -3197,7 +3207,7 @@ def get_optional_params(
|
|||
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
#return optional_params
|
||||
# return optional_params
|
||||
if max_tokens is not None:
|
||||
if "vicuna" in model or "flan" in model:
|
||||
optional_params["max_length"] = max_tokens
|
||||
|
@ -6455,15 +6465,6 @@ def get_model_list():
|
|||
data = response.json()
|
||||
# update model list
|
||||
model_list = data["model_list"]
|
||||
# # check if all model providers are in environment
|
||||
# model_providers = data["model_providers"]
|
||||
# missing_llm_provider = None
|
||||
# for item in model_providers:
|
||||
# if f"{item.upper()}_API_KEY" not in os.environ:
|
||||
# missing_llm_provider = item
|
||||
# break
|
||||
# # update environment - if required
|
||||
# threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start()
|
||||
return model_list
|
||||
return [] # return empty list by default
|
||||
except Exception:
|
||||
|
@ -8147,10 +8148,11 @@ class CustomStreamWrapper:
|
|||
if response is None:
|
||||
continue
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.run_success_logging_in_thread,
|
||||
args=(response, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.run_success_logging_in_thread,
|
||||
response,
|
||||
cache_hit,
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||
|
@ -8201,10 +8203,13 @@ class CustomStreamWrapper:
|
|||
getattr(complete_streaming_response, "usage"),
|
||||
)
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
args=(response, None, None, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
response,
|
||||
None,
|
||||
None,
|
||||
cache_hit,
|
||||
)
|
||||
self.sent_stream_usage = True
|
||||
return response
|
||||
raise # Re-raise StopIteration
|
||||
|
@ -8215,17 +8220,20 @@ class CustomStreamWrapper:
|
|||
usage = calculate_total_usage(chunks=self.chunks)
|
||||
processed_chunk._hidden_params["usage"] = usage
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.run_success_logging_in_thread,
|
||||
args=(processed_chunk, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.run_success_logging_in_thread,
|
||||
processed_chunk,
|
||||
cache_hit,
|
||||
)
|
||||
return processed_chunk
|
||||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
threading.Thread(
|
||||
target=self.logging_obj.failure_handler, args=(e, traceback_exception)
|
||||
).start()
|
||||
executor.submit(
|
||||
self.logging_obj.failure_handler,
|
||||
e,
|
||||
traceback_exception,
|
||||
)
|
||||
if isinstance(e, OpenAIError):
|
||||
raise e
|
||||
else:
|
||||
|
@ -8314,11 +8322,13 @@ class CustomStreamWrapper:
|
|||
if processed_chunk is None:
|
||||
continue
|
||||
## LOGGING
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
args=(processed_chunk, None, None, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
processed_chunk,
|
||||
None,
|
||||
None,
|
||||
cache_hit,
|
||||
) # log response
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
processed_chunk, cache_hit=cache_hit
|
||||
|
@ -8368,10 +8378,13 @@ class CustomStreamWrapper:
|
|||
if processed_chunk is None:
|
||||
continue
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
args=(processed_chunk, None, None, cache_hit),
|
||||
).start() # log processed_chunk
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
processed_chunk,
|
||||
None,
|
||||
None,
|
||||
cache_hit,
|
||||
) # log processed_chunk
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
processed_chunk, cache_hit=cache_hit
|
||||
|
@ -8410,10 +8423,13 @@ class CustomStreamWrapper:
|
|||
getattr(complete_streaming_response, "usage"),
|
||||
)
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
args=(response, None, None, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
response,
|
||||
None,
|
||||
None,
|
||||
cache_hit,
|
||||
) # log response
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
response, cache_hit=cache_hit
|
||||
|
@ -8426,10 +8442,13 @@ class CustomStreamWrapper:
|
|||
self.sent_last_chunk = True
|
||||
processed_chunk = self.finish_reason_handler()
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
args=(processed_chunk, None, None, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
processed_chunk,
|
||||
None,
|
||||
None,
|
||||
cache_hit,
|
||||
) # log response
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
processed_chunk, cache_hit=cache_hit
|
||||
|
@ -8455,10 +8474,13 @@ class CustomStreamWrapper:
|
|||
getattr(complete_streaming_response, "usage"),
|
||||
)
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
args=(response, None, None, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
response,
|
||||
None,
|
||||
None,
|
||||
cache_hit,
|
||||
) # log response
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
response, cache_hit=cache_hit
|
||||
|
@ -8471,10 +8493,13 @@ class CustomStreamWrapper:
|
|||
self.sent_last_chunk = True
|
||||
processed_chunk = self.finish_reason_handler()
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
args=(processed_chunk, None, None, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
processed_chunk,
|
||||
None,
|
||||
None,
|
||||
cache_hit,
|
||||
) # log response
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
processed_chunk, cache_hit=cache_hit
|
||||
|
@ -8489,11 +8514,11 @@ class CustomStreamWrapper:
|
|||
)
|
||||
if self.logging_obj is not None:
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.failure_handler,
|
||||
args=(e, traceback_exception),
|
||||
).start() # log response
|
||||
# Handle any exceptions that might occur during streaming
|
||||
executor.submit(
|
||||
self.logging_obj.failure_handler,
|
||||
e,
|
||||
traceback_exception,
|
||||
) # Handle any exceptions that might occur during streaming
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_failure_handler(e, traceback_exception)
|
||||
)
|
||||
|
@ -8502,11 +8527,11 @@ class CustomStreamWrapper:
|
|||
traceback_exception = traceback.format_exc()
|
||||
if self.logging_obj is not None:
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.failure_handler,
|
||||
args=(e, traceback_exception),
|
||||
).start() # log response
|
||||
# Handle any exceptions that might occur during streaming
|
||||
executor.submit(
|
||||
self.logging_obj.failure_handler,
|
||||
e,
|
||||
traceback_exception,
|
||||
) # Handle any exceptions that might occur during streaming
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue