HumanLoop integration for Prompt Management (#7479)

* feat(humanloop.py): initial commit for humanloop prompt management integration

Closes https://github.com/BerriAI/litellm/issues/213

* feat(humanloop.py): working e2e humanloop prompt management integration

Closes https://github.com/BerriAI/litellm/issues/213

* fix(humanloop.py): fix linting errors

* fix: fix linting erro

* fix: fix test

* test: handle filenotfound error
This commit is contained in:
Krish Dholakia 2024-12-30 22:26:03 -08:00 committed by GitHub
parent 0178e75cd9
commit 77c13df55d
9 changed files with 310 additions and 39 deletions

View file

@ -26,6 +26,8 @@ from litellm.constants import (
DEFAULT_REPLICATE_POLLING_RETRIES,
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS,
LITELLM_CHAT_PROVIDERS,
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS,
OPENAI_CHAT_COMPLETION_PARAMS,
)
from litellm.types.guardrails import GuardrailItem
from litellm.proxy._types import (
@ -72,6 +74,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"argilla",
"mlflow",
"langfuse",
"humanloop",
]
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
_known_custom_logger_compatible_callbacks: List = list(

View file

@ -69,6 +69,45 @@ LITELLM_CHAT_PROVIDERS = [
"galadriel",
]
OPENAI_CHAT_COMPLETION_PARAMS = [
"functions",
"function_call",
"temperature",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_completion_tokens",
"modalities",
"prediction",
"audio",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"parallel_tool_calls",
"logprobs",
"top_logprobs",
"extra_headers",
]
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS = 60 # 1 minute
RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when converting response format to tool call
########################### Logging Callback Constants ###########################

View file

@ -0,0 +1,199 @@
"""
Humanloop integration
https://humanloop.com/
"""
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
import httpx
import litellm
from litellm.caching import DualCache
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams
from .custom_logger import CustomLogger
class PromptManagementClient(TypedDict):
prompt_id: str
prompt_template: List[AllMessageValues]
model: Optional[str]
optional_params: Optional[Dict[str, Any]]
class HumanLoopPromptManager(DualCache):
@property
def integration_name(self):
return "humanloop"
def _get_prompt_from_id_cache(
self, humanloop_prompt_id: str
) -> Optional[PromptManagementClient]:
return cast(
Optional[PromptManagementClient], self.get_cache(key=humanloop_prompt_id)
)
def _compile_prompt_helper(
self, prompt_template: List[AllMessageValues], prompt_variables: Dict[str, Any]
) -> List[AllMessageValues]:
"""
Helper function to compile the prompt by substituting variables in the template.
Args:
prompt_template: List[AllMessageValues]
prompt_variables (dict): A dictionary of variables to substitute into the prompt template.
Returns:
list: A list of dictionaries with variables substituted.
"""
compiled_prompts: List[AllMessageValues] = []
for template in prompt_template:
tc = template.get("content")
if tc and isinstance(tc, str):
formatted_template = tc.replace("{{", "{").replace("}}", "}")
compiled_content = formatted_template.format(**prompt_variables)
template["content"] = compiled_content
compiled_prompts.append(template)
return compiled_prompts
def _get_prompt_from_id_api(
self, humanloop_prompt_id: str, humanloop_api_key: str
) -> PromptManagementClient:
client = _get_httpx_client()
base_url = "https://api.humanloop.com/v5/prompts/{}".format(humanloop_prompt_id)
response = client.get(
url=base_url,
headers={
"X-Api-Key": humanloop_api_key,
"Content-Type": "application/json",
},
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise Exception(f"Error getting prompt from Humanloop: {e.response.text}")
json_response = response.json()
template_message = json_response["template"]
if isinstance(template_message, dict):
template_messages = [template_message]
elif isinstance(template_message, list):
template_messages = template_message
else:
raise ValueError(f"Invalid template message type: {type(template_message)}")
template_model = json_response["model"]
optional_params = {}
for k, v in json_response.items():
if k in litellm.OPENAI_CHAT_COMPLETION_PARAMS:
optional_params[k] = v
return PromptManagementClient(
prompt_id=humanloop_prompt_id,
prompt_template=cast(List[AllMessageValues], template_messages),
model=template_model,
optional_params=optional_params,
)
def _get_prompt_from_id(
self, humanloop_prompt_id: str, humanloop_api_key: str
) -> PromptManagementClient:
prompt = self._get_prompt_from_id_cache(humanloop_prompt_id)
if prompt is None:
prompt = self._get_prompt_from_id_api(
humanloop_prompt_id, humanloop_api_key
)
self.set_cache(
key=humanloop_prompt_id,
value=prompt,
ttl=litellm.HUMANLOOP_PROMPT_CACHE_TTL_SECONDS,
)
return prompt
def compile_prompt(
self,
prompt_template: List[AllMessageValues],
prompt_variables: Optional[dict],
) -> List[AllMessageValues]:
compiled_prompt: Optional[Union[str, list]] = None
if prompt_variables is None:
prompt_variables = {}
compiled_prompt = self._compile_prompt_helper(
prompt_template=prompt_template,
prompt_variables=prompt_variables,
)
return compiled_prompt
def _get_model_from_prompt(
self, prompt_management_client: PromptManagementClient, model: str
) -> str:
if prompt_management_client["model"] is not None:
return prompt_management_client["model"]
else:
return model.replace("{}/".format(self.integration_name), "")
prompt_manager = HumanLoopPromptManager()
class HumanloopLogger(CustomLogger):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
headers: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
humanloop_api_key = dynamic_callback_params.get(
"humanloop_api_key"
) or get_secret_str("HUMANLOOP_API_KEY")
if humanloop_api_key is None:
return super().get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
headers=headers,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
)
prompt_template = prompt_manager._get_prompt_from_id(
humanloop_prompt_id=prompt_id, humanloop_api_key=humanloop_api_key
)
updated_messages = prompt_manager.compile_prompt(
prompt_template=prompt_template["prompt_template"],
prompt_variables=prompt_variables,
)
prompt_template_optional_params = prompt_template["optional_params"] or {}
updated_non_default_params = {
**non_default_params,
**prompt_template_optional_params,
}
model = prompt_manager._get_model_from_prompt(
prompt_management_client=prompt_template, model=model
)
return model, updated_messages, updated_non_default_params

View file

@ -76,6 +76,7 @@ from ..integrations.galileo import GalileoObserve
from ..integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
from ..integrations.greenscale import GreenscaleLogger
from ..integrations.helicone import HeliconeLogger
from ..integrations.humanloop import HumanloopLogger
from ..integrations.lago import LagoLogger
from ..integrations.langfuse.langfuse import LangFuseLogger
from ..integrations.langfuse.langfuse_handler import LangFuseHandler
@ -446,6 +447,7 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_id: str,
prompt_variables: Optional[dict],
) -> Tuple[str, List[AllMessageValues], dict]:
for (
custom_logger_compatible_callback
) in litellm._known_custom_logger_compatible_callbacks:
@ -455,6 +457,7 @@ class Logging(LiteLLMLoggingBaseClass):
internal_usage_cache=None,
llm_router=None,
)
if custom_logger is None:
continue
model, messages, non_default_params = (
@ -2428,6 +2431,14 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
langfuse_logger = LangfusePromptManagement()
_in_memory_loggers.append(langfuse_logger)
return langfuse_logger # type: ignore
elif logging_integration == "humanloop":
for callback in _in_memory_loggers:
if isinstance(callback, HumanloopLogger):
return callback
humanloop_logger = HumanloopLogger()
_in_memory_loggers.append(humanloop_logger)
return humanloop_logger # type: ignore
def get_custom_logger_compatible_class( # noqa: PLR0915

View file

@ -923,44 +923,7 @@ def completion( # type: ignore # noqa: PLR0915
assistant_continue_message=assistant_continue_message,
)
######## end of unpacking kwargs ###########
openai_params = [
"functions",
"function_call",
"temperature",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_completion_tokens",
"modalities",
"prediction",
"audio",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"parallel_tool_calls",
"logprobs",
"top_logprobs",
"extra_headers",
]
openai_params = litellm.OPENAI_CHAT_COMPLETION_PARAMS
default_params = openai_params + all_litellm_params
litellm_params = {} # used to prevent unbound var errors
non_default_params = {

View file

@ -1601,6 +1601,9 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
langsmith_project: Optional[str]
langsmith_base_url: Optional[str]
# Humanloop dynamic params
humanloop_api_key: Optional[str]
# Logging settings
turn_off_message_logging: Optional[bool] # when true will not log messages

View file

@ -4534,3 +4534,17 @@ def test_langfuse_completion(monkeypatch):
prompt_variables={"user_message": "this is used"},
messages=[{"role": "user", "content": "this is ignored"}],
)
def test_humanloop_completion(monkeypatch):
monkeypatch.setenv(
"HUMANLOOP_API_KEY", "hl_sk_59c1206e110c3f5b9985f0de4d23e7cbc79c4c4ae18c9f14"
)
litellm.set_verbose = True
resp = litellm.completion(
model="humanloop/gpt-3.5-turbo",
humanloop_api_key=os.getenv("HUMANLOOP_API_KEY"),
prompt_id="pr_nmSOVpEdyYPm2DrOwCoOm",
prompt_variables={"person": "John"},
messages=[{"role": "user", "content": "Tell me a joke."}],
)

View file

@ -0,0 +1,35 @@
import os
import sys
import threading
from datetime import datetime
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import pytest
from litellm.integrations.humanloop import HumanLoopPromptManager
from litellm.types.utils import StandardCallbackDynamicParams
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
from unittest.mock import Mock, patch
def test_compile_prompt():
prompt_manager = HumanLoopPromptManager()
prompt_template = [
{
"content": "You are {{person}}. Answer questions as this person. Do not break character.",
"name": None,
"tool_call_id": None,
"role": "system",
"tool_calls": None,
}
]
prompt_variables = {"person": "John"}
compiled_prompt = prompt_manager._compile_prompt_helper(
prompt_template, prompt_variables
)
assert (
compiled_prompt[0]["content"]
== "You are John. Answer questions as this person. Do not break character."
)

View file

@ -35,6 +35,7 @@ from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
from litellm.integrations.azure_storage.azure_storage import AzureBlobStorageLogger
from litellm.integrations.humanloop import HumanloopLogger
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler
from unittest.mock import patch
@ -59,6 +60,7 @@ callback_class_str_to_classType = {
"argilla": ArgillaLogger,
"opentelemetry": OpenTelemetry,
"azure_storage": AzureBlobStorageLogger,
"humanloop": HumanloopLogger,
# OTEL compatible loggers
"logfire": OpenTelemetry,
"arize": OpenTelemetry,
@ -178,7 +180,9 @@ async def use_callback_in_llm_call(
assert isinstance(litellm.success_callback[0], expected_class)
assert isinstance(litellm.failure_callback[0], expected_class)
assert len(litellm._async_success_callback) == 1
assert (
len(litellm._async_success_callback) == 1
), f"Got={litellm._async_success_callback}"
assert len(litellm._async_failure_callback) == 1
assert len(litellm.success_callback) == 1
assert len(litellm.failure_callback) == 1