Merge pull request #2774 from BerriAI/litellm_async_perf

(fix) improve async perf by 100ms
This commit is contained in:
Ishaan Jaff 2024-04-01 08:12:34 -07:00 committed by GitHub
commit bbfd850e12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 10 deletions

View file

@ -41,7 +41,7 @@ def test_completion_custom_provider_model_name():
messages=messages,
logger_fn=logger_fn,
)
# Add any assertions here to check the,response
# Add any assertions here to check the response
print(response)
print(response["choices"][0]["finish_reason"])
except litellm.Timeout as e:

View file

@ -1434,9 +1434,7 @@ class Logging:
model = self.model
kwargs = self.model_call_details
input = kwargs.get(
"messages", kwargs.get("input", None)
)
input = kwargs.get("messages", kwargs.get("input", None))
type = (
"embed"
@ -1444,7 +1442,7 @@ class Logging:
else "llm"
)
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
if self.stream:
if "complete_streaming_response" not in kwargs:
break
@ -1458,7 +1456,7 @@ class Logging:
model=model,
input=input,
user_id=kwargs.get("user", None),
#user_props=self.model_call_details.get("user_props", None),
# user_props=self.model_call_details.get("user_props", None),
extra=kwargs.get("optional_params", {}),
response_obj=result,
start_time=start_time,
@ -2064,8 +2062,6 @@ class Logging:
else "llm"
)
lunaryLogger.log_event(
type=_type,
event="error",
@ -2509,6 +2505,43 @@ def client(original_function):
@wraps(original_function)
def wrapper(*args, **kwargs):
# DO NOT MOVE THIS. It always needs to run first
# Check if this is an async function. If so only execute the async function
if (
kwargs.get("acompletion", False) == True
or kwargs.get("aembedding", False) == True
or kwargs.get("aimg_generation", False) == True
or kwargs.get("amoderation", False) == True
or kwargs.get("atext_completion", False) == True
or kwargs.get("atranscription", False) == True
):
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
if litellm.num_retries_per_request is not None:
# check if previous_models passed in as ['litellm_params']['metadata]['previous_models']
previous_models = kwargs.get("metadata", {}).get(
"previous_models", None
)
if previous_models is not None:
if litellm.num_retries_per_request <= len(previous_models):
raise Exception(f"Max retries per request hit!")
# MODEL CALL
result = original_function(*args, **kwargs)
if "stream" in kwargs and kwargs["stream"] == True:
if (
"complete_response" in kwargs
and kwargs["complete_response"] == True
):
chunks = []
for idx, chunk in enumerate(result):
chunks.append(chunk)
return litellm.stream_chunk_builder(
chunks, messages=kwargs.get("messages", None)
)
else:
return result
return result
# Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print
print_args_passed_to_litellm(original_function, args, kwargs)
start_time = datetime.datetime.now()
@ -6178,9 +6211,9 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger
try:
for callback in callback_list:
print_verbose(f"callback: {callback}")