Merge branch 'main' into litellm_bedrock_command_r_support

This commit is contained in:
Krish Dholakia 2024-05-11 21:24:42 -07:00 committed by GitHub
commit 1d651c6049
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
82 changed files with 3661 additions and 605 deletions

View file

@ -3,7 +3,6 @@
from datetime import datetime, timezone
import traceback
import importlib
import sys
import packaging
@ -15,13 +14,33 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
}
def parse_tool_calls(tool_calls):
if tool_calls is None:
return None
def clean_tool_call(tool_call):
serialized = {
"type": tool_call.type,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
}
return serialized
return [clean_tool_call(tool_call) for tool_call in tool_calls]
def parse_messages(input):
if input is None:
return None
def clean_message(message):
# if is strin, return as is
# if is string, return as is
if isinstance(message, str):
return message
@ -35,9 +54,7 @@ def parse_messages(input):
# Only add tool_calls and function_call to res if they are set
if message.get("tool_calls"):
serialized["tool_calls"] = message.get("tool_calls")
if message.get("function_call"):
serialized["function_call"] = message.get("function_call")
serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
return serialized
@ -92,8 +109,13 @@ class LunaryLogger:
print_verbose(f"Lunary Logging - Logging request for model {model}")
litellm_params = kwargs.get("litellm_params", {})
optional_params = kwargs.get("optional_params", {})
metadata = litellm_params.get("metadata", {}) or {}
if optional_params:
# merge into extra
extra = {**extra, **optional_params}
tags = litellm_params.pop("tags", None) or []
if extra:
@ -103,7 +125,7 @@ class LunaryLogger:
# keep only serializable types
for param, value in extra.items():
if not isinstance(value, (str, int, bool, float)):
if not isinstance(value, (str, int, bool, float)) and param != "tools":
try:
extra[param] = str(value)
except:
@ -139,7 +161,7 @@ class LunaryLogger:
metadata=metadata,
runtime="litellm",
tags=tags,
extra=extra,
params=extra,
)
self.lunary_client.track_event(