Merge pull request #2774 from BerriAI/litellm_async_perf

(fix) improve async perf by 100ms
This commit is contained in:
Ishaan Jaff 2024-04-01 08:12:34 -07:00 committed by GitHub
commit bbfd850e12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 10 deletions

View file

@ -41,7 +41,7 @@ def test_completion_custom_provider_model_name():
messages=messages, messages=messages,
logger_fn=logger_fn, logger_fn=logger_fn,
) )
# Add any assertions here to check the,response # Add any assertions here to check the response
print(response) print(response)
print(response["choices"][0]["finish_reason"]) print(response["choices"][0]["finish_reason"])
except litellm.Timeout as e: except litellm.Timeout as e:

View file

@ -1434,9 +1434,7 @@ class Logging:
model = self.model model = self.model
kwargs = self.model_call_details kwargs = self.model_call_details
input = kwargs.get( input = kwargs.get("messages", kwargs.get("input", None))
"messages", kwargs.get("input", None)
)
type = ( type = (
"embed" "embed"
@ -1458,7 +1456,7 @@ class Logging:
model=model, model=model,
input=input, input=input,
user_id=kwargs.get("user", None), user_id=kwargs.get("user", None),
#user_props=self.model_call_details.get("user_props", None), # user_props=self.model_call_details.get("user_props", None),
extra=kwargs.get("optional_params", {}), extra=kwargs.get("optional_params", {}),
response_obj=result, response_obj=result,
start_time=start_time, start_time=start_time,
@ -2064,8 +2062,6 @@ class Logging:
else "llm" else "llm"
) )
lunaryLogger.log_event( lunaryLogger.log_event(
type=_type, type=_type,
event="error", event="error",
@ -2509,6 +2505,43 @@ def client(original_function):
@wraps(original_function) @wraps(original_function)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
# DO NOT MOVE THIS. It always needs to run first
# Check if this is an async function. If so only execute the async function
if (
kwargs.get("acompletion", False) == True
or kwargs.get("aembedding", False) == True
or kwargs.get("aimg_generation", False) == True
or kwargs.get("amoderation", False) == True
or kwargs.get("atext_completion", False) == True
or kwargs.get("atranscription", False) == True
):
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
if litellm.num_retries_per_request is not None:
# check if previous_models passed in as ['litellm_params']['metadata]['previous_models']
previous_models = kwargs.get("metadata", {}).get(
"previous_models", None
)
if previous_models is not None:
if litellm.num_retries_per_request <= len(previous_models):
raise Exception(f"Max retries per request hit!")
# MODEL CALL
result = original_function(*args, **kwargs)
if "stream" in kwargs and kwargs["stream"] == True:
if (
"complete_response" in kwargs
and kwargs["complete_response"] == True
):
chunks = []
for idx, chunk in enumerate(result):
chunks.append(chunk)
return litellm.stream_chunk_builder(
chunks, messages=kwargs.get("messages", None)
)
else:
return result
return result
# Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print
print_args_passed_to_litellm(original_function, args, kwargs) print_args_passed_to_litellm(original_function, args, kwargs)
start_time = datetime.datetime.now() start_time = datetime.datetime.now()