mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Fix team-based logging to langfuse + allow custom tokenizer on /token_counter
endpoint (#7493)
* fix(langfuse_prompt_management.py): migrate dynamic logging to langfuse custom logger compatible class * fix(langfuse_prompt_management.py): support failure callback logging to langfuse as well * feat(proxy_server.py): support setting custom tokenizer on config.yaml Allows customizing value for `/utils/token_counter` * fix(proxy_server.py): fix linting errors * test: skip if file not found * style: cleanup unused import * docs(configs.md): add docs on setting custom tokenizer
This commit is contained in:
parent
6705e30d5d
commit
080de89cfb
11 changed files with 192 additions and 72 deletions
|
@ -516,6 +516,32 @@ model_list:
|
|||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
### Set custom tokenizer
|
||||
|
||||
If you're using the [`/utils/token_counter` endpoint](https://litellm-api.up.railway.app/#/llm%20utils/token_counter_utils_token_counter_post), and want to set a custom huggingface tokenizer for a model, you can do so in the `config.yaml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: openai-deepseek
|
||||
litellm_params:
|
||||
model: deepseek/deepseek-chat
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["restricted-models"]
|
||||
custom_tokenizer:
|
||||
identifier: deepseek-ai/DeepSeek-V3-Base
|
||||
revision: main
|
||||
auth_token: os.environ/HUGGINGFACE_API_KEY
|
||||
```
|
||||
|
||||
**Spec**
|
||||
```
|
||||
custom_tokenizer:
|
||||
identifier: str # huggingface model identifier
|
||||
revision: str # huggingface model revision (usually 'main')
|
||||
auth_token: Optional[str] # huggingface auth token
|
||||
```
|
||||
|
||||
## General Settings `general_settings` (DB Connection, etc)
|
||||
|
||||
### Configure DB Pool Limits + Connection Timeouts
|
||||
|
|
|
@ -158,6 +158,7 @@ class LangFuseHandler:
|
|||
Returns:
|
||||
bool: True if the dynamic langfuse credentials are passed, False otherwise
|
||||
"""
|
||||
|
||||
if (
|
||||
standard_callback_dynamic_params.get("langfuse_host") is not None
|
||||
or standard_callback_dynamic_params.get("langfuse_public_key") is not None
|
||||
|
|
|
@ -16,7 +16,11 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
|
||||
from ...litellm_core_utils.specialty_caches.dynamic_logging_cache import (
|
||||
DynamicLoggingCache,
|
||||
)
|
||||
from .langfuse import LangFuseLogger
|
||||
from .langfuse_handler import LangFuseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langfuse import Langfuse
|
||||
|
@ -29,6 +33,8 @@ else:
|
|||
PROMPT_CLIENT = Any
|
||||
LangfuseClass = Any
|
||||
|
||||
in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def langfuse_client_init(
|
||||
|
@ -252,7 +258,15 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
|
|||
return model, messages, non_default_params
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
self._old_log_event(
|
||||
standard_callback_dynamic_params = kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
)
|
||||
langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
|
||||
globalLangfuseLogger=self,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
langfuse_logger_to_use._old_log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
|
@ -262,13 +276,21 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
|
|||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
standard_callback_dynamic_params = kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
)
|
||||
langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
|
||||
globalLangfuseLogger=self,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
standard_logging_object = cast(
|
||||
Optional[StandardLoggingPayload],
|
||||
kwargs.get("standard_logging_object", None),
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
self._old_log_event(
|
||||
langfuse_logger_to_use._old_log_event(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
response_obj=None,
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Dict, Optional
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
def initialize_standard_callback_dynamic_params(
|
||||
kwargs: Optional[Dict] = None,
|
||||
) -> StandardCallbackDynamicParams:
|
||||
"""
|
||||
Initialize the standard callback dynamic params from the kwargs
|
||||
|
||||
checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams
|
||||
"""
|
||||
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
if kwargs:
|
||||
_supported_callback_params = (
|
||||
StandardCallbackDynamicParams.__annotations__.keys()
|
||||
)
|
||||
for param in _supported_callback_params:
|
||||
if param in kwargs:
|
||||
_param_value = kwargs.pop(param)
|
||||
if (
|
||||
_param_value is not None
|
||||
and isinstance(_param_value, str)
|
||||
and "os.environ/" in _param_value
|
||||
):
|
||||
_param_value = get_secret_str(secret_name=_param_value)
|
||||
standard_callback_dynamic_params[param] = _param_value # type: ignore
|
||||
|
||||
return standard_callback_dynamic_params
|
|
@ -94,7 +94,11 @@ from ..integrations.supabase import Supabase
|
|||
from ..integrations.traceloop import TraceloopLogger
|
||||
from ..integrations.weights_biases import WeightsBiasesLogger
|
||||
from .exception_mapping_utils import _get_response_headers
|
||||
from .initialize_dynamic_callback_params import (
|
||||
initialize_standard_callback_dynamic_params as _initialize_standard_callback_dynamic_params,
|
||||
)
|
||||
from .logging_utils import _assemble_complete_response_from_streaming_chunks
|
||||
from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache
|
||||
|
||||
try:
|
||||
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
|
||||
|
@ -156,39 +160,6 @@ class ServiceTraceIDCache:
|
|||
return None
|
||||
|
||||
|
||||
import hashlib
|
||||
|
||||
|
||||
class DynamicLoggingCache:
|
||||
"""
|
||||
Prevent memory leaks caused by initializing new logging clients on each request.
|
||||
|
||||
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.cache = InMemoryCache()
|
||||
|
||||
def get_cache_key(self, args: dict) -> str:
|
||||
args_str = json.dumps(args, sort_keys=True)
|
||||
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
|
||||
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
|
||||
key_name = self.get_cache_key(
|
||||
args={**credentials, "service_name": service_name}
|
||||
)
|
||||
response = self.cache.get_cache(key=key_name)
|
||||
return response
|
||||
|
||||
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
|
||||
key_name = self.get_cache_key(
|
||||
args={**credentials, "service_name": service_name}
|
||||
)
|
||||
self.cache.set_cache(key=key_name, value=logging_obj)
|
||||
return None
|
||||
|
||||
|
||||
in_memory_trace_id_cache = ServiceTraceIDCache()
|
||||
in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
||||
|
||||
|
@ -370,24 +341,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
if kwargs:
|
||||
_supported_callback_params = (
|
||||
StandardCallbackDynamicParams.__annotations__.keys()
|
||||
)
|
||||
for param in _supported_callback_params:
|
||||
if param in kwargs:
|
||||
_param_value = kwargs.pop(param)
|
||||
if (
|
||||
_param_value is not None
|
||||
and isinstance(_param_value, str)
|
||||
and "os.environ/" in _param_value
|
||||
):
|
||||
_param_value = get_secret_str(secret_name=_param_value)
|
||||
standard_callback_dynamic_params[param] = _param_value # type: ignore
|
||||
return standard_callback_dynamic_params
|
||||
return _initialize_standard_callback_dynamic_params(kwargs)
|
||||
|
||||
def update_environment_variables(
|
||||
self,
|
||||
|
@ -963,7 +917,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
def success_handler( # noqa: PLR0915
|
||||
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
|
||||
):
|
||||
print_verbose(f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}")
|
||||
verbose_logger.debug(
|
||||
f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}"
|
||||
)
|
||||
start_time, end_time, result = self._success_handler_helper_fn(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -971,9 +927,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
cache_hit=cache_hit,
|
||||
standard_logging_object=kwargs.get("standard_logging_object", None),
|
||||
)
|
||||
# print(f"original response in success handler: {self.model_call_details['original_response']}")
|
||||
try:
|
||||
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
|
||||
|
||||
## BUILD COMPLETE STREAMED RESPONSE
|
||||
complete_streaming_response: Optional[
|
||||
|
@ -2015,6 +1969,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
)
|
||||
|
||||
result = None # result sent to all loggers, init this to None incase it's not created
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
import hashlib
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from ...caching import InMemoryCache
|
||||
|
||||
|
||||
class DynamicLoggingCache:
|
||||
"""
|
||||
Prevent memory leaks caused by initializing new logging clients on each request.
|
||||
|
||||
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.cache = InMemoryCache()
|
||||
|
||||
def get_cache_key(self, args: dict) -> str:
|
||||
args_str = json.dumps(args, sort_keys=True)
|
||||
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
|
||||
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
|
||||
key_name = self.get_cache_key(
|
||||
args={**credentials, "service_name": service_name}
|
||||
)
|
||||
response = self.cache.get_cache(key=key_name)
|
||||
return response
|
||||
|
||||
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
|
||||
key_name = self.get_cache_key(
|
||||
args={**credentials, "service_name": service_name}
|
||||
)
|
||||
self.cache.set_cache(key=key_name, value=logging_obj)
|
||||
return None
|
|
@ -5,17 +5,28 @@ model_list:
|
|||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["default-models"]
|
||||
- model_name: openai/o1-*
|
||||
- model_name: openai-deepseek
|
||||
litellm_params:
|
||||
model: openai/o1-*
|
||||
model: deepseek/deepseek-chat
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["restricted-models"]
|
||||
- model_name: azure-o1-preview
|
||||
litellm_params:
|
||||
model: azure/o1-preview
|
||||
api_key: os.environ/AZURE_OPENAI_O1_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
model_info:
|
||||
supports_native_streaming: True
|
||||
access_groups: ["shared-models"]
|
||||
custom_tokenizer:
|
||||
identifier: deepseek-ai/DeepSeek-V3-Base
|
||||
revision: main
|
||||
auth_token: os.environ/HUGGINGFACE_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
default_team_settings:
|
||||
- team_id: "team_1"
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
langfuse_public_key: os.environ/LANGFUSE_PUBLIC_KEY
|
||||
langfuse_secret: os.environ/LANGFUSE_SECRET_KEY
|
||||
langfuse_host: os.environ/LANGFUSE_HOST
|
||||
- team_id: "team_2"
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
langfuse_public_key: os.environ/LANGFUSE_PROJECT3_PUBLIC
|
||||
langfuse_secret: os.environ/LANGFUSE_PROJECT3_SECRET
|
||||
langfuse_host: os.environ/LANGFUSE_HOST
|
||||
|
|
|
@ -274,6 +274,8 @@ from litellm.types.llms.anthropic import (
|
|||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.router import ModelInfo as RouterModelInfo
|
||||
from litellm.types.router import RouterGeneralSettings, updateDeployment
|
||||
from litellm.types.utils import CustomHuggingfaceTokenizer
|
||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
@ -5526,11 +5528,16 @@ async def token_counter(request: TokenCountRequest):
|
|||
|
||||
deployment = None
|
||||
litellm_model_name = None
|
||||
model_info: Optional[ModelMapInfo] = None
|
||||
if llm_router is not None:
|
||||
# get 1 deployment corresponding to the model
|
||||
for _model in llm_router.model_list:
|
||||
if _model["model_name"] == request.model:
|
||||
deployment = _model
|
||||
model_info = llm_router.get_router_model_info(
|
||||
deployment=deployment,
|
||||
received_model_name=request.model,
|
||||
)
|
||||
break
|
||||
if deployment is not None:
|
||||
litellm_model_name = deployment.get("litellm_params", {}).get("model")
|
||||
|
@ -5541,12 +5548,22 @@ async def token_counter(request: TokenCountRequest):
|
|||
model_to_use = (
|
||||
litellm_model_name or request.model
|
||||
) # use litellm model name, if it's not avalable then fallback to request.model
|
||||
_tokenizer_used = litellm.utils._select_tokenizer(model=model_to_use)
|
||||
|
||||
custom_tokenizer: Optional[CustomHuggingfaceTokenizer] = None
|
||||
if model_info is not None:
|
||||
custom_tokenizer = cast(
|
||||
Optional[CustomHuggingfaceTokenizer],
|
||||
model_info.get("custom_tokenizer", None),
|
||||
)
|
||||
_tokenizer_used = litellm.utils._select_tokenizer(
|
||||
model=model_to_use, custom_tokenizer=custom_tokenizer
|
||||
)
|
||||
tokenizer_used = str(_tokenizer_used["type"])
|
||||
total_tokens = token_counter(
|
||||
model=model_to_use,
|
||||
text=prompt,
|
||||
messages=messages,
|
||||
custom_tokenizer=_tokenizer_used,
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=total_tokens,
|
||||
|
|
|
@ -1813,3 +1813,9 @@ class LiteLLMLoggingBaseClass:
|
|||
self, original_response, input=None, api_key=None, additional_args={}
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class CustomHuggingfaceTokenizer(TypedDict):
|
||||
identifier: str
|
||||
revision: str # usually 'main'
|
||||
auth_token: Optional[str]
|
||||
|
|
|
@ -126,6 +126,7 @@ from litellm.types.utils import (
|
|||
ChatCompletionMessageToolCall,
|
||||
Choices,
|
||||
CostPerToken,
|
||||
CustomHuggingfaceTokenizer,
|
||||
Delta,
|
||||
Embedding,
|
||||
EmbeddingResponse,
|
||||
|
@ -1242,10 +1243,21 @@ def _is_async_request(
|
|||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _select_tokenizer(
|
||||
model: str,
|
||||
model: str, custom_tokenizer: Optional[CustomHuggingfaceTokenizer] = None
|
||||
):
|
||||
if custom_tokenizer is not None:
|
||||
custom_tokenizer = Tokenizer.from_pretrained(
|
||||
custom_tokenizer["identifier"],
|
||||
revision=custom_tokenizer["revision"],
|
||||
auth_token=custom_tokenizer["auth_token"],
|
||||
)
|
||||
return {"type": "huggingface_tokenizer", "tokenizer": custom_tokenizer}
|
||||
return _select_tokenizer_helper(model=model)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _select_tokenizer_helper(model: str):
|
||||
if model in litellm.cohere_models and "command-r" in model:
|
||||
# cohere
|
||||
cohere_tokenizer = Tokenizer.from_pretrained(
|
||||
|
|
|
@ -210,9 +210,12 @@ def test_model_info_bedrock_converse(monkeypatch):
|
|||
monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True")
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
try:
|
||||
# Load whitelist models from file
|
||||
with open("whitelisted_bedrock_models.txt", "r") as file:
|
||||
whitelist_models = [line.strip() for line in file.readlines()]
|
||||
except FileNotFoundError:
|
||||
pytest.skip("whitelisted_bedrock_models.txt not found")
|
||||
|
||||
_enforce_bedrock_converse_models(
|
||||
model_cost=litellm.model_cost, whitelist_models=whitelist_models
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue