mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(Feat) - Add GCS Pub/Sub Logging integration for sending DB SpendLogs
to BigQuery (#7976)
* add pub_sub * fix custom batch logger for GCS PUB/SUB * GCS_PUBSUB_PROJECT_ID * e2e gcs pub sub * add gcs pub sub * fix logging * add GcsPubSubLogger * fix pub sub * add pub sub * docs gcs pub / sub * docs on pub sub controls * test_gcs_pub_sub * fix publish_message * test_async_gcs_pub_sub * test_async_gcs_pub_sub
This commit is contained in:
parent
c9a32ebf76
commit
74caef0843
11 changed files with 458 additions and 24 deletions
|
@ -367,6 +367,8 @@ router_settings:
|
|||
| GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file
|
||||
| GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS. **Default is 20 seconds**
|
||||
| GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs. **Default is 2048**
|
||||
| GCS_PUBSUB_TOPIC_ID | PubSub Topic ID to send LiteLLM SpendLogs to.
|
||||
| GCS_PUBSUB_PROJECT_ID | PubSub Project ID to send LiteLLM SpendLogs to.
|
||||
| GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers
|
||||
| GENERIC_CLIENT_ID | Client ID for generic OAuth providers
|
||||
| GENERIC_CLIENT_SECRET | Client secret for generic OAuth providers
|
||||
|
|
|
@ -1025,6 +1025,74 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
6. Save the JSON file and add the path to `GCS_PATH_SERVICE_ACCOUNT`
|
||||
|
||||
|
||||
|
||||
## Google Cloud Storage - PubSub Topic
|
||||
|
||||
Log LLM Logs/SpendLogs to [Google Cloud Storage PubSub Topic](https://cloud.google.com/pubsub/docs/reference/rest)
|
||||
|
||||
:::info
|
||||
|
||||
✨ This is an Enterprise only feature [Get Started with Enterprise here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
:::
|
||||
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Description | Log LiteLLM `SpendLogs Table` to Google Cloud Storage PubSub Topic |
|
||||
|
||||
When to use `gcs_pubsub`?
|
||||
|
||||
- If your LiteLLM Database has crossed 1M+ spend logs and you want to send `SpendLogs` to a PubSub Topic that can be consumed by GCS BigQuery
|
||||
|
||||
|
||||
#### Usage
|
||||
|
||||
1. Add `gcs_pubsub` to LiteLLM Config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- litellm_params:
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
api_key: my-fake-key
|
||||
model: openai/my-fake-model
|
||||
model_name: fake-openai-endpoint
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["gcs_pubsub"] # 👈 KEY CHANGE # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
2. Set required env variables
|
||||
|
||||
```shell
|
||||
GCS_PUBSUB_TOPIC_ID="litellmDB"
|
||||
GCS_PUBSUB_PROJECT_ID="reliableKeys"
|
||||
```
|
||||
|
||||
3. Start Proxy
|
||||
|
||||
```
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
4. Test it!
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "fake-openai-endpoint",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
|
||||
|
||||
## s3 Buckets
|
||||
|
||||
We will use the `--config` to set
|
||||
|
@ -1301,7 +1369,7 @@ LiteLLM supports customizing the following Datadog environment variables
|
|||
|
||||
|
||||
## Lunary
|
||||
### Step1: Install dependencies and set your environment variables
|
||||
#### Step1: Install dependencies and set your environment variables
|
||||
Install the dependencies
|
||||
```shell
|
||||
pip install litellm lunary
|
||||
|
@ -1312,7 +1380,7 @@ Get you Lunary public key from from https://app.lunary.ai/settings
|
|||
export LUNARY_PUBLIC_KEY="<your-public-key>"
|
||||
```
|
||||
|
||||
### Step 2: Create a `config.yaml` and set `lunary` callbacks
|
||||
#### Step 2: Create a `config.yaml` and set `lunary` callbacks
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
|
@ -1324,12 +1392,12 @@ litellm_settings:
|
|||
failure_callback: ["lunary"]
|
||||
```
|
||||
|
||||
### Step 3: Start the LiteLLM proxy
|
||||
#### Step 3: Start the LiteLLM proxy
|
||||
```shell
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
### Step 4: Make a request
|
||||
#### Step 4: Make a request
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
|
@ -1352,14 +1420,14 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
## MLflow
|
||||
|
||||
|
||||
### Step1: Install dependencies
|
||||
#### Step1: Install dependencies
|
||||
Install the dependencies.
|
||||
|
||||
```shell
|
||||
pip install litellm mlflow
|
||||
```
|
||||
|
||||
### Step 2: Create a `config.yaml` with `mlflow` callback
|
||||
#### Step 2: Create a `config.yaml` with `mlflow` callback
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
|
@ -1371,12 +1439,12 @@ litellm_settings:
|
|||
failure_callback: ["mlflow"]
|
||||
```
|
||||
|
||||
### Step 3: Start the LiteLLM proxy
|
||||
#### Step 3: Start the LiteLLM proxy
|
||||
```shell
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
### Step 4: Make a request
|
||||
#### Step 4: Make a request
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
|
@ -1392,7 +1460,7 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
}'
|
||||
```
|
||||
|
||||
### Step 5: Review traces
|
||||
#### Step 5: Review traces
|
||||
|
||||
Run the following command to start MLflow UI and review recorded traces.
|
||||
|
||||
|
|
|
@ -77,6 +77,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
|
|||
"langfuse",
|
||||
"pagerduty",
|
||||
"humanloop",
|
||||
"gcs_pubsub",
|
||||
]
|
||||
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
|
||||
_known_custom_logger_compatible_callbacks: List = list(
|
||||
|
|
5
litellm/integrations/Readme.md
Normal file
5
litellm/integrations/Readme.md
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Integrations
|
||||
|
||||
This folder contains logging integrations for litellm
|
||||
|
||||
eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc.
|
202
litellm/integrations/gcs_pubsub/pub_sub.py
Normal file
202
litellm/integrations/gcs_pubsub/pub_sub.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
BETA
|
||||
|
||||
This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub.
|
||||
|
||||
Users can use this instead of sending their SpendLogs to their Postgres database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import SpendLogsPayload
|
||||
else:
|
||||
SpendLogsPayload = Any
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
class GcsPubSubLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
topic_id: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize Google Cloud Pub/Sub publisher
|
||||
|
||||
Args:
|
||||
project_id (str): Google Cloud project ID
|
||||
topic_id (str): Pub/Sub topic ID
|
||||
credentials_path (str, optional): Path to Google Cloud credentials JSON file
|
||||
"""
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
_premium_user_check()
|
||||
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID")
|
||||
self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID")
|
||||
self.path_service_account_json = credentials_path or os.getenv(
|
||||
"GCS_PATH_SERVICE_ACCOUNT"
|
||||
)
|
||||
|
||||
if not self.project_id or not self.topic_id:
|
||||
raise ValueError("Both project_id and topic_id must be provided")
|
||||
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.log_queue: List[SpendLogsPayload] = []
|
||||
|
||||
async def construct_request_headers(self) -> Dict[str, str]:
|
||||
"""Construct authorization headers using Vertex AI auth"""
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
_auth_header, vertex_project = (
|
||||
await vertex_chat_completion._ensure_access_token_async(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="pub-sub",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
return headers
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to GCS PubSub Topic
|
||||
|
||||
- Creates a SpendLogsPayload
|
||||
- Adds to batch queue
|
||||
- Flushes based on CustomBatchLogger settings
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
get_logging_payload,
|
||||
)
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
_premium_user_check()
|
||||
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"PubSub: Logging - Enters logging function for model %s", kwargs
|
||||
)
|
||||
spend_logs_payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
self.log_queue.append(spend_logs_payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the batch of messages to Pub/Sub
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
f"PubSub - about to flush {len(self.log_queue)} events"
|
||||
)
|
||||
|
||||
for message in self.log_queue:
|
||||
await self.publish_message(message)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
finally:
|
||||
self.log_queue.clear()
|
||||
|
||||
async def publish_message(
|
||||
self, message: SpendLogsPayload
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Publish message to Google Cloud Pub/Sub using REST API
|
||||
|
||||
Args:
|
||||
message: Message to publish (dict or string)
|
||||
|
||||
Returns:
|
||||
dict: Published message response
|
||||
"""
|
||||
try:
|
||||
headers = await self.construct_request_headers()
|
||||
|
||||
# Prepare message data
|
||||
if isinstance(message, str):
|
||||
message_data = message
|
||||
else:
|
||||
message_data = json.dumps(message, default=str)
|
||||
|
||||
# Base64 encode the message
|
||||
import base64
|
||||
|
||||
encoded_message = base64.b64encode(message_data.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
# Construct request body
|
||||
request_body = {"messages": [{"data": encoded_message}]}
|
||||
|
||||
url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish"
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
url=url, headers=headers, json=request_body
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 202]:
|
||||
verbose_logger.error("Pub/Sub publish error: %s", str(response.text))
|
||||
raise Exception(f"Failed to publish message: {response.text}")
|
||||
|
||||
verbose_logger.debug("Pub/Sub response: %s", response.text)
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("Pub/Sub publish error: %s", str(e))
|
|
@ -77,6 +77,7 @@ from ..integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger
|
|||
from ..integrations.dynamodb import DyanmoDBLogger
|
||||
from ..integrations.galileo import GalileoObserve
|
||||
from ..integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
|
||||
from ..integrations.gcs_pubsub.pub_sub import GcsPubSubLogger
|
||||
from ..integrations.greenscale import GreenscaleLogger
|
||||
from ..integrations.helicone import HeliconeLogger
|
||||
from ..integrations.humanloop import HumanloopLogger
|
||||
|
@ -2571,6 +2572,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args)
|
||||
_in_memory_loggers.append(pagerduty_logger)
|
||||
return pagerduty_logger # type: ignore
|
||||
elif logging_integration == "gcs_pubsub":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, GcsPubSubLogger):
|
||||
return callback
|
||||
_gcs_pubsub_logger = GcsPubSubLogger()
|
||||
_in_memory_loggers.append(_gcs_pubsub_logger)
|
||||
return _gcs_pubsub_logger # type: ignore
|
||||
elif logging_integration == "humanloop":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, HumanloopLogger):
|
||||
|
@ -2704,6 +2712,10 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
|
|||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, PagerDutyAlerting):
|
||||
return callback
|
||||
elif logging_integration == "gcs_pubsub":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, GcsPubSubLogger):
|
||||
return callback
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
|
|
|
@ -58,7 +58,7 @@ from litellm.proxy.management_helpers.utils import (
|
|||
add_new_member,
|
||||
management_endpoint_wrapper,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.proxy.utils import PrismaClient, _premium_user_check
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -1527,15 +1527,3 @@ def _set_team_metadata_field(
|
|||
_premium_user_check()
|
||||
team_data.metadata = team_data.metadata or {}
|
||||
team_data.metadata[field_name] = value
|
||||
|
||||
|
||||
def _premium_user_check():
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if not premium_user:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": f"This feature is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}"
|
||||
},
|
||||
)
|
||||
|
|
|
@ -15,8 +15,7 @@ model_list:
|
|||
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
||||
|
||||
callbacks: ["gcs_pubsub"]
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "bedrock-pre-guard"
|
||||
|
@ -26,3 +25,4 @@ guardrails:
|
|||
guardrailIdentifier: gf3sc1mzinjw
|
||||
guardrailVersion: "DRAFT"
|
||||
default_on: true
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union, overload
|
|||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
CommonProxyErrors,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
)
|
||||
|
@ -2959,3 +2960,18 @@ def handle_exception_on_proxy(e: Exception) -> ProxyException:
|
|||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
def _premium_user_check():
|
||||
"""
|
||||
Raises an HTTPException if the user is not a premium user
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if not premium_user:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": f"This feature is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}"
|
||||
},
|
||||
)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"request_id": "chatcmpl-2283081b-dc89-41f6-93e6-d4f914774027",
|
||||
"call_type": "acompletion",
|
||||
"api_key": "",
|
||||
"cache_hit": "None",
|
||||
"startTime": "2025-01-24 09:20:46.847371",
|
||||
"endTime": "2025-01-24 09:20:46.851954",
|
||||
"completionStartTime": "2025-01-24 09:20:46.851954",
|
||||
"model": "gpt-4o",
|
||||
"user": "",
|
||||
"team_id": "",
|
||||
"metadata": "{\"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
|
||||
"cache_key": "Cache OFF",
|
||||
"spend": 0.00022500000000000002,
|
||||
"total_tokens": 30,
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"request_tags": "[]",
|
||||
"end_user": "",
|
||||
"api_base": "",
|
||||
"model_group": "",
|
||||
"model_id": "",
|
||||
"requester_ip_address": null,
|
||||
"custom_llm_provider": "openai",
|
||||
"messages": "{}",
|
||||
"response": "{}"
|
||||
}
|
113
tests/logging_callback_tests/test_gcs_pub_sub.py
Normal file
113
tests/logging_callback_tests/test_gcs_pub_sub.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
import io
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.gcs_pubsub.pub_sub import *
|
||||
from datetime import datetime, timedelta
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingModelInformation,
|
||||
StandardLoggingMetadata,
|
||||
StandardLoggingHiddenParams,
|
||||
)
|
||||
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def assert_gcs_pubsub_request_matches_expected(
|
||||
actual_request_body: dict,
|
||||
expected_file_name: str,
|
||||
):
|
||||
"""
|
||||
Helper function to compare actual GCS PubSub request body with expected JSON file.
|
||||
|
||||
Args:
|
||||
actual_request_body (dict): The actual request body received from the API call
|
||||
expected_file_name (str): Name of the JSON file containing expected request body
|
||||
"""
|
||||
# Get the current directory and read the expected request body
|
||||
pwd = os.path.dirname(os.path.realpath(__file__))
|
||||
expected_body_path = os.path.join(pwd, "gcs_pub_sub_body", expected_file_name)
|
||||
|
||||
with open(expected_body_path, "r") as f:
|
||||
expected_request_body = json.load(f)
|
||||
|
||||
# Replace dynamic values in actual request body
|
||||
time_fields = ["startTime", "endTime", "completionStartTime", "request_id"]
|
||||
for field in time_fields:
|
||||
if field in actual_request_body:
|
||||
actual_request_body[field] = expected_request_body[field]
|
||||
|
||||
# Assert the entire request body matches
|
||||
assert (
|
||||
actual_request_body == expected_request_body
|
||||
), f"Difference in request bodies: {json.dumps(actual_request_body, indent=2)} != {json.dumps(expected_request_body, indent=2)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_gcs_pub_sub():
|
||||
# Create a mock for the async_httpx_client's post method
|
||||
mock_post = AsyncMock()
|
||||
mock_post.return_value.status_code = 202
|
||||
mock_post.return_value.text = "Accepted"
|
||||
|
||||
# Initialize the GcsPubSubLogger and set the mock
|
||||
gcs_pub_sub_logger = GcsPubSubLogger(flush_interval=1)
|
||||
gcs_pub_sub_logger.async_httpx_client.post = mock_post
|
||||
|
||||
mock_construct_request_headers = AsyncMock()
|
||||
mock_construct_request_headers.return_value = {"Authorization": "Bearer mock_token"}
|
||||
gcs_pub_sub_logger.construct_request_headers = mock_construct_request_headers
|
||||
litellm.callbacks = [gcs_pub_sub_logger]
|
||||
|
||||
# Make the completion call
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
mock_response="hi",
|
||||
)
|
||||
|
||||
await asyncio.sleep(3) # Wait for async flush
|
||||
|
||||
# Assert httpx post was called
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Get the actual request body from the mock
|
||||
actual_url = mock_post.call_args[1]["url"]
|
||||
print("sent to url", actual_url)
|
||||
assert (
|
||||
actual_url
|
||||
== "https://pubsub.googleapis.com/v1/projects/reliableKeys/topics/litellmDB:publish"
|
||||
)
|
||||
actual_request = mock_post.call_args[1]["json"]
|
||||
|
||||
# Extract and decode the base64 encoded message
|
||||
encoded_message = actual_request["messages"][0]["data"]
|
||||
import base64
|
||||
|
||||
decoded_message = base64.b64decode(encoded_message).decode("utf-8")
|
||||
|
||||
# Parse the JSON string into a dictionary
|
||||
actual_request = json.loads(decoded_message)
|
||||
print("##########\n")
|
||||
print(json.dumps(actual_request, indent=4))
|
||||
print("##########\n")
|
||||
# Verify the request body matches expected format
|
||||
assert_gcs_pubsub_request_matches_expected(
|
||||
actual_request, "spend_logs_payload.json"
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue