(feat) fix api callback imports

This commit is contained in:
ishaan-jaff 2024-02-15 16:15:36 -08:00
parent 3e90acb750
commit 47b8715d25
3 changed files with 61 additions and 8 deletions

View file

@ -57,7 +57,9 @@ class GenericAPILogger:
# This is sync, because we run this in a separate thread. Running in a sepearate thread ensures it will never block an LLM API call # This is sync, because we run this in a separate thread. Running in a sepearate thread ensures it will never block an LLM API call
# Experience with s3, Langfuse shows that async logging events are complicated and can block LLM calls # Experience with s3, Langfuse shows that async logging events are complicated and can block LLM calls
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(
self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
):
try: try:
verbose_logger.debug( verbose_logger.debug(
f"GenericAPILogger Logging - Enters logging function for model {kwargs}" f"GenericAPILogger Logging - Enters logging function for model {kwargs}"

View file

@ -0,0 +1,45 @@
import sys
import os
import io, asyncio
# import logging
# logging.basicConfig(level=logging.DEBUG)
sys.path.insert(0, os.path.abspath("../.."))
print("Modified sys.path:", sys.path)
from litellm import completion
import litellm
litellm.num_retries = 3
import time, random
import pytest
@pytest.mark.asyncio
async def test_custom_api_logging():
try:
litellm.success_callback = ["generic"]
litellm.set_verbose = True
os.environ["GENERIC_LOGGER_ENDPOINT"] = "http://localhost:8000/log-event"
print("Testing generic api logging")
await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test"}],
max_tokens=10,
temperature=0.7,
user="ishaan-2",
)
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
finally:
# post, close log file and verify
# Reset stdout to the original value
print("Passed! Testing async s3 logging")
# test_s3_logging()

View file

@ -12,6 +12,7 @@ import litellm
import dotenv, json, traceback, threading, base64, ast import dotenv, json, traceback, threading, base64, ast
import subprocess, os import subprocess, os
from os.path import abspath, join, dirname
import litellm, openai import litellm, openai
import itertools import itertools
import random, uuid, requests import random, uuid, requests
@ -33,11 +34,11 @@ from dataclasses import (
# import pkg_resources # import pkg_resources
from importlib import resources from importlib import resources
# filename = pkg_resources.resource_filename(__name__, "llms/tokenizers") # # filename = pkg_resources.resource_filename(__name__, "llms/tokenizers")
filename = str(resources.files("llms").joinpath("tokenizers")) # filename = str(resources.files().joinpath("llms/tokenizers"))
os.environ[ # os.environ[
"TIKTOKEN_CACHE_DIR" # "TIKTOKEN_CACHE_DIR"
] = filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071 # ] = filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
import importlib.metadata import importlib.metadata
from ._logging import verbose_logger from ._logging import verbose_logger
@ -76,8 +77,13 @@ from .exceptions import (
UnprocessableEntityError, UnprocessableEntityError,
) )
# import enterprise features # Import Enterprise features
from ..enterprise.callbacks.api_callback import GenericAPILogger project_path = abspath(join(dirname(__file__), "..", ".."))
# Add the "enterprise" directory to sys.path
enterprise_path = abspath(join(project_path, "enterprise"))
sys.path.append(enterprise_path)
from enterprise.callbacks.generic_api_callback import GenericAPILogger
from typing import cast, List, Dict, Union, Optional, Literal, Any from typing import cast, List, Dict, Union, Optional, Literal, Any
from .caching import Cache from .caching import Cache