fix opik types

This commit is contained in:
Ishaan Jaff 2024-10-10 18:37:53 +05:30
parent 4064bfc6dd
commit fbf756806e

View file

@ -2,27 +2,26 @@
Opik Logger that logs LLM events to an Opik server Opik Logger that logs LLM events to an Opik server
""" """
from typing import Dict, List import asyncio
import json import json
import traceback
from typing import Dict, List
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
import traceback from litellm.integrations.custom_batch_logger import CustomBatchLogger
from .utils import (
get_opik_config_variable,
create_uuid7,
create_usage_object,
get_traces_and_spans_from_payload
)
import asyncio
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider, httpxSpecialProvider,
) )
from litellm.integrations.custom_batch_logger import CustomBatchLogger from .utils import (
create_usage_object,
create_uuid7,
get_opik_config_variable,
get_traces_and_spans_from_payload,
)
class OpikLogger(CustomBatchLogger): class OpikLogger(CustomBatchLogger):
""" """
@ -38,41 +37,39 @@ class OpikLogger(CustomBatchLogger):
self.opik_project_name = get_opik_config_variable( self.opik_project_name = get_opik_config_variable(
"project_name", "project_name",
user_value=kwargs.get("project_name", None), user_value=kwargs.get("project_name", None),
default_value="Default Project" default_value="Default Project",
) )
opik_base_url = get_opik_config_variable( opik_base_url = get_opik_config_variable(
"url_override", "url_override",
user_value=kwargs.get("url", None), user_value=kwargs.get("url", None),
default_value="https://www.comet.com/opik/api" default_value="https://www.comet.com/opik/api",
) )
opik_api_key = get_opik_config_variable( opik_api_key = get_opik_config_variable(
"api_key", "api_key", user_value=kwargs.get("api_key", None), default_value=None
user_value=kwargs.get("api_key", None),
default_value=None
) )
opik_workspace = get_opik_config_variable( opik_workspace = get_opik_config_variable(
"workspace", "workspace", user_value=kwargs.get("workspace", None), default_value=None
user_value=kwargs.get("workspace", None),
default_value=None
) )
self.trace_url = f"{opik_base_url}/v1/private/traces/batch" self.trace_url = f"{opik_base_url}/v1/private/traces/batch"
self.span_url = f"{opik_base_url}/v1/private/spans/batch" self.span_url = f"{opik_base_url}/v1/private/spans/batch"
self.headers = {} self.headers = {}
if opik_workspace: if opik_workspace:
self.headers["Comet-Workspace"] = opik_workspace self.headers["Comet-Workspace"] = opik_workspace
if opik_api_key: if opik_api_key:
self.headers["authorization"] = opik_api_key self.headers["authorization"] = opik_api_key
self.opik_workspace = opik_workspace
self.opik_api_key = opik_api_key
try: try:
asyncio.create_task(self.periodic_flush()) asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock() self.flush_lock = asyncio.Lock()
except Exception as e: except Exception as e:
verbose_logger.debug( verbose_logger.exception(
f"OpikLogger - Asynchronous processing not initialized as we are not running in an async context" f"OpikLogger - Asynchronous processing not initialized as we are not running in an async context {str(e)}"
) )
self.flush_lock = None self.flush_lock = None
@ -84,11 +81,13 @@ class OpikLogger(CustomBatchLogger):
kwargs=kwargs, kwargs=kwargs,
response_obj=response_obj, response_obj=response_obj,
start_time=start_time, start_time=start_time,
end_time=end_time end_time=end_time,
) )
self.log_queue.extend(opik_payload) self.log_queue.extend(opik_payload)
verbose_logger.debug(f"OpikLogger added event to log_queue - Will flush in {self.flush_interval} seconds...") verbose_logger.debug(
f"OpikLogger added event to log_queue - Will flush in {self.flush_interval} seconds..."
)
if len(self.log_queue) >= self.batch_size: if len(self.log_queue) >= self.batch_size:
verbose_logger.debug("OpikLogger - Flushing batch") verbose_logger.debug("OpikLogger - Flushing batch")
@ -97,13 +96,11 @@ class OpikLogger(CustomBatchLogger):
verbose_logger.exception( verbose_logger.exception(
f"OpikLogger failed to log success event - {str(e)}\n{traceback.format_exc()}" f"OpikLogger failed to log success event - {str(e)}\n{traceback.format_exc()}"
) )
def _sync_send(self, url: str, headers: Dict[str, str], batch: List[Dict]): def _sync_send(self, url: str, headers: Dict[str, str], batch: Dict):
try: try:
response = self.sync_httpx_client.post( response = self.sync_httpx_client.post(
url=url, url=url, headers=headers, json=batch # type: ignore
headers=headers,
json=batch
) )
response.raise_for_status() response.raise_for_status()
if response.status_code != 204: if response.status_code != 204:
@ -121,25 +118,27 @@ class OpikLogger(CustomBatchLogger):
kwargs=kwargs, kwargs=kwargs,
response_obj=response_obj, response_obj=response_obj,
start_time=start_time, start_time=start_time,
end_time=end_time end_time=end_time,
) )
traces, spans = get_traces_and_spans_from_payload(opik_payload) traces, spans = get_traces_and_spans_from_payload(opik_payload)
if len(traces) > 0: if len(traces) > 0:
self._sync_send(self.trace_url, self.headers, {"traces": traces}) self._sync_send(
url=self.trace_url, headers=self.headers, batch={"traces": traces}
)
if len(spans) > 0: if len(spans) > 0:
self._sync_send(self.span_url, self.headers, {"spans": spans}) self._sync_send(
url=self.span_url, headers=self.headers, batch={"spans": spans}
)
except Exception as e: except Exception as e:
verbose_logger.exception( verbose_logger.exception(
f"OpikLogger failed to log success event - {str(e)}\n{traceback.format_exc()}" f"OpikLogger failed to log success event - {str(e)}\n{traceback.format_exc()}"
) )
async def _submit_batch(self, url: str, headers: Dict[str, str], batch: List[Dict]): async def _submit_batch(self, url: str, headers: Dict[str, str], batch: Dict):
try: try:
response = await self.async_httpx_client.post( response = await self.async_httpx_client.post(
url=url, url=url, headers=headers, json=batch # type: ignore
headers=headers,
json=batch
) )
response.raise_for_status() response.raise_for_status()
@ -160,37 +159,42 @@ class OpikLogger(CustomBatchLogger):
headers = {} headers = {}
if self.opik_workspace: if self.opik_workspace:
headers["Comet-Workspace"] = self.opik_workspace headers["Comet-Workspace"] = self.opik_workspace
if self.opik_api_key: if self.opik_api_key:
headers["authorization"] = self.opik_api_key headers["authorization"] = self.opik_api_key
return headers return headers
async def async_send_batch(self): async def async_send_batch(self):
verbose_logger.exception("Calling async_send_batch") verbose_logger.exception("Calling async_send_batch")
if not self.log_queue: if not self.log_queue:
return return
# Split the log_queue into traces and spans # Split the log_queue into traces and spans
traces, spans = get_traces_and_spans_from_payload(self.log_queue) traces, spans = get_traces_and_spans_from_payload(self.log_queue)
# Send trace batch # Send trace batch
if len(traces) > 0: if len(traces) > 0:
await self._submit_batch(self.trace_url, self.headers, {"traces": traces}) await self._submit_batch(
url=self.trace_url, headers=self.headers, batch={"traces": traces}
)
if len(spans) > 0: if len(spans) > 0:
await self._submit_batch(self.span_url, self.headers, {"spans": spans}) await self._submit_batch(
url=self.span_url, headers=self.headers, batch={"spans": spans}
)
def _create_opik_payload(
self, kwargs, response_obj, start_time, end_time
) -> List[Dict]:
def _create_opik_payload(self, kwargs, response_obj, start_time, end_time) -> List[Dict]:
# Get metadata # Get metadata
_litellm_params = kwargs.get("litellm_params", {}) or {} _litellm_params = kwargs.get("litellm_params", {}) or {}
litellm_params_metadata = _litellm_params.get("metadata", {}) or {} litellm_params_metadata = _litellm_params.get("metadata", {}) or {}
# Extract opik metadata # Extract opik metadata
litellm_opik_metadata = litellm_params_metadata.get("opik", {}) litellm_opik_metadata = litellm_params_metadata.get("opik", {})
verbose_logger.debug(f"litellm_opik_metadata - {json.dumps(litellm_opik_metadata, default=str)}") verbose_logger.debug(
f"litellm_opik_metadata - {json.dumps(litellm_opik_metadata, default=str)}"
)
project_name = litellm_opik_metadata.get("project_name", self.opik_project_name) project_name = litellm_opik_metadata.get("project_name", self.opik_project_name)
# Extract trace_id and parent_span_id # Extract trace_id and parent_span_id
@ -212,16 +216,18 @@ class OpikLogger(CustomBatchLogger):
# Use standard_logging_object to create metadata and input/output data # Use standard_logging_object to create metadata and input/output data
standard_logging_object = kwargs.get("standard_logging_object", None) standard_logging_object = kwargs.get("standard_logging_object", None)
if standard_logging_object is None: if standard_logging_object is None:
verbose_logger.debug("OpikLogger skipping event; no standard_logging_object found") verbose_logger.debug(
"OpikLogger skipping event; no standard_logging_object found"
)
return [] return []
# Create input and output data # Create input and output data
input_data = standard_logging_object.get("messages", {}) input_data = standard_logging_object.get("messages", {})
output_data = standard_logging_object.get("response", {}) output_data = standard_logging_object.get("response", {})
# Create usage object # Create usage object
usage = create_usage_object(response_obj["usage"]) usage = create_usage_object(response_obj["usage"])
# Define span and trace names # Define span and trace names
span_name = "%s_%s_%s" % ( span_name = "%s_%s_%s" % (
response_obj.get("model", "unknown-model"), response_obj.get("model", "unknown-model"),
@ -229,14 +235,14 @@ class OpikLogger(CustomBatchLogger):
response_obj.get("created", 0), response_obj.get("created", 0),
) )
trace_name = response_obj.get("object", "unknown type") trace_name = response_obj.get("object", "unknown type")
# Create metadata object, we add the opik metadata first and then # Create metadata object, we add the opik metadata first and then
# update it with the standard_logging_object metadata # update it with the standard_logging_object metadata
metadata = litellm_opik_metadata metadata = litellm_opik_metadata
if "current_span_data" in metadata: if "current_span_data" in metadata:
del metadata["current_span_data"] del metadata["current_span_data"]
metadata["created_from"] = "litellm" metadata["created_from"] = "litellm"
metadata.update(standard_logging_object.get("metadata", {})) metadata.update(standard_logging_object.get("metadata", {}))
if "call_type" in standard_logging_object: if "call_type" in standard_logging_object:
metadata["type"] = standard_logging_object["call_type"] metadata["type"] = standard_logging_object["call_type"]
@ -245,12 +251,16 @@ class OpikLogger(CustomBatchLogger):
if "response_cost" in kwargs: if "response_cost" in kwargs:
metadata["cost"] = { metadata["cost"] = {
"total_tokens": kwargs["response_cost"], "total_tokens": kwargs["response_cost"],
"currency": "USD" "currency": "USD",
} }
if "response_cost_failure_debug_info" in kwargs: if "response_cost_failure_debug_info" in kwargs:
metadata["response_cost_failure_debug_info"] = kwargs["response_cost_failure_debug_info"] metadata["response_cost_failure_debug_info"] = kwargs[
"response_cost_failure_debug_info"
]
if "model_map_information" in standard_logging_object: if "model_map_information" in standard_logging_object:
metadata["model_map_information"] = standard_logging_object["model_map_information"] metadata["model_map_information"] = standard_logging_object[
"model_map_information"
]
if "model" in standard_logging_object: if "model" in standard_logging_object:
metadata["model"] = standard_logging_object["model"] metadata["model"] = standard_logging_object["model"]
if "model_id" in standard_logging_object: if "model_id" in standard_logging_object:
@ -269,40 +279,48 @@ class OpikLogger(CustomBatchLogger):
metadata["model_parameters"] = standard_logging_object["model_parameters"] metadata["model_parameters"] = standard_logging_object["model_parameters"]
if "hidden_params" in standard_logging_object: if "hidden_params" in standard_logging_object:
metadata["hidden_params"] = standard_logging_object["hidden_params"] metadata["hidden_params"] = standard_logging_object["hidden_params"]
payload = [] payload = []
if trace_id is None: if trace_id is None:
trace_id = create_uuid7() trace_id = create_uuid7()
verbose_logger.debug(f"OpikLogger creating payload for trace with id {trace_id}") verbose_logger.debug(
f"OpikLogger creating payload for trace with id {trace_id}"
)
payload.append({ payload.append(
{
"project_name": project_name,
"id": trace_id,
"name": trace_name,
"start_time": start_time.isoformat() + "Z",
"end_time": end_time.isoformat() + "Z",
"input": input_data,
"output": output_data,
"metadata": metadata,
"tags": opik_tags,
}
)
span_id = create_uuid7()
verbose_logger.debug(
f"OpikLogger creating payload for trace with id {trace_id} and span with id {span_id}"
)
payload.append(
{
"id": span_id,
"project_name": project_name, "project_name": project_name,
"id": trace_id, "trace_id": trace_id,
"name": trace_name, "parent_span_id": parent_span_id,
"name": span_name,
"type": "llm",
"start_time": start_time.isoformat() + "Z", "start_time": start_time.isoformat() + "Z",
"end_time": end_time.isoformat() + "Z", "end_time": end_time.isoformat() + "Z",
"input": input_data, "input": input_data,
"output": output_data, "output": output_data,
"metadata": metadata, "metadata": metadata,
"tags": opik_tags, "tags": opik_tags,
}) "usage": usage,
}
span_id = create_uuid7() )
verbose_logger.debug(f"OpikLogger creating payload for trace with id {trace_id} and span with id {span_id}")
payload.append({
"id": span_id,
"project_name": project_name,
"trace_id": trace_id,
"parent_span_id": parent_span_id,
"name": span_name,
"type": "llm",
"start_time": start_time.isoformat() + "Z",
"end_time": end_time.isoformat() + "Z",
"input": input_data,
"output": output_data,
"metadata": metadata,
"tags": opik_tags,
"usage": usage
})
verbose_logger.debug(f"Payload: {payload}") verbose_logger.debug(f"Payload: {payload}")
return payload return payload