fix(utils.py): add more exception mapping for huggingface

This commit is contained in:
Krrish Dholakia 2024-02-15 21:26:22 -08:00
parent 1b844aafdc
commit c37aad50ea

View file

@ -12,6 +12,7 @@ import litellm
import dotenv, json, traceback, threading, base64, ast import dotenv, json, traceback, threading, base64, ast
import subprocess, os import subprocess, os
from os.path import abspath, join, dirname
import litellm, openai import litellm, openai
import itertools import itertools
import random, uuid, requests import random, uuid, requests
@ -34,6 +35,7 @@ from dataclasses import (
from importlib import resources from importlib import resources
# filename = pkg_resources.resource_filename(__name__, "llms/tokenizers") # filename = pkg_resources.resource_filename(__name__, "llms/tokenizers")
try: try:
filename = str( filename = str(
resources.files().joinpath("llms/tokenizers") # type: ignore resources.files().joinpath("llms/tokenizers") # type: ignore
@ -42,9 +44,10 @@ except:
filename = str( filename = str(
resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10 resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10
) # for python 3.10+ ) # for python 3.10+
os.environ[ os.environ["TIKTOKEN_CACHE_DIR"] = (
"TIKTOKEN_CACHE_DIR" filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
] = filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071 )
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
import importlib.metadata import importlib.metadata
from ._logging import verbose_logger from ._logging import verbose_logger
@ -82,6 +85,20 @@ from .exceptions import (
BudgetExceededError, BudgetExceededError,
UnprocessableEntityError, UnprocessableEntityError,
) )
# Import Enterprise features
project_path = abspath(join(dirname(__file__), "..", ".."))
# Add the "enterprise" directory to sys.path
verbose_logger.debug(f"current project_path: {project_path}")
enterprise_path = abspath(join(project_path, "enterprise"))
sys.path.append(enterprise_path)
verbose_logger.debug(f"sys.path: {sys.path}")
try:
from enterprise.callbacks.generic_api_callback import GenericAPILogger
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from typing import cast, List, Dict, Union, Optional, Literal, Any from typing import cast, List, Dict, Union, Optional, Literal, Any
from .caching import Cache from .caching import Cache
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -107,6 +124,7 @@ customLogger = None
langFuseLogger = None langFuseLogger = None
dynamoLogger = None dynamoLogger = None
s3Logger = None s3Logger = None
genericAPILogger = None
llmonitorLogger = None llmonitorLogger = None
aispendLogger = None aispendLogger = None
berrispendLogger = None berrispendLogger = None
@ -1087,26 +1105,26 @@ class Logging:
self.call_type == CallTypes.aimage_generation.value self.call_type == CallTypes.aimage_generation.value
or self.call_type == CallTypes.image_generation.value or self.call_type == CallTypes.image_generation.value
): ):
self.model_call_details[ self.model_call_details["response_cost"] = (
"response_cost" litellm.completion_cost(
] = litellm.completion_cost(
completion_response=result, completion_response=result,
model=self.model, model=self.model,
call_type=self.call_type, call_type=self.call_type,
) )
)
else: else:
# check if base_model set on azure # check if base_model set on azure
base_model = _get_base_model_from_metadata( base_model = _get_base_model_from_metadata(
model_call_details=self.model_call_details model_call_details=self.model_call_details
) )
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info
self.model_call_details[ self.model_call_details["response_cost"] = (
"response_cost" litellm.completion_cost(
] = litellm.completion_cost(
completion_response=result, completion_response=result,
call_type=self.call_type, call_type=self.call_type,
model=base_model, model=base_model,
) )
)
verbose_logger.debug( verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}" f"Model={self.model}; cost={self.model_call_details['response_cost']}"
) )
@ -1174,9 +1192,9 @@ class Logging:
verbose_logger.debug( verbose_logger.debug(
f"Logging Details LiteLLM-Success Call streaming complete" f"Logging Details LiteLLM-Success Call streaming complete"
) )
self.model_call_details[ self.model_call_details["complete_streaming_response"] = (
"complete_streaming_response" complete_streaming_response
] = complete_streaming_response )
try: try:
if self.model_call_details.get("cache_hit", False) == True: if self.model_call_details.get("cache_hit", False) == True:
self.model_call_details["response_cost"] = 0.0 self.model_call_details["response_cost"] = 0.0
@ -1186,12 +1204,12 @@ class Logging:
model_call_details=self.model_call_details model_call_details=self.model_call_details
) )
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info
self.model_call_details[ self.model_call_details["response_cost"] = (
"response_cost" litellm.completion_cost(
] = litellm.completion_cost(
completion_response=complete_streaming_response, completion_response=complete_streaming_response,
model=base_model, model=base_model,
) )
)
verbose_logger.debug( verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}" f"Model={self.model}; cost={self.model_call_details['response_cost']}"
) )
@ -1369,6 +1387,35 @@ class Logging:
user_id=kwargs.get("user", None), user_id=kwargs.get("user", None),
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if callback == "generic":
global genericAPILogger
verbose_logger.debug("reaches langfuse for success logging!")
kwargs = {}
for k, v in self.model_call_details.items():
if (
k != "original_response"
): # copy.deepcopy raises errors as this could be a coroutine
kwargs[k] = v
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
if self.stream:
verbose_logger.debug(
f"is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}"
)
if complete_streaming_response is None:
break
else:
print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"]
if genericAPILogger is None:
genericAPILogger = GenericAPILogger()
genericAPILogger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if callback == "cache" and litellm.cache is not None: if callback == "cache" and litellm.cache is not None:
# this only logs streaming once, complete_streaming_response exists i.e when stream ends # this only logs streaming once, complete_streaming_response exists i.e when stream ends
print_verbose("success_callback: reaches cache for logging!") print_verbose("success_callback: reaches cache for logging!")
@ -1448,11 +1495,11 @@ class Logging:
) )
else: else:
if self.stream and complete_streaming_response: if self.stream and complete_streaming_response:
self.model_call_details[ self.model_call_details["complete_response"] = (
"complete_response" self.model_call_details.get(
] = self.model_call_details.get(
"complete_streaming_response", {} "complete_streaming_response", {}
) )
)
result = self.model_call_details["complete_response"] result = self.model_call_details["complete_response"]
callback.log_success_event( callback.log_success_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -1531,9 +1578,9 @@ class Logging:
verbose_logger.debug( verbose_logger.debug(
"Async success callbacks: Got a complete streaming response" "Async success callbacks: Got a complete streaming response"
) )
self.model_call_details[ self.model_call_details["complete_streaming_response"] = (
"complete_streaming_response" complete_streaming_response
] = complete_streaming_response )
try: try:
if self.model_call_details.get("cache_hit", False) == True: if self.model_call_details.get("cache_hit", False) == True:
self.model_call_details["response_cost"] = 0.0 self.model_call_details["response_cost"] = 0.0
@ -2272,9 +2319,9 @@ def client(original_function):
): ):
print_verbose(f"Checking Cache") print_verbose(f"Checking Cache")
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[ kwargs["preset_cache_key"] = (
"preset_cache_key" preset_cache_key # for streaming calls, we need to pass the preset_cache_key
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key )
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None: if cached_result != None:
if "detail" in cached_result: if "detail" in cached_result:
@ -2572,17 +2619,17 @@ def client(original_function):
cached_result = None cached_result = None
elif isinstance(litellm.cache.cache, RedisSemanticCache): elif isinstance(litellm.cache.cache, RedisSemanticCache):
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[ kwargs["preset_cache_key"] = (
"preset_cache_key" preset_cache_key # for streaming calls, we need to pass the preset_cache_key
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key )
cached_result = await litellm.cache.async_get_cache( cached_result = await litellm.cache.async_get_cache(
*args, **kwargs *args, **kwargs
) )
else: else:
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[ kwargs["preset_cache_key"] = (
"preset_cache_key" preset_cache_key # for streaming calls, we need to pass the preset_cache_key
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key )
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None and not isinstance( if cached_result is not None and not isinstance(
@ -3912,16 +3959,16 @@ def get_optional_params(
True # so that main.py adds the function call to the prompt True # so that main.py adds the function call to the prompt
) )
if "tools" in non_default_params: if "tools" in non_default_params:
optional_params[ optional_params["functions_unsupported_model"] = (
"functions_unsupported_model" non_default_params.pop("tools")
] = non_default_params.pop("tools") )
non_default_params.pop( non_default_params.pop(
"tool_choice", None "tool_choice", None
) # causes ollama requests to hang ) # causes ollama requests to hang
elif "functions" in non_default_params: elif "functions" in non_default_params:
optional_params[ optional_params["functions_unsupported_model"] = (
"functions_unsupported_model" non_default_params.pop("functions")
] = non_default_params.pop("functions") )
elif ( elif (
litellm.add_function_to_prompt litellm.add_function_to_prompt
): # if user opts to add it to prompt instead ): # if user opts to add it to prompt instead
@ -4101,9 +4148,9 @@ def get_optional_params(
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if n is not None: if n is not None:
optional_params["best_of"] = n optional_params["best_of"] = n
optional_params[ optional_params["do_sample"] = (
"do_sample" True # Need to sample if you want best of for hf inference endpoints
] = True # Need to sample if you want best of for hf inference endpoints )
if stream is not None: if stream is not None:
optional_params["stream"] = stream optional_params["stream"] = stream
if stop is not None: if stop is not None:
@ -4148,9 +4195,9 @@ def get_optional_params(
if max_tokens is not None: if max_tokens is not None:
optional_params["max_tokens"] = max_tokens optional_params["max_tokens"] = max_tokens
if frequency_penalty is not None: if frequency_penalty is not None:
optional_params[ optional_params["repetition_penalty"] = (
"repetition_penalty" frequency_penalty # https://docs.together.ai/reference/inference
] = frequency_penalty # https://docs.together.ai/reference/inference )
if stop is not None: if stop is not None:
optional_params["stop"] = stop optional_params["stop"] = stop
if tools is not None: if tools is not None:
@ -4259,9 +4306,9 @@ def get_optional_params(
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if n is not None: if n is not None:
optional_params["best_of"] = n optional_params["best_of"] = n
optional_params[ optional_params["do_sample"] = (
"do_sample" True # Need to sample if you want best of for hf inference endpoints
] = True # Need to sample if you want best of for hf inference endpoints )
if stream is not None: if stream is not None:
optional_params["stream"] = stream optional_params["stream"] = stream
if stop is not None: if stop is not None:
@ -4584,9 +4631,9 @@ def get_optional_params(
extra_body["safe_mode"] = safe_mode extra_body["safe_mode"] = safe_mode
if random_seed is not None: if random_seed is not None:
extra_body["random_seed"] = random_seed extra_body["random_seed"] = random_seed
optional_params[ optional_params["extra_body"] = (
"extra_body" extra_body # openai client supports `extra_body` param
] = extra_body # openai client supports `extra_body` param )
elif custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
supported_params = [ supported_params = [
"functions", "functions",
@ -4655,9 +4702,9 @@ def get_optional_params(
extra_body["models"] = models extra_body["models"] = models
if route is not None: if route is not None:
extra_body["route"] = route extra_body["route"] = route
optional_params[ optional_params["extra_body"] = (
"extra_body" extra_body # openai client supports `extra_body` param
] = extra_body # openai client supports `extra_body` param )
else: # assume passing in params for openai/azure openai else: # assume passing in params for openai/azure openai
supported_params = [ supported_params = [
"functions", "functions",
@ -6982,6 +7029,14 @@ def exception_type(
model=model, model=model,
response=original_exception.response, response=original_exception.response,
) )
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
model=model,
response=original_exception.response,
)
else: else:
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
@ -8413,11 +8468,11 @@ class CustomStreamWrapper:
try: try:
completion_obj["content"] = chunk.text completion_obj["content"] = chunk.text
if hasattr(chunk.candidates[0], "finish_reason"): if hasattr(chunk.candidates[0], "finish_reason"):
model_response.choices[ model_response.choices[0].finish_reason = (
0 map_finish_reason(
].finish_reason = map_finish_reason(
chunk.candidates[0].finish_reason.name chunk.candidates[0].finish_reason.name
) )
)
except: except:
if chunk.candidates[0].finish_reason.name == "SAFETY": if chunk.candidates[0].finish_reason.name == "SAFETY":
raise Exception( raise Exception(