Fix team-based logging to langfuse + allow custom tokenizer on /token_counter endpoint (#7493)

* fix(langfuse_prompt_management.py): migrate dynamic logging to langfuse custom logger compatible class

* fix(langfuse_prompt_management.py): support failure callback logging to langfuse as well

* feat(proxy_server.py): support setting custom tokenizer on config.yaml

Allows customizing value for `/utils/token_counter`

* fix(proxy_server.py): fix linting errors

* test: skip if file not found

* style: cleanup unused import

* docs(configs.md): add docs on setting custom tokenizer
This commit is contained in:
Krish Dholakia 2024-12-31 23:18:41 -08:00 committed by GitHub
parent 6705e30d5d
commit 080de89cfb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 192 additions and 72 deletions

View file

@ -516,6 +516,32 @@ model_list:
$ litellm --config /path/to/config.yaml
```
### Set custom tokenizer
If you're using the [`/utils/token_counter` endpoint](https://litellm-api.up.railway.app/#/llm%20utils/token_counter_utils_token_counter_post), and want to set a custom huggingface tokenizer for a model, you can do so in the `config.yaml`
```yaml
model_list:
- model_name: openai-deepseek
litellm_params:
model: deepseek/deepseek-chat
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["restricted-models"]
custom_tokenizer:
identifier: deepseek-ai/DeepSeek-V3-Base
revision: main
auth_token: os.environ/HUGGINGFACE_API_KEY
```
**Spec**
```
custom_tokenizer:
identifier: str # huggingface model identifier
revision: str # huggingface model revision (usually 'main')
auth_token: Optional[str] # huggingface auth token
```
## General Settings `general_settings` (DB Connection, etc)
### Configure DB Pool Limits + Connection Timeouts

View file

@ -158,6 +158,7 @@ class LangFuseHandler:
Returns:
bool: True if the dynamic langfuse credentials are passed, False otherwise
"""
if (
standard_callback_dynamic_params.get("langfuse_host") is not None
or standard_callback_dynamic_params.get("langfuse_public_key") is not None

View file

@ -16,7 +16,11 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
from ...litellm_core_utils.specialty_caches.dynamic_logging_cache import (
DynamicLoggingCache,
)
from .langfuse import LangFuseLogger
from .langfuse_handler import LangFuseHandler
if TYPE_CHECKING:
from langfuse import Langfuse
@ -29,6 +33,8 @@ else:
PROMPT_CLIENT = Any
LangfuseClass = Any
in_memory_dynamic_logger_cache = DynamicLoggingCache()
@lru_cache(maxsize=10)
def langfuse_client_init(
@ -252,7 +258,15 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
return model, messages, non_default_params
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self._old_log_event(
standard_callback_dynamic_params = kwargs.get(
"standard_callback_dynamic_params"
)
langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
globalLangfuseLogger=self,
standard_callback_dynamic_params=standard_callback_dynamic_params,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
langfuse_logger_to_use._old_log_event(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
@ -262,13 +276,21 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
standard_callback_dynamic_params = kwargs.get(
"standard_callback_dynamic_params"
)
langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
globalLangfuseLogger=self,
standard_callback_dynamic_params=standard_callback_dynamic_params,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
standard_logging_object = cast(
Optional[StandardLoggingPayload],
kwargs.get("standard_logging_object", None),
)
if standard_logging_object is None:
return
self._old_log_event(
langfuse_logger_to_use._old_log_event(
start_time=start_time,
end_time=end_time,
response_obj=None,

View file

@ -0,0 +1,32 @@
from typing import Dict, Optional
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import StandardCallbackDynamicParams
def initialize_standard_callback_dynamic_params(
kwargs: Optional[Dict] = None,
) -> StandardCallbackDynamicParams:
"""
Initialize the standard callback dynamic params from the kwargs
checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams
"""
standard_callback_dynamic_params = StandardCallbackDynamicParams()
if kwargs:
_supported_callback_params = (
StandardCallbackDynamicParams.__annotations__.keys()
)
for param in _supported_callback_params:
if param in kwargs:
_param_value = kwargs.pop(param)
if (
_param_value is not None
and isinstance(_param_value, str)
and "os.environ/" in _param_value
):
_param_value = get_secret_str(secret_name=_param_value)
standard_callback_dynamic_params[param] = _param_value # type: ignore
return standard_callback_dynamic_params

View file

@ -94,7 +94,11 @@ from ..integrations.supabase import Supabase
from ..integrations.traceloop import TraceloopLogger
from ..integrations.weights_biases import WeightsBiasesLogger
from .exception_mapping_utils import _get_response_headers
from .initialize_dynamic_callback_params import (
initialize_standard_callback_dynamic_params as _initialize_standard_callback_dynamic_params,
)
from .logging_utils import _assemble_complete_response_from_streaming_chunks
from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache
try:
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
@ -156,39 +160,6 @@ class ServiceTraceIDCache:
return None
import hashlib
class DynamicLoggingCache:
"""
Prevent memory leaks caused by initializing new logging clients on each request.
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
"""
def __init__(self) -> None:
self.cache = InMemoryCache()
def get_cache_key(self, args: dict) -> str:
args_str = json.dumps(args, sort_keys=True)
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
return cache_key
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
response = self.cache.get_cache(key=key_name)
return response
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
self.cache.set_cache(key=key_name, value=logging_obj)
return None
in_memory_trace_id_cache = ServiceTraceIDCache()
in_memory_dynamic_logger_cache = DynamicLoggingCache()
@ -370,24 +341,7 @@ class Logging(LiteLLMLoggingBaseClass):
checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams
"""
from litellm.secret_managers.main import get_secret_str
standard_callback_dynamic_params = StandardCallbackDynamicParams()
if kwargs:
_supported_callback_params = (
StandardCallbackDynamicParams.__annotations__.keys()
)
for param in _supported_callback_params:
if param in kwargs:
_param_value = kwargs.pop(param)
if (
_param_value is not None
and isinstance(_param_value, str)
and "os.environ/" in _param_value
):
_param_value = get_secret_str(secret_name=_param_value)
standard_callback_dynamic_params[param] = _param_value # type: ignore
return standard_callback_dynamic_params
return _initialize_standard_callback_dynamic_params(kwargs)
def update_environment_variables(
self,
@ -963,7 +917,9 @@ class Logging(LiteLLMLoggingBaseClass):
def success_handler( # noqa: PLR0915
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
):
print_verbose(f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}")
verbose_logger.debug(
f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}"
)
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
@ -971,9 +927,7 @@ class Logging(LiteLLMLoggingBaseClass):
cache_hit=cache_hit,
standard_logging_object=kwargs.get("standard_logging_object", None),
)
# print(f"original response in success handler: {self.model_call_details['original_response']}")
try:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response: Optional[
@ -2015,6 +1969,7 @@ class Logging(LiteLLMLoggingBaseClass):
)
result = None # result sent to all loggers, init this to None incase it's not created
for callback in callbacks:
try:
if isinstance(callback, CustomLogger): # custom logger class

View file

@ -0,0 +1,35 @@
import hashlib
import json
from typing import Any, Optional
from ...caching import InMemoryCache
class DynamicLoggingCache:
"""
Prevent memory leaks caused by initializing new logging clients on each request.
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
"""
def __init__(self) -> None:
self.cache = InMemoryCache()
def get_cache_key(self, args: dict) -> str:
args_str = json.dumps(args, sort_keys=True)
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
return cache_key
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
response = self.cache.get_cache(key=key_name)
return response
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
self.cache.set_cache(key=key_name, value=logging_obj)
return None

View file

@ -5,17 +5,28 @@ model_list:
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["default-models"]
- model_name: openai/o1-*
- model_name: openai-deepseek
litellm_params:
model: openai/o1-*
model: deepseek/deepseek-chat
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["restricted-models"]
- model_name: azure-o1-preview
litellm_params:
model: azure/o1-preview
api_key: os.environ/AZURE_OPENAI_O1_KEY
api_base: os.environ/AZURE_API_BASE
model_info:
supports_native_streaming: True
access_groups: ["shared-models"]
custom_tokenizer:
identifier: deepseek-ai/DeepSeek-V3-Base
revision: main
auth_token: os.environ/HUGGINGFACE_API_KEY
litellm_settings:
default_team_settings:
- team_id: "team_1"
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PUBLIC_KEY
langfuse_secret: os.environ/LANGFUSE_SECRET_KEY
langfuse_host: os.environ/LANGFUSE_HOST
- team_id: "team_2"
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PROJECT3_PUBLIC
langfuse_secret: os.environ/LANGFUSE_PROJECT3_SECRET
langfuse_host: os.environ/LANGFUSE_HOST

View file

@ -274,6 +274,8 @@ from litellm.types.llms.anthropic import (
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import ModelInfo as RouterModelInfo
from litellm.types.router import RouterGeneralSettings, updateDeployment
from litellm.types.utils import CustomHuggingfaceTokenizer
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
@ -5526,11 +5528,16 @@ async def token_counter(request: TokenCountRequest):
deployment = None
litellm_model_name = None
model_info: Optional[ModelMapInfo] = None
if llm_router is not None:
# get 1 deployment corresponding to the model
for _model in llm_router.model_list:
if _model["model_name"] == request.model:
deployment = _model
model_info = llm_router.get_router_model_info(
deployment=deployment,
received_model_name=request.model,
)
break
if deployment is not None:
litellm_model_name = deployment.get("litellm_params", {}).get("model")
@ -5541,12 +5548,22 @@ async def token_counter(request: TokenCountRequest):
model_to_use = (
litellm_model_name or request.model
) # use litellm model name, if it's not avalable then fallback to request.model
_tokenizer_used = litellm.utils._select_tokenizer(model=model_to_use)
custom_tokenizer: Optional[CustomHuggingfaceTokenizer] = None
if model_info is not None:
custom_tokenizer = cast(
Optional[CustomHuggingfaceTokenizer],
model_info.get("custom_tokenizer", None),
)
_tokenizer_used = litellm.utils._select_tokenizer(
model=model_to_use, custom_tokenizer=custom_tokenizer
)
tokenizer_used = str(_tokenizer_used["type"])
total_tokens = token_counter(
model=model_to_use,
text=prompt,
messages=messages,
custom_tokenizer=_tokenizer_used,
)
return TokenCountResponse(
total_tokens=total_tokens,

View file

@ -1813,3 +1813,9 @@ class LiteLLMLoggingBaseClass:
self, original_response, input=None, api_key=None, additional_args={}
):
pass
class CustomHuggingfaceTokenizer(TypedDict):
identifier: str
revision: str # usually 'main'
auth_token: Optional[str]

View file

@ -126,6 +126,7 @@ from litellm.types.utils import (
ChatCompletionMessageToolCall,
Choices,
CostPerToken,
CustomHuggingfaceTokenizer,
Delta,
Embedding,
EmbeddingResponse,
@ -1242,10 +1243,21 @@ def _is_async_request(
return False
@lru_cache(maxsize=128)
def _select_tokenizer(
model: str,
model: str, custom_tokenizer: Optional[CustomHuggingfaceTokenizer] = None
):
if custom_tokenizer is not None:
custom_tokenizer = Tokenizer.from_pretrained(
custom_tokenizer["identifier"],
revision=custom_tokenizer["revision"],
auth_token=custom_tokenizer["auth_token"],
)
return {"type": "huggingface_tokenizer", "tokenizer": custom_tokenizer}
return _select_tokenizer_helper(model=model)
@lru_cache(maxsize=128)
def _select_tokenizer_helper(model: str):
if model in litellm.cohere_models and "command-r" in model:
# cohere
cohere_tokenizer = Tokenizer.from_pretrained(

View file

@ -210,9 +210,12 @@ def test_model_info_bedrock_converse(monkeypatch):
monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True")
litellm.model_cost = litellm.get_model_cost_map(url="")
try:
# Load whitelist models from file
with open("whitelisted_bedrock_models.txt", "r") as file:
whitelist_models = [line.strip() for line in file.readlines()]
except FileNotFoundError:
pytest.skip("whitelisted_bedrock_models.txt not found")
_enforce_bedrock_converse_models(
model_cost=litellm.model_cost, whitelist_models=whitelist_models