mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
317 lines
12 KiB
Python
317 lines
12 KiB
Python
#### What this does ####
|
|
# This file contains the LiteralAILogger class which is used to log steps to the LiteralAI observability platform.
|
|
import asyncio
|
|
import os
|
|
import uuid
|
|
from typing import List, Optional
|
|
|
|
import httpx
|
|
|
|
from litellm._logging import verbose_logger
|
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
HTTPHandler,
|
|
get_async_httpx_client,
|
|
httpxSpecialProvider,
|
|
)
|
|
from litellm.types.utils import StandardLoggingPayload
|
|
|
|
|
|
class LiteralAILogger(CustomBatchLogger):
|
|
def __init__(
|
|
self,
|
|
literalai_api_key=None,
|
|
literalai_api_url="https://cloud.getliteral.ai",
|
|
env=None,
|
|
**kwargs,
|
|
):
|
|
self.literalai_api_url = os.getenv("LITERAL_API_URL") or literalai_api_url
|
|
self.headers = {
|
|
"Content-Type": "application/json",
|
|
"x-api-key": literalai_api_key or os.getenv("LITERAL_API_KEY"),
|
|
"x-client-name": "litellm",
|
|
}
|
|
if env:
|
|
self.headers["x-env"] = env
|
|
self.async_httpx_client = get_async_httpx_client(
|
|
llm_provider=httpxSpecialProvider.LoggingCallback
|
|
)
|
|
self.sync_http_handler = HTTPHandler()
|
|
batch_size = os.getenv("LITERAL_BATCH_SIZE", None)
|
|
self.flush_lock = asyncio.Lock()
|
|
super().__init__(
|
|
**kwargs,
|
|
flush_lock=self.flush_lock,
|
|
batch_size=int(batch_size) if batch_size else None,
|
|
)
|
|
|
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
try:
|
|
verbose_logger.debug(
|
|
"Literal AI Layer Logging - kwargs: %s, response_obj: %s",
|
|
kwargs,
|
|
response_obj,
|
|
)
|
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
|
self.log_queue.append(data)
|
|
verbose_logger.debug(
|
|
"Literal AI logging: queue length %s, batch size %s",
|
|
len(self.log_queue),
|
|
self.batch_size,
|
|
)
|
|
if len(self.log_queue) >= self.batch_size:
|
|
self._send_batch()
|
|
except Exception:
|
|
verbose_logger.exception(
|
|
"Literal AI Layer Error - error logging success event."
|
|
)
|
|
|
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
verbose_logger.info("Literal AI Failure Event Logging!")
|
|
try:
|
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
|
self.log_queue.append(data)
|
|
verbose_logger.debug(
|
|
"Literal AI logging: queue length %s, batch size %s",
|
|
len(self.log_queue),
|
|
self.batch_size,
|
|
)
|
|
if len(self.log_queue) >= self.batch_size:
|
|
self._send_batch()
|
|
except Exception:
|
|
verbose_logger.exception(
|
|
"Literal AI Layer Error - error logging failure event."
|
|
)
|
|
|
|
def _send_batch(self):
|
|
if not self.log_queue:
|
|
return
|
|
|
|
url = f"{self.literalai_api_url}/api/graphql"
|
|
query = self._steps_query_builder(self.log_queue)
|
|
variables = self._steps_variables_builder(self.log_queue)
|
|
try:
|
|
response = self.sync_http_handler.post(
|
|
url=url,
|
|
json={
|
|
"query": query,
|
|
"variables": variables,
|
|
},
|
|
headers=self.headers,
|
|
)
|
|
|
|
if response.status_code >= 300:
|
|
verbose_logger.error(
|
|
f"Literal AI Error: {response.status_code} - {response.text}"
|
|
)
|
|
else:
|
|
verbose_logger.debug(
|
|
f"Batch of {len(self.log_queue)} runs successfully created"
|
|
)
|
|
except Exception:
|
|
verbose_logger.exception("Literal AI Layer Error")
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
try:
|
|
verbose_logger.debug(
|
|
"Literal AI Async Layer Logging - kwargs: %s, response_obj: %s",
|
|
kwargs,
|
|
response_obj,
|
|
)
|
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
|
self.log_queue.append(data)
|
|
verbose_logger.debug(
|
|
"Literal AI logging: queue length %s, batch size %s",
|
|
len(self.log_queue),
|
|
self.batch_size,
|
|
)
|
|
if len(self.log_queue) >= self.batch_size:
|
|
await self.flush_queue()
|
|
except Exception:
|
|
verbose_logger.exception(
|
|
"Literal AI Layer Error - error logging async success event."
|
|
)
|
|
|
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
verbose_logger.info("Literal AI Failure Event Logging!")
|
|
try:
|
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
|
self.log_queue.append(data)
|
|
verbose_logger.debug(
|
|
"Literal AI logging: queue length %s, batch size %s",
|
|
len(self.log_queue),
|
|
self.batch_size,
|
|
)
|
|
if len(self.log_queue) >= self.batch_size:
|
|
await self.flush_queue()
|
|
except Exception:
|
|
verbose_logger.exception(
|
|
"Literal AI Layer Error - error logging async failure event."
|
|
)
|
|
|
|
async def async_send_batch(self):
|
|
if not self.log_queue:
|
|
return
|
|
|
|
url = f"{self.literalai_api_url}/api/graphql"
|
|
query = self._steps_query_builder(self.log_queue)
|
|
variables = self._steps_variables_builder(self.log_queue)
|
|
|
|
try:
|
|
response = await self.async_httpx_client.post(
|
|
url=url,
|
|
json={
|
|
"query": query,
|
|
"variables": variables,
|
|
},
|
|
headers=self.headers,
|
|
)
|
|
if response.status_code >= 300:
|
|
verbose_logger.error(
|
|
f"Literal AI Error: {response.status_code} - {response.text}"
|
|
)
|
|
else:
|
|
verbose_logger.debug(
|
|
f"Batch of {len(self.log_queue)} runs successfully created"
|
|
)
|
|
except httpx.HTTPStatusError as e:
|
|
verbose_logger.exception(
|
|
f"Literal AI HTTP Error: {e.response.status_code} - {e.response.text}"
|
|
)
|
|
except Exception:
|
|
verbose_logger.exception("Literal AI Layer Error")
|
|
|
|
def _prepare_log_data(self, kwargs, response_obj, start_time, end_time) -> dict:
|
|
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
|
"standard_logging_object", None
|
|
)
|
|
|
|
if logging_payload is None:
|
|
raise ValueError("standard_logging_object not found in kwargs")
|
|
clean_metadata = logging_payload["metadata"]
|
|
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
|
|
|
settings = logging_payload["model_parameters"]
|
|
messages = logging_payload["messages"]
|
|
response = logging_payload["response"]
|
|
choices: List = []
|
|
if isinstance(response, dict) and "choices" in response:
|
|
choices = response["choices"]
|
|
message_completion = choices[0]["message"] if choices else None
|
|
prompt_id = None
|
|
variables = None
|
|
|
|
if messages and isinstance(messages, list) and isinstance(messages[0], dict):
|
|
for message in messages:
|
|
if literal_prompt := getattr(message, "__literal_prompt__", None):
|
|
prompt_id = literal_prompt.get("prompt_id")
|
|
variables = literal_prompt.get("variables")
|
|
message["uuid"] = literal_prompt.get("uuid")
|
|
message["templated"] = True
|
|
|
|
tools = settings.pop("tools", None)
|
|
|
|
step = {
|
|
"id": metadata.get("step_id", str(uuid.uuid4())),
|
|
"error": logging_payload["error_str"],
|
|
"name": kwargs.get("model", ""),
|
|
"threadId": metadata.get("literalai_thread_id", None),
|
|
"parentId": metadata.get("literalai_parent_id", None),
|
|
"rootRunId": metadata.get("literalai_root_run_id", None),
|
|
"input": None,
|
|
"output": None,
|
|
"type": "llm",
|
|
"tags": metadata.get("tags", metadata.get("literalai_tags", None)),
|
|
"startTime": str(start_time),
|
|
"endTime": str(end_time),
|
|
"metadata": clean_metadata,
|
|
"generation": {
|
|
"inputTokenCount": logging_payload["prompt_tokens"],
|
|
"outputTokenCount": logging_payload["completion_tokens"],
|
|
"tokenCount": logging_payload["total_tokens"],
|
|
"promptId": prompt_id,
|
|
"variables": variables,
|
|
"provider": kwargs.get("custom_llm_provider", "litellm"),
|
|
"model": kwargs.get("model", ""),
|
|
"duration": (end_time - start_time).total_seconds(),
|
|
"settings": settings,
|
|
"messages": messages,
|
|
"messageCompletion": message_completion,
|
|
"tools": tools,
|
|
},
|
|
}
|
|
return step
|
|
|
|
def _steps_query_variables_builder(self, steps):
|
|
generated = ""
|
|
for id in range(len(steps)):
|
|
generated += f"""$id_{id}: String!
|
|
$threadId_{id}: String
|
|
$rootRunId_{id}: String
|
|
$type_{id}: StepType
|
|
$startTime_{id}: DateTime
|
|
$endTime_{id}: DateTime
|
|
$error_{id}: String
|
|
$input_{id}: Json
|
|
$output_{id}: Json
|
|
$metadata_{id}: Json
|
|
$parentId_{id}: String
|
|
$name_{id}: String
|
|
$tags_{id}: [String!]
|
|
$generation_{id}: GenerationPayloadInput
|
|
$scores_{id}: [ScorePayloadInput!]
|
|
$attachments_{id}: [AttachmentPayloadInput!]
|
|
"""
|
|
return generated
|
|
|
|
def _steps_ingest_steps_builder(self, steps):
|
|
generated = ""
|
|
for id in range(len(steps)):
|
|
generated += f"""
|
|
step{id}: ingestStep(
|
|
id: $id_{id}
|
|
threadId: $threadId_{id}
|
|
rootRunId: $rootRunId_{id}
|
|
startTime: $startTime_{id}
|
|
endTime: $endTime_{id}
|
|
type: $type_{id}
|
|
error: $error_{id}
|
|
input: $input_{id}
|
|
output: $output_{id}
|
|
metadata: $metadata_{id}
|
|
parentId: $parentId_{id}
|
|
name: $name_{id}
|
|
tags: $tags_{id}
|
|
generation: $generation_{id}
|
|
scores: $scores_{id}
|
|
attachments: $attachments_{id}
|
|
) {{
|
|
ok
|
|
message
|
|
}}
|
|
"""
|
|
return generated
|
|
|
|
def _steps_query_builder(self, steps):
|
|
return f"""
|
|
mutation AddStep({self._steps_query_variables_builder(steps)}) {{
|
|
{self._steps_ingest_steps_builder(steps)}
|
|
}}
|
|
"""
|
|
|
|
def _steps_variables_builder(self, steps):
|
|
def serialize_step(event, id):
|
|
result = {}
|
|
|
|
for key, value in event.items():
|
|
# Only keep the keys that are not None to avoid overriding existing values
|
|
if value is not None:
|
|
result[f"{key}_{id}"] = value
|
|
|
|
return result
|
|
|
|
variables = {}
|
|
for i in range(len(steps)):
|
|
step = steps[i]
|
|
variables.update(serialize_step(step, i))
|
|
return variables
|