litellm-mirror/tests/litellm/integrations/test_athina.py
2025-02-28 21:48:13 +05:30

207 lines
No EOL
8.1 KiB
Python

import unittest
from unittest.mock import patch, MagicMock, ANY
import json
import datetime
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
from litellm.integrations.athina import AthinaLogger
class TestAthinaLogger(unittest.TestCase):
def setUp(self):
# Set up environment variables for testing
self.env_patcher = patch.dict('os.environ', {
'ATHINA_API_KEY': 'test-api-key',
'ATHINA_BASE_URL': 'https://test.athina.ai'
})
self.env_patcher.start()
self.logger = AthinaLogger()
# Setup common test variables
self.start_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
self.end_time = datetime.datetime(2023, 1, 1, 12, 0, 1)
self.print_verbose = MagicMock()
def tearDown(self):
self.env_patcher.stop()
def test_init(self):
"""Test the initialization of AthinaLogger"""
self.assertEqual(self.logger.athina_api_key, 'test-api-key')
self.assertEqual(self.logger.athina_logging_url, 'https://test.athina.ai/api/v1/log/inference')
self.assertEqual(self.logger.headers, {
'athina-api-key': 'test-api-key',
'Content-Type': 'application/json'
})
@patch('litellm.module_level_client.post')
def test_log_event_success(self, mock_post):
"""Test successful logging of an event"""
# Setup mock response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "Success"
mock_post.return_value = mock_response
# Create test data
kwargs = {
'model': 'gpt-4',
'messages': [{'role': 'user', 'content': 'Hello'}],
'stream': False,
'litellm_params': {
'metadata': {
'environment': 'test-environment',
'prompt_slug': 'test-prompt',
'customer_id': 'test-customer',
'customer_user_id': 'test-user',
'session_id': 'test-session',
'external_reference_id': 'test-ext-ref',
'context': 'test-context',
'expected_response': 'test-expected',
'user_query': 'test-query',
'tags': ['test-tag'],
'user_feedback': 'test-feedback',
'model_options': {'test-opt': 'test-val'},
'custom_attributes': {'test-attr': 'test-val'}
}
}
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {
'id': 'resp-123',
'choices': [{'message': {'content': 'Hi there'}}],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 5,
'total_tokens': 15
}
}
# Call the method
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
# Verify the results
mock_post.assert_called_once()
call_args = mock_post.call_args
self.assertEqual(call_args[0][0], 'https://test.athina.ai/api/v1/log/inference')
self.assertEqual(call_args[1]['headers'], self.logger.headers)
# Parse and verify the sent data
sent_data = json.loads(call_args[1]['data'])
self.assertEqual(sent_data['language_model_id'], 'gpt-4')
self.assertEqual(sent_data['prompt'], kwargs['messages'])
self.assertEqual(sent_data['prompt_tokens'], 10)
self.assertEqual(sent_data['completion_tokens'], 5)
self.assertEqual(sent_data['total_tokens'], 15)
self.assertEqual(sent_data['response_time'], 1000) # 1 second = 1000ms
self.assertEqual(sent_data['customer_id'], 'test-customer')
self.assertEqual(sent_data['session_id'], 'test-session')
self.assertEqual(sent_data['environment'], 'test-environment')
self.assertEqual(sent_data['prompt_slug'], 'test-prompt')
self.assertEqual(sent_data['external_reference_id'], 'test-ext-ref')
self.assertEqual(sent_data['context'], 'test-context')
self.assertEqual(sent_data['expected_response'], 'test-expected')
self.assertEqual(sent_data['user_query'], 'test-query')
self.assertEqual(sent_data['tags'], ['test-tag'])
self.assertEqual(sent_data['user_feedback'], 'test-feedback')
self.assertEqual(sent_data['model_options'], {'test-opt': 'test-val'})
self.assertEqual(sent_data['custom_attributes'], {'test-attr': 'test-val'})
# Verify the print_verbose was called
self.print_verbose.assert_called_once_with("Athina Logger Succeeded - Success")
@patch('litellm.module_level_client.post')
def test_log_event_error_response(self, mock_post):
"""Test handling of error response from the API"""
# Setup mock error response
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_post.return_value = mock_response
# Create test data
kwargs = {
'model': 'gpt-4',
'messages': [{'role': 'user', 'content': 'Hello'}],
'stream': False
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {
'id': 'resp-123',
'choices': [{'message': {'content': 'Hi there'}}],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 5,
'total_tokens': 15
}
}
# Call the method
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
# Verify print_verbose was called with error message
self.print_verbose.assert_called_once_with("Athina Logger Error - Bad Request, 400")
@patch('litellm.module_level_client.post')
def test_log_event_exception(self, mock_post):
"""Test handling of exceptions during logging"""
# Setup mock to raise exception
mock_post.side_effect = Exception("Test exception")
# Create test data
kwargs = {
'model': 'gpt-4',
'messages': [{'role': 'user', 'content': 'Hello'}],
'stream': False
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {}
# Call the method
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
# Verify print_verbose was called with exception info
self.print_verbose.assert_called_once()
self.assertIn("Athina Logger Error - Test exception", self.print_verbose.call_args[0][0])
@patch('litellm.module_level_client.post')
def test_log_event_with_tools(self, mock_post):
"""Test logging with tools/functions data"""
# Setup mock response
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
# Create test data with tools
kwargs = {
'model': 'gpt-4',
'messages': [{'role': 'user', 'content': "What's the weather?"}],
'stream': False,
'optional_params': {
'tools': [{'type': 'function', 'function': {'name': 'get_weather'}}]
}
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {
'id': 'resp-123',
'usage': {'prompt_tokens': 10, 'completion_tokens': 5, 'total_tokens': 15}
}
# Call the method
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
# Verify the results
sent_data = json.loads(mock_post.call_args[1]['data'])
self.assertEqual(sent_data['tools'], [{'type': 'function', 'function': {'name': 'get_weather'}}])
if __name__ == '__main__':
unittest.main()