fix(utils.py): add more exception mapping for huggingface

This commit is contained in:
Krrish Dholakia 2024-02-15 21:26:22 -08:00
parent 1b844aafdc
commit c37aad50ea

View file

@ -12,6 +12,7 @@ import litellm
import dotenv, json, traceback, threading, base64, ast
import subprocess, os
from os.path import abspath, join, dirname
import litellm, openai
import itertools
import random, uuid, requests
@ -34,6 +35,7 @@ from dataclasses import (
from importlib import resources
# filename = pkg_resources.resource_filename(__name__, "llms/tokenizers")
try:
filename = str(
resources.files().joinpath("llms/tokenizers") # type: ignore
@ -42,9 +44,10 @@ except:
filename = str(
resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10
) # for python 3.10+
os.environ[
"TIKTOKEN_CACHE_DIR"
] = filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
os.environ["TIKTOKEN_CACHE_DIR"] = (
filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
)
encoding = tiktoken.get_encoding("cl100k_base")
import importlib.metadata
from ._logging import verbose_logger
@ -82,6 +85,20 @@ from .exceptions import (
BudgetExceededError,
UnprocessableEntityError,
)
# Import Enterprise features
project_path = abspath(join(dirname(__file__), "..", ".."))
# Add the "enterprise" directory to sys.path
verbose_logger.debug(f"current project_path: {project_path}")
enterprise_path = abspath(join(project_path, "enterprise"))
sys.path.append(enterprise_path)
verbose_logger.debug(f"sys.path: {sys.path}")
try:
from enterprise.callbacks.generic_api_callback import GenericAPILogger
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from typing import cast, List, Dict, Union, Optional, Literal, Any
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor
@ -107,6 +124,7 @@ customLogger = None
langFuseLogger = None
dynamoLogger = None
s3Logger = None
genericAPILogger = None
llmonitorLogger = None
aispendLogger = None
berrispendLogger = None
@ -1087,12 +1105,12 @@ class Logging:
self.call_type == CallTypes.aimage_generation.value
or self.call_type == CallTypes.image_generation.value
):
self.model_call_details[
"response_cost"
] = litellm.completion_cost(
completion_response=result,
model=self.model,
call_type=self.call_type,
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=result,
model=self.model,
call_type=self.call_type,
)
)
else:
# check if base_model set on azure
@ -1100,12 +1118,12 @@ class Logging:
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details[
"response_cost"
] = litellm.completion_cost(
completion_response=result,
call_type=self.call_type,
model=base_model,
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=result,
call_type=self.call_type,
model=base_model,
)
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
@ -1174,9 +1192,9 @@ class Logging:
verbose_logger.debug(
f"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
try:
if self.model_call_details.get("cache_hit", False) == True:
self.model_call_details["response_cost"] = 0.0
@ -1186,11 +1204,11 @@ class Logging:
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details[
"response_cost"
] = litellm.completion_cost(
completion_response=complete_streaming_response,
model=base_model,
self.model_call_details["response_cost"] = (
litellm.completion_cost(
completion_response=complete_streaming_response,
model=base_model,
)
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
@ -1369,6 +1387,35 @@ class Logging:
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if callback == "generic":
global genericAPILogger
verbose_logger.debug("reaches langfuse for success logging!")
kwargs = {}
for k, v in self.model_call_details.items():
if (
k != "original_response"
): # copy.deepcopy raises errors as this could be a coroutine
kwargs[k] = v
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
if self.stream:
verbose_logger.debug(
f"is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}"
)
if complete_streaming_response is None:
break
else:
print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"]
if genericAPILogger is None:
genericAPILogger = GenericAPILogger()
genericAPILogger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if callback == "cache" and litellm.cache is not None:
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
print_verbose("success_callback: reaches cache for logging!")
@ -1448,10 +1495,10 @@ class Logging:
)
else:
if self.stream and complete_streaming_response:
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
)
result = self.model_call_details["complete_response"]
callback.log_success_event(
@ -1531,9 +1578,9 @@ class Logging:
verbose_logger.debug(
"Async success callbacks: Got a complete streaming response"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
try:
if self.model_call_details.get("cache_hit", False) == True:
self.model_call_details["response_cost"] = 0.0
@ -2272,9 +2319,9 @@ def client(original_function):
):
print_verbose(f"Checking Cache")
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[
"preset_cache_key"
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None:
if "detail" in cached_result:
@ -2572,17 +2619,17 @@ def client(original_function):
cached_result = None
elif isinstance(litellm.cache.cache, RedisSemanticCache):
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[
"preset_cache_key"
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = await litellm.cache.async_get_cache(
*args, **kwargs
)
else:
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs[
"preset_cache_key"
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None and not isinstance(
@ -3912,16 +3959,16 @@ def get_optional_params(
True # so that main.py adds the function call to the prompt
)
if "tools" in non_default_params:
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("tools")
optional_params["functions_unsupported_model"] = (
non_default_params.pop("tools")
)
non_default_params.pop(
"tool_choice", None
) # causes ollama requests to hang
elif "functions" in non_default_params:
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("functions")
optional_params["functions_unsupported_model"] = (
non_default_params.pop("functions")
)
elif (
litellm.add_function_to_prompt
): # if user opts to add it to prompt instead
@ -4101,9 +4148,9 @@ def get_optional_params(
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params[
"do_sample"
] = True # Need to sample if you want best of for hf inference endpoints
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
@ -4148,9 +4195,9 @@ def get_optional_params(
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if frequency_penalty is not None:
optional_params[
"repetition_penalty"
] = frequency_penalty # https://docs.together.ai/reference/inference
optional_params["repetition_penalty"] = (
frequency_penalty # https://docs.together.ai/reference/inference
)
if stop is not None:
optional_params["stop"] = stop
if tools is not None:
@ -4259,9 +4306,9 @@ def get_optional_params(
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params[
"do_sample"
] = True # Need to sample if you want best of for hf inference endpoints
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
@ -4584,9 +4631,9 @@ def get_optional_params(
extra_body["safe_mode"] = safe_mode
if random_seed is not None:
extra_body["random_seed"] = random_seed
optional_params[
"extra_body"
] = extra_body # openai client supports `extra_body` param
optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
)
elif custom_llm_provider == "openrouter":
supported_params = [
"functions",
@ -4655,9 +4702,9 @@ def get_optional_params(
extra_body["models"] = models
if route is not None:
extra_body["route"] = route
optional_params[
"extra_body"
] = extra_body # openai client supports `extra_body` param
optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
)
else: # assume passing in params for openai/azure openai
supported_params = [
"functions",
@ -6982,6 +7029,14 @@ def exception_type(
model=model,
response=original_exception.response,
)
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
model=model,
response=original_exception.response,
)
else:
exception_mapping_worked = True
raise APIError(
@ -8413,10 +8468,10 @@ class CustomStreamWrapper:
try:
completion_obj["content"] = chunk.text
if hasattr(chunk.candidates[0], "finish_reason"):
model_response.choices[
0
].finish_reason = map_finish_reason(
chunk.candidates[0].finish_reason.name
model_response.choices[0].finish_reason = (
map_finish_reason(
chunk.candidates[0].finish_reason.name
)
)
except:
if chunk.candidates[0].finish_reason.name == "SAFETY":