feat(langfuse/): support langfuse prompt management (#7073)

* feat(langfuse/): support langfuse prompt management

Initial working commit for langfuse prompt management support

Closes https://github.com/BerriAI/litellm/issues/6269

* test: update test

* fix(litellm_logging.py): suppress linting error
This commit is contained in:
Krish Dholakia 2024-12-06 23:10:22 -08:00 committed by GitHub
parent e4493248ae
commit 19a4273fda
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 186 additions and 2 deletions

View file

@ -63,6 +63,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"opik", "opik",
"argilla", "argilla",
"mlflow", "mlflow",
"langfuse",
] ]
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
_known_custom_logger_compatible_callbacks: List = list( _known_custom_logger_compatible_callbacks: List = list(

View file

@ -6,8 +6,11 @@ Used to get the LangFuseLogger for a given request
Handles Key/Team Based Langfuse Logging Handles Key/Team Based Langfuse Logging
""" """
import os
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
from packaging.version import Version
from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams
from .langfuse import LangFuseLogger, LangfuseLoggingConfig from .langfuse import LangFuseLogger, LangfuseLoggingConfig

View file

@ -0,0 +1,163 @@
"""
Call Hook for LiteLLM Proxy which allows Langfuse prompt management.
"""
import os
import traceback
from typing import Literal, Optional, Union
from packaging.version import Version
from litellm._logging import verbose_proxy_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.secret_managers.main import str_to_bool
class LangfusePromptManagement(CustomLogger):
def __init__(
self,
langfuse_public_key=None,
langfuse_secret=None,
langfuse_host=None,
flush_interval=1,
):
try:
import langfuse
from langfuse import Langfuse
except Exception as e:
raise Exception(
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
)
# Instance variables
self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
self.public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
self.langfuse_host = langfuse_host or os.getenv(
"LANGFUSE_HOST", "https://cloud.langfuse.com"
)
if not (
self.langfuse_host.startswith("http://")
or self.langfuse_host.startswith("https://")
):
# add http:// if unset, assume communicating over private network - e.g. render
self.langfuse_host = "http://" + self.langfuse_host
self.langfuse_release = os.getenv("LANGFUSE_RELEASE")
self.langfuse_debug = os.getenv("LANGFUSE_DEBUG")
self.langfuse_flush_interval = (
os.getenv("LANGFUSE_FLUSH_INTERVAL") or flush_interval
)
parameters = {
"public_key": self.public_key,
"secret_key": self.secret_key,
"host": self.langfuse_host,
"release": self.langfuse_release,
"debug": self.langfuse_debug,
"flush_interval": self.langfuse_flush_interval, # flush interval in seconds
}
if Version(langfuse.version.__version__) >= Version("2.6.0"):
parameters["sdk_integration"] = "litellm"
self.Langfuse = Langfuse(**parameters)
# set the current langfuse project id in the environ
# this is used by Alerting to link to the correct project
try:
project_id = self.Langfuse.client.projects.get().data[0].id
os.environ["LANGFUSE_PROJECT_ID"] = project_id
except Exception:
project_id = None
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
upstream_langfuse_debug = (
str_to_bool(self.upstream_langfuse_debug)
if self.upstream_langfuse_debug is not None
else None
)
self.upstream_langfuse_secret_key = os.getenv(
"UPSTREAM_LANGFUSE_SECRET_KEY"
)
self.upstream_langfuse_public_key = os.getenv(
"UPSTREAM_LANGFUSE_PUBLIC_KEY"
)
self.upstream_langfuse_host = os.getenv("UPSTREAM_LANGFUSE_HOST")
self.upstream_langfuse_release = os.getenv("UPSTREAM_LANGFUSE_RELEASE")
self.upstream_langfuse_debug = os.getenv("UPSTREAM_LANGFUSE_DEBUG")
self.upstream_langfuse = Langfuse(
public_key=self.upstream_langfuse_public_key,
secret_key=self.upstream_langfuse_secret_key,
host=self.upstream_langfuse_host,
release=self.upstream_langfuse_release,
debug=(
upstream_langfuse_debug
if upstream_langfuse_debug is not None
else False
),
)
else:
self.upstream_langfuse = None
def _compile_prompt(
self,
metadata: dict,
call_type: Union[Literal["completion"], Literal["text_completion"]],
) -> Optional[Union[str, list]]:
compiled_prompt: Optional[Union[str, list]] = None
if isinstance(metadata, dict):
langfuse_prompt_id = metadata.get("langfuse_prompt_id")
langfuse_prompt_variables = metadata.get("langfuse_prompt_variables") or {}
if (
langfuse_prompt_id
and isinstance(langfuse_prompt_id, str)
and isinstance(langfuse_prompt_variables, dict)
):
langfuse_prompt = self.Langfuse.get_prompt(langfuse_prompt_id)
compiled_prompt = langfuse_prompt.compile(**langfuse_prompt_variables)
return compiled_prompt
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Union[
Literal["completion"],
Literal["text_completion"],
Literal["embeddings"],
Literal["image_generation"],
Literal["moderation"],
Literal["audio_transcription"],
Literal["pass_through_endpoint"],
Literal["rerank"],
],
) -> Union[Exception, str, dict, None]:
metadata = data.get("metadata") or {}
compiled_prompt: Optional[Union[str, list]] = None
if call_type == "completion" or call_type == "text_completion":
compiled_prompt = self._compile_prompt(metadata, call_type)
if compiled_prompt is None:
return await super().async_pre_call_hook(
user_api_key_dict, cache, data, call_type
)
if call_type == "completion":
if isinstance(compiled_prompt, list):
data["messages"] = compiled_prompt + data["messages"]
else:
data["messages"] = [
{"role": "system", "content": compiled_prompt}
] + data["messages"]
elif call_type == "text_completion" and isinstance(compiled_prompt, str):
data["prompt"] = compiled_prompt + "\n" + data["prompt"]
verbose_proxy_logger.debug(
f"LangfusePromptManagement.async_pre_call_hook compiled_prompt: {compiled_prompt}, type: {type(compiled_prompt)}"
)
return await super().async_pre_call_hook(
user_api_key_dict, cache, data, call_type
)

View file

@ -75,6 +75,7 @@ from ..integrations.helicone import HeliconeLogger
from ..integrations.lago import LagoLogger from ..integrations.lago import LagoLogger
from ..integrations.langfuse.langfuse import LangFuseLogger from ..integrations.langfuse.langfuse import LangFuseLogger
from ..integrations.langfuse.langfuse_handler import LangFuseHandler from ..integrations.langfuse.langfuse_handler import LangFuseHandler
from ..integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement
from ..integrations.langsmith import LangsmithLogger from ..integrations.langsmith import LangsmithLogger
from ..integrations.literal_ai import LiteralAILogger from ..integrations.literal_ai import LiteralAILogger
from ..integrations.logfire_logger import LogfireLevel, LogfireLogger from ..integrations.logfire_logger import LogfireLevel, LogfireLogger
@ -2349,9 +2350,17 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
_mlflow_logger = MlflowLogger() _mlflow_logger = MlflowLogger()
_in_memory_loggers.append(_mlflow_logger) _in_memory_loggers.append(_mlflow_logger)
return _mlflow_logger # type: ignore return _mlflow_logger # type: ignore
elif logging_integration == "langfuse":
for callback in _in_memory_loggers:
if isinstance(callback, LangfusePromptManagement):
return callback
langfuse_logger = LangfusePromptManagement()
_in_memory_loggers.append(langfuse_logger)
return langfuse_logger # type: ignore
def get_custom_logger_compatible_class( def get_custom_logger_compatible_class( # noqa: PLR0915
logging_integration: litellm._custom_logger_compatible_callbacks_literal, logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Optional[CustomLogger]: ) -> Optional[CustomLogger]:
if logging_integration == "lago": if logging_integration == "lago":
@ -2402,6 +2411,10 @@ def get_custom_logger_compatible_class(
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if isinstance(callback, OpikLogger): if isinstance(callback, OpikLogger):
return callback return callback
elif logging_integration == "langfuse":
for callback in _in_memory_loggers:
if isinstance(callback, LangfusePromptManagement):
return callback
elif logging_integration == "otel": elif logging_integration == "otel":
from litellm.integrations.opentelemetry import OpenTelemetry from litellm.integrations.opentelemetry import OpenTelemetry

View file

@ -508,6 +508,7 @@ class ProxyLogging:
try: try:
for callback in litellm.callbacks: for callback in litellm.callbacks:
_callback = None _callback = None
if isinstance(callback, str): if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
@ -515,7 +516,6 @@ class ProxyLogging:
) )
else: else:
_callback = callback # type: ignore _callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomGuardrail): if _callback is not None and isinstance(_callback, CustomGuardrail):
from litellm.types.guardrails import GuardrailEventHooks from litellm.types.guardrails import GuardrailEventHooks

View file

@ -33,6 +33,9 @@ from litellm.integrations.opik.opik import OpikLogger
from litellm.integrations.opentelemetry import OpenTelemetry from litellm.integrations.opentelemetry import OpenTelemetry
from litellm.integrations.mlflow import MlflowLogger from litellm.integrations.mlflow import MlflowLogger
from litellm.integrations.argilla import ArgillaLogger from litellm.integrations.argilla import ArgillaLogger
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler
from unittest.mock import patch from unittest.mock import patch
@ -61,6 +64,7 @@ callback_class_str_to_classType = {
"arize": OpenTelemetry, "arize": OpenTelemetry,
"langtrace": OpenTelemetry, "langtrace": OpenTelemetry,
"mlflow": MlflowLogger, "mlflow": MlflowLogger,
"langfuse": LangfusePromptManagement,
} }
expected_env_vars = { expected_env_vars = {