mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Fix team-based logging to langfuse + allow custom tokenizer on /token_counter
endpoint (#7493)
* fix(langfuse_prompt_management.py): migrate dynamic logging to langfuse custom logger compatible class * fix(langfuse_prompt_management.py): support failure callback logging to langfuse as well * feat(proxy_server.py): support setting custom tokenizer on config.yaml Allows customizing value for `/utils/token_counter` * fix(proxy_server.py): fix linting errors * test: skip if file not found * style: cleanup unused import * docs(configs.md): add docs on setting custom tokenizer
This commit is contained in:
parent
6705e30d5d
commit
080de89cfb
11 changed files with 192 additions and 72 deletions
|
@ -516,6 +516,32 @@ model_list:
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Set custom tokenizer
|
||||||
|
|
||||||
|
If you're using the [`/utils/token_counter` endpoint](https://litellm-api.up.railway.app/#/llm%20utils/token_counter_utils_token_counter_post), and want to set a custom huggingface tokenizer for a model, you can do so in the `config.yaml`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: openai-deepseek
|
||||||
|
litellm_params:
|
||||||
|
model: deepseek/deepseek-chat
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
model_info:
|
||||||
|
access_groups: ["restricted-models"]
|
||||||
|
custom_tokenizer:
|
||||||
|
identifier: deepseek-ai/DeepSeek-V3-Base
|
||||||
|
revision: main
|
||||||
|
auth_token: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
**Spec**
|
||||||
|
```
|
||||||
|
custom_tokenizer:
|
||||||
|
identifier: str # huggingface model identifier
|
||||||
|
revision: str # huggingface model revision (usually 'main')
|
||||||
|
auth_token: Optional[str] # huggingface auth token
|
||||||
|
```
|
||||||
|
|
||||||
## General Settings `general_settings` (DB Connection, etc)
|
## General Settings `general_settings` (DB Connection, etc)
|
||||||
|
|
||||||
### Configure DB Pool Limits + Connection Timeouts
|
### Configure DB Pool Limits + Connection Timeouts
|
||||||
|
|
|
@ -158,6 +158,7 @@ class LangFuseHandler:
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the dynamic langfuse credentials are passed, False otherwise
|
bool: True if the dynamic langfuse credentials are passed, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (
|
if (
|
||||||
standard_callback_dynamic_params.get("langfuse_host") is not None
|
standard_callback_dynamic_params.get("langfuse_host") is not None
|
||||||
or standard_callback_dynamic_params.get("langfuse_public_key") is not None
|
or standard_callback_dynamic_params.get("langfuse_public_key") is not None
|
||||||
|
|
|
@ -16,7 +16,11 @@ from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||||
|
|
||||||
|
from ...litellm_core_utils.specialty_caches.dynamic_logging_cache import (
|
||||||
|
DynamicLoggingCache,
|
||||||
|
)
|
||||||
from .langfuse import LangFuseLogger
|
from .langfuse import LangFuseLogger
|
||||||
|
from .langfuse_handler import LangFuseHandler
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langfuse import Langfuse
|
from langfuse import Langfuse
|
||||||
|
@ -29,6 +33,8 @@ else:
|
||||||
PROMPT_CLIENT = Any
|
PROMPT_CLIENT = Any
|
||||||
LangfuseClass = Any
|
LangfuseClass = Any
|
||||||
|
|
||||||
|
in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=10)
|
@lru_cache(maxsize=10)
|
||||||
def langfuse_client_init(
|
def langfuse_client_init(
|
||||||
|
@ -252,7 +258,15 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
|
||||||
return model, messages, non_default_params
|
return model, messages, non_default_params
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
self._old_log_event(
|
standard_callback_dynamic_params = kwargs.get(
|
||||||
|
"standard_callback_dynamic_params"
|
||||||
|
)
|
||||||
|
langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
|
||||||
|
globalLangfuseLogger=self,
|
||||||
|
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||||
|
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||||
|
)
|
||||||
|
langfuse_logger_to_use._old_log_event(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
response_obj=response_obj,
|
response_obj=response_obj,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -262,13 +276,21 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
standard_callback_dynamic_params = kwargs.get(
|
||||||
|
"standard_callback_dynamic_params"
|
||||||
|
)
|
||||||
|
langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request(
|
||||||
|
globalLangfuseLogger=self,
|
||||||
|
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||||
|
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||||
|
)
|
||||||
standard_logging_object = cast(
|
standard_logging_object = cast(
|
||||||
Optional[StandardLoggingPayload],
|
Optional[StandardLoggingPayload],
|
||||||
kwargs.get("standard_logging_object", None),
|
kwargs.get("standard_logging_object", None),
|
||||||
)
|
)
|
||||||
if standard_logging_object is None:
|
if standard_logging_object is None:
|
||||||
return
|
return
|
||||||
self._old_log_event(
|
langfuse_logger_to_use._old_log_event(
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
response_obj=None,
|
response_obj=None,
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.utils import StandardCallbackDynamicParams
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_standard_callback_dynamic_params(
|
||||||
|
kwargs: Optional[Dict] = None,
|
||||||
|
) -> StandardCallbackDynamicParams:
|
||||||
|
"""
|
||||||
|
Initialize the standard callback dynamic params from the kwargs
|
||||||
|
|
||||||
|
checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams
|
||||||
|
"""
|
||||||
|
|
||||||
|
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||||
|
if kwargs:
|
||||||
|
_supported_callback_params = (
|
||||||
|
StandardCallbackDynamicParams.__annotations__.keys()
|
||||||
|
)
|
||||||
|
for param in _supported_callback_params:
|
||||||
|
if param in kwargs:
|
||||||
|
_param_value = kwargs.pop(param)
|
||||||
|
if (
|
||||||
|
_param_value is not None
|
||||||
|
and isinstance(_param_value, str)
|
||||||
|
and "os.environ/" in _param_value
|
||||||
|
):
|
||||||
|
_param_value = get_secret_str(secret_name=_param_value)
|
||||||
|
standard_callback_dynamic_params[param] = _param_value # type: ignore
|
||||||
|
|
||||||
|
return standard_callback_dynamic_params
|
|
@ -94,7 +94,11 @@ from ..integrations.supabase import Supabase
|
||||||
from ..integrations.traceloop import TraceloopLogger
|
from ..integrations.traceloop import TraceloopLogger
|
||||||
from ..integrations.weights_biases import WeightsBiasesLogger
|
from ..integrations.weights_biases import WeightsBiasesLogger
|
||||||
from .exception_mapping_utils import _get_response_headers
|
from .exception_mapping_utils import _get_response_headers
|
||||||
|
from .initialize_dynamic_callback_params import (
|
||||||
|
initialize_standard_callback_dynamic_params as _initialize_standard_callback_dynamic_params,
|
||||||
|
)
|
||||||
from .logging_utils import _assemble_complete_response_from_streaming_chunks
|
from .logging_utils import _assemble_complete_response_from_streaming_chunks
|
||||||
|
from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
|
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
|
||||||
|
@ -156,39 +160,6 @@ class ServiceTraceIDCache:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicLoggingCache:
|
|
||||||
"""
|
|
||||||
Prevent memory leaks caused by initializing new logging clients on each request.
|
|
||||||
|
|
||||||
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.cache = InMemoryCache()
|
|
||||||
|
|
||||||
def get_cache_key(self, args: dict) -> str:
|
|
||||||
args_str = json.dumps(args, sort_keys=True)
|
|
||||||
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
|
|
||||||
return cache_key
|
|
||||||
|
|
||||||
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
|
|
||||||
key_name = self.get_cache_key(
|
|
||||||
args={**credentials, "service_name": service_name}
|
|
||||||
)
|
|
||||||
response = self.cache.get_cache(key=key_name)
|
|
||||||
return response
|
|
||||||
|
|
||||||
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
|
|
||||||
key_name = self.get_cache_key(
|
|
||||||
args={**credentials, "service_name": service_name}
|
|
||||||
)
|
|
||||||
self.cache.set_cache(key=key_name, value=logging_obj)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
in_memory_trace_id_cache = ServiceTraceIDCache()
|
in_memory_trace_id_cache = ServiceTraceIDCache()
|
||||||
in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
||||||
|
|
||||||
|
@ -370,24 +341,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
|
|
||||||
checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams
|
checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams
|
||||||
"""
|
"""
|
||||||
from litellm.secret_managers.main import get_secret_str
|
return _initialize_standard_callback_dynamic_params(kwargs)
|
||||||
|
|
||||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
|
||||||
if kwargs:
|
|
||||||
_supported_callback_params = (
|
|
||||||
StandardCallbackDynamicParams.__annotations__.keys()
|
|
||||||
)
|
|
||||||
for param in _supported_callback_params:
|
|
||||||
if param in kwargs:
|
|
||||||
_param_value = kwargs.pop(param)
|
|
||||||
if (
|
|
||||||
_param_value is not None
|
|
||||||
and isinstance(_param_value, str)
|
|
||||||
and "os.environ/" in _param_value
|
|
||||||
):
|
|
||||||
_param_value = get_secret_str(secret_name=_param_value)
|
|
||||||
standard_callback_dynamic_params[param] = _param_value # type: ignore
|
|
||||||
return standard_callback_dynamic_params
|
|
||||||
|
|
||||||
def update_environment_variables(
|
def update_environment_variables(
|
||||||
self,
|
self,
|
||||||
|
@ -963,7 +917,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
def success_handler( # noqa: PLR0915
|
def success_handler( # noqa: PLR0915
|
||||||
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
|
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
|
||||||
):
|
):
|
||||||
print_verbose(f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}")
|
verbose_logger.debug(
|
||||||
|
f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}"
|
||||||
|
)
|
||||||
start_time, end_time, result = self._success_handler_helper_fn(
|
start_time, end_time, result = self._success_handler_helper_fn(
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
|
@ -971,9 +927,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
cache_hit=cache_hit,
|
cache_hit=cache_hit,
|
||||||
standard_logging_object=kwargs.get("standard_logging_object", None),
|
standard_logging_object=kwargs.get("standard_logging_object", None),
|
||||||
)
|
)
|
||||||
# print(f"original response in success handler: {self.model_call_details['original_response']}")
|
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
|
|
||||||
|
|
||||||
## BUILD COMPLETE STREAMED RESPONSE
|
## BUILD COMPLETE STREAMED RESPONSE
|
||||||
complete_streaming_response: Optional[
|
complete_streaming_response: Optional[
|
||||||
|
@ -2015,6 +1969,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
)
|
)
|
||||||
|
|
||||||
result = None # result sent to all loggers, init this to None incase it's not created
|
result = None # result sent to all loggers, init this to None incase it's not created
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger): # custom logger class
|
if isinstance(callback, CustomLogger): # custom logger class
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ...caching import InMemoryCache
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicLoggingCache:
|
||||||
|
"""
|
||||||
|
Prevent memory leaks caused by initializing new logging clients on each request.
|
||||||
|
|
||||||
|
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.cache = InMemoryCache()
|
||||||
|
|
||||||
|
def get_cache_key(self, args: dict) -> str:
|
||||||
|
args_str = json.dumps(args, sort_keys=True)
|
||||||
|
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
|
||||||
|
return cache_key
|
||||||
|
|
||||||
|
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
|
||||||
|
key_name = self.get_cache_key(
|
||||||
|
args={**credentials, "service_name": service_name}
|
||||||
|
)
|
||||||
|
response = self.cache.get_cache(key=key_name)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
|
||||||
|
key_name = self.get_cache_key(
|
||||||
|
args={**credentials, "service_name": service_name}
|
||||||
|
)
|
||||||
|
self.cache.set_cache(key=key_name, value=logging_obj)
|
||||||
|
return None
|
|
@ -5,17 +5,28 @@ model_list:
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
model_info:
|
model_info:
|
||||||
access_groups: ["default-models"]
|
access_groups: ["default-models"]
|
||||||
- model_name: openai/o1-*
|
- model_name: openai-deepseek
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/o1-*
|
model: deepseek/deepseek-chat
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
model_info:
|
model_info:
|
||||||
access_groups: ["restricted-models"]
|
access_groups: ["restricted-models"]
|
||||||
- model_name: azure-o1-preview
|
custom_tokenizer:
|
||||||
litellm_params:
|
identifier: deepseek-ai/DeepSeek-V3-Base
|
||||||
model: azure/o1-preview
|
revision: main
|
||||||
api_key: os.environ/AZURE_OPENAI_O1_KEY
|
auth_token: os.environ/HUGGINGFACE_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
|
||||||
model_info:
|
litellm_settings:
|
||||||
supports_native_streaming: True
|
default_team_settings:
|
||||||
access_groups: ["shared-models"]
|
- team_id: "team_1"
|
||||||
|
success_callback: ["langfuse"]
|
||||||
|
failure_callback: ["langfuse"]
|
||||||
|
langfuse_public_key: os.environ/LANGFUSE_PUBLIC_KEY
|
||||||
|
langfuse_secret: os.environ/LANGFUSE_SECRET_KEY
|
||||||
|
langfuse_host: os.environ/LANGFUSE_HOST
|
||||||
|
- team_id: "team_2"
|
||||||
|
success_callback: ["langfuse"]
|
||||||
|
failure_callback: ["langfuse"]
|
||||||
|
langfuse_public_key: os.environ/LANGFUSE_PROJECT3_PUBLIC
|
||||||
|
langfuse_secret: os.environ/LANGFUSE_PROJECT3_SECRET
|
||||||
|
langfuse_host: os.environ/LANGFUSE_HOST
|
||||||
|
|
|
@ -274,6 +274,8 @@ from litellm.types.llms.anthropic import (
|
||||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||||
from litellm.types.router import ModelInfo as RouterModelInfo
|
from litellm.types.router import ModelInfo as RouterModelInfo
|
||||||
from litellm.types.router import RouterGeneralSettings, updateDeployment
|
from litellm.types.router import RouterGeneralSettings, updateDeployment
|
||||||
|
from litellm.types.utils import CustomHuggingfaceTokenizer
|
||||||
|
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||||
|
|
||||||
|
@ -5526,11 +5528,16 @@ async def token_counter(request: TokenCountRequest):
|
||||||
|
|
||||||
deployment = None
|
deployment = None
|
||||||
litellm_model_name = None
|
litellm_model_name = None
|
||||||
|
model_info: Optional[ModelMapInfo] = None
|
||||||
if llm_router is not None:
|
if llm_router is not None:
|
||||||
# get 1 deployment corresponding to the model
|
# get 1 deployment corresponding to the model
|
||||||
for _model in llm_router.model_list:
|
for _model in llm_router.model_list:
|
||||||
if _model["model_name"] == request.model:
|
if _model["model_name"] == request.model:
|
||||||
deployment = _model
|
deployment = _model
|
||||||
|
model_info = llm_router.get_router_model_info(
|
||||||
|
deployment=deployment,
|
||||||
|
received_model_name=request.model,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
if deployment is not None:
|
if deployment is not None:
|
||||||
litellm_model_name = deployment.get("litellm_params", {}).get("model")
|
litellm_model_name = deployment.get("litellm_params", {}).get("model")
|
||||||
|
@ -5541,12 +5548,22 @@ async def token_counter(request: TokenCountRequest):
|
||||||
model_to_use = (
|
model_to_use = (
|
||||||
litellm_model_name or request.model
|
litellm_model_name or request.model
|
||||||
) # use litellm model name, if it's not avalable then fallback to request.model
|
) # use litellm model name, if it's not avalable then fallback to request.model
|
||||||
_tokenizer_used = litellm.utils._select_tokenizer(model=model_to_use)
|
|
||||||
|
custom_tokenizer: Optional[CustomHuggingfaceTokenizer] = None
|
||||||
|
if model_info is not None:
|
||||||
|
custom_tokenizer = cast(
|
||||||
|
Optional[CustomHuggingfaceTokenizer],
|
||||||
|
model_info.get("custom_tokenizer", None),
|
||||||
|
)
|
||||||
|
_tokenizer_used = litellm.utils._select_tokenizer(
|
||||||
|
model=model_to_use, custom_tokenizer=custom_tokenizer
|
||||||
|
)
|
||||||
tokenizer_used = str(_tokenizer_used["type"])
|
tokenizer_used = str(_tokenizer_used["type"])
|
||||||
total_tokens = token_counter(
|
total_tokens = token_counter(
|
||||||
model=model_to_use,
|
model=model_to_use,
|
||||||
text=prompt,
|
text=prompt,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
custom_tokenizer=_tokenizer_used,
|
||||||
)
|
)
|
||||||
return TokenCountResponse(
|
return TokenCountResponse(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
|
|
|
@ -1813,3 +1813,9 @@ class LiteLLMLoggingBaseClass:
|
||||||
self, original_response, input=None, api_key=None, additional_args={}
|
self, original_response, input=None, api_key=None, additional_args={}
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CustomHuggingfaceTokenizer(TypedDict):
|
||||||
|
identifier: str
|
||||||
|
revision: str # usually 'main'
|
||||||
|
auth_token: Optional[str]
|
||||||
|
|
|
@ -126,6 +126,7 @@ from litellm.types.utils import (
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
Choices,
|
Choices,
|
||||||
CostPerToken,
|
CostPerToken,
|
||||||
|
CustomHuggingfaceTokenizer,
|
||||||
Delta,
|
Delta,
|
||||||
Embedding,
|
Embedding,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
|
@ -1242,10 +1243,21 @@ def _is_async_request(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=128)
|
|
||||||
def _select_tokenizer(
|
def _select_tokenizer(
|
||||||
model: str,
|
model: str, custom_tokenizer: Optional[CustomHuggingfaceTokenizer] = None
|
||||||
):
|
):
|
||||||
|
if custom_tokenizer is not None:
|
||||||
|
custom_tokenizer = Tokenizer.from_pretrained(
|
||||||
|
custom_tokenizer["identifier"],
|
||||||
|
revision=custom_tokenizer["revision"],
|
||||||
|
auth_token=custom_tokenizer["auth_token"],
|
||||||
|
)
|
||||||
|
return {"type": "huggingface_tokenizer", "tokenizer": custom_tokenizer}
|
||||||
|
return _select_tokenizer_helper(model=model)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _select_tokenizer_helper(model: str):
|
||||||
if model in litellm.cohere_models and "command-r" in model:
|
if model in litellm.cohere_models and "command-r" in model:
|
||||||
# cohere
|
# cohere
|
||||||
cohere_tokenizer = Tokenizer.from_pretrained(
|
cohere_tokenizer = Tokenizer.from_pretrained(
|
||||||
|
|
|
@ -210,9 +210,12 @@ def test_model_info_bedrock_converse(monkeypatch):
|
||||||
monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True")
|
monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True")
|
||||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
|
||||||
# Load whitelist models from file
|
try:
|
||||||
with open("whitelisted_bedrock_models.txt", "r") as file:
|
# Load whitelist models from file
|
||||||
whitelist_models = [line.strip() for line in file.readlines()]
|
with open("whitelisted_bedrock_models.txt", "r") as file:
|
||||||
|
whitelist_models = [line.strip() for line in file.readlines()]
|
||||||
|
except FileNotFoundError:
|
||||||
|
pytest.skip("whitelisted_bedrock_models.txt not found")
|
||||||
|
|
||||||
_enforce_bedrock_converse_models(
|
_enforce_bedrock_converse_models(
|
||||||
model_cost=litellm.model_cost, whitelist_models=whitelist_models
|
model_cost=litellm.model_cost, whitelist_models=whitelist_models
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue