Merge branch 'main' into litellm_embedding_caching_updates

This commit is contained in:
Krish Dholakia 2024-01-11 23:58:51 +05:30 committed by GitHub
commit 817a3d29b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 320 additions and 242 deletions

View file

@ -1975,6 +1975,8 @@ def client(original_function):
@wraps(original_function)
def wrapper(*args, **kwargs):
# Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print
print_args_passed_to_litellm(original_function, args, kwargs)
start_time = datetime.datetime.now()
result = None
logging_obj = kwargs.get("litellm_logging_obj", None)
@ -2175,6 +2177,7 @@ def client(original_function):
@wraps(original_function)
async def wrapper_async(*args, **kwargs):
print_args_passed_to_litellm(original_function, args, kwargs)
start_time = datetime.datetime.now()
result = None
logging_obj = kwargs.get("litellm_logging_obj", None)
@ -2991,7 +2994,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
response=httpx.Response(
status_code=404,
content=error_str,
request=httpx.request(method="cost_per_token", url="https://github.com/BerriAI/litellm"), # type: ignore
request=httpx.Request(method="cost_per_token", url="https://github.com/BerriAI/litellm"), # type: ignore
),
llm_provider="",
)
@ -4318,7 +4321,7 @@ def get_llm_provider(
response=httpx.Response(
status_code=400,
content=error_str,
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
request=httpx.Request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
),
llm_provider="",
)
@ -4333,7 +4336,7 @@ def get_llm_provider(
response=httpx.Response(
status_code=400,
content=error_str,
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
request=httpx.Request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
),
llm_provider="",
)
@ -8427,3 +8430,49 @@ def transform_logprobs(hf_response):
transformed_logprobs = token_info
return transformed_logprobs
def print_args_passed_to_litellm(original_function, args, kwargs):
try:
# we've already printed this for acompletion, don't print for completion
if (
"acompletion" in kwargs
and kwargs["acompletion"] == True
and original_function.__name__ == "completion"
):
return
elif (
"aembedding" in kwargs
and kwargs["aembedding"] == True
and original_function.__name__ == "embedding"
):
return
elif (
"aimg_generation" in kwargs
and kwargs["aimg_generation"] == True
and original_function.__name__ == "img_generation"
):
return
args_str = ", ".join(map(repr, args))
kwargs_str = ", ".join(f"{key}={repr(value)}" for key, value in kwargs.items())
print_verbose("\n") # new line before
print_verbose("\033[92mRequest to litellm:\033[0m")
if args and kwargs:
print_verbose(
f"\033[92mlitellm.{original_function.__name__}({args_str}, {kwargs_str})\033[0m"
)
elif args:
print_verbose(
f"\033[92mlitellm.{original_function.__name__}({args_str})\033[0m"
)
elif kwargs:
print_verbose(
f"\033[92mlitellm.{original_function.__name__}({kwargs_str})\033[0m"
)
else:
print_verbose(f"\033[92mlitellm.{original_function.__name__}()\033[0m")
print_verbose("\n") # new line after
except:
# This should always be non blocking
pass