fix(lunary.py): move parse_version to inside lunarylogger

This commit is contained in:
Krrish Dholakia 2024-04-03 13:52:42 -07:00
parent fcaa452ccd
commit 24d9fcb32c

View file

@ -4,11 +4,11 @@ from datetime import datetime, timezone
import traceback import traceback
import dotenv import dotenv
import importlib import importlib
from pkg_resources import parse_version
import sys import sys
dotenv.load_dotenv() dotenv.load_dotenv()
# convert to {completion: xx, tokens: xx} # convert to {completion: xx, tokens: xx}
def parse_usage(usage): def parse_usage(usage):
return { return {
@ -16,6 +16,7 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
} }
def parse_messages(input): def parse_messages(input):
if input is None: if input is None:
return None return None
@ -28,7 +29,6 @@ def parse_messages(input):
if "message" in message: if "message" in message:
return clean_message(message["message"]) return clean_message(message["message"])
serialized = { serialized = {
"role": message.get("role"), "role": message.get("role"),
"content": message.get("content"), "content": message.get("content"),
@ -56,10 +56,14 @@ class LunaryLogger:
def __init__(self): def __init__(self):
try: try:
import lunary import lunary
from pkg_resources import parse_version
version = importlib.metadata.version("lunary") version = importlib.metadata.version("lunary")
# if version < 0.1.43 then raise ImportError # if version < 0.1.43 then raise ImportError
if parse_version(version) < parse_version("0.1.43"): if parse_version(version) < parse_version("0.1.43"):
print("Lunary version outdated. Required: > 0.1.43. Upgrade via 'pip install lunary --upgrade'") print(
"Lunary version outdated. Required: > 0.1.43. Upgrade via 'pip install lunary --upgrade'"
)
raise ImportError raise ImportError
self.lunary_client = lunary self.lunary_client = lunary
@ -88,9 +92,7 @@ class LunaryLogger:
print_verbose(f"Lunary Logging - Logging request for model {model}") print_verbose(f"Lunary Logging - Logging request for model {model}")
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = ( metadata = litellm_params.get("metadata", {}) or {}
litellm_params.get("metadata", {}) or {}
)
tags = litellm_params.pop("tags", None) or [] tags = litellm_params.pop("tags", None) or []
@ -148,7 +150,7 @@ class LunaryLogger:
runtime="litellm", runtime="litellm",
error=error_obj, error=error_obj,
output=parse_messages(output), output=parse_messages(output),
token_usage=usage token_usage=usage,
) )
except: except: