(fix) improve async perf

This commit is contained in:
Ishaan Jaff 2024-03-30 19:07:04 -07:00
parent cdd6e79e6c
commit bd95626579

View file

@ -1434,9 +1434,7 @@ class Logging:
model = self.model
kwargs = self.model_call_details
input = kwargs.get(
"messages", kwargs.get("input", None)
)
input = kwargs.get("messages", kwargs.get("input", None))
type = (
"embed"
@ -1444,7 +1442,7 @@ class Logging:
else "llm"
)
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
if self.stream:
if "complete_streaming_response" not in kwargs:
break
@ -1458,7 +1456,7 @@ class Logging:
model=model,
input=input,
user_id=kwargs.get("user", None),
#user_props=self.model_call_details.get("user_props", None),
# user_props=self.model_call_details.get("user_props", None),
extra=kwargs.get("optional_params", {}),
response_obj=result,
start_time=start_time,
@ -2064,8 +2062,6 @@ class Logging:
else "llm"
)
lunaryLogger.log_event(
type=_type,
event="error",
@ -2509,6 +2505,33 @@ def client(original_function):
@wraps(original_function)
def wrapper(*args, **kwargs):
# DO NOT MOVE THIS. It always needs to run first
# Check if this is an async function. If so only execute the async function
if (
kwargs.get("acompletion", False) == True
or kwargs.get("aembedding", False) == True
or kwargs.get("aimg_generation", False) == True
or kwargs.get("amoderation", False) == True
or kwargs.get("atext_completion", False) == True
or kwargs.get("atranscription", False) == True
):
# MODEL CALL
result = original_function(*args, **kwargs)
if "stream" in kwargs and kwargs["stream"] == True:
if (
"complete_response" in kwargs
and kwargs["complete_response"] == True
):
chunks = []
for idx, chunk in enumerate(result):
chunks.append(chunk)
return litellm.stream_chunk_builder(
chunks, messages=kwargs.get("messages", None)
)
else:
return result
return result
# Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print
print_args_passed_to_litellm(original_function, args, kwargs)
start_time = datetime.datetime.now()
@ -6178,9 +6201,9 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger
try:
for callback in callback_list:
print_verbose(f"callback: {callback}")