Merge branch 'main' into litellm-fix-vertexaibeta

This commit is contained in:
Tiger Yu 2024-07-02 09:49:44 -07:00 committed by GitHub
commit 26630cd263
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
73 changed files with 3482 additions and 782 deletions

View file

@ -48,8 +48,10 @@ from tokenizers import Tokenizer
import litellm
import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils
import litellm.litellm_core_utils.json_validation_rule
from litellm.caching import DualCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.exception_mapping_utils import get_error_message
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
@ -579,7 +581,7 @@ def client(original_function):
else:
return False
def post_call_processing(original_response, model):
def post_call_processing(original_response, model, optional_params: Optional[dict]):
try:
if original_response is None:
pass
@ -594,11 +596,47 @@ def client(original_function):
pass
else:
if isinstance(original_response, ModelResponse):
model_response = original_response.choices[
model_response: Optional[str] = original_response.choices[
0
].message.content
### POST-CALL RULES ###
rules_obj.post_call_rules(input=model_response, model=model)
].message.content # type: ignore
if model_response is not None:
### POST-CALL RULES ###
rules_obj.post_call_rules(
input=model_response, model=model
)
### JSON SCHEMA VALIDATION ###
if (
optional_params is not None
and "response_format" in optional_params
and isinstance(
optional_params["response_format"], dict
)
and "type" in optional_params["response_format"]
and optional_params["response_format"]["type"]
== "json_object"
and "response_schema"
in optional_params["response_format"]
and isinstance(
optional_params["response_format"][
"response_schema"
],
dict,
)
and "enforce_validation"
in optional_params["response_format"]
and optional_params["response_format"][
"enforce_validation"
]
is True
):
# schema given, json response expected, and validation enforced
litellm.litellm_core_utils.json_validation_rule.validate_schema(
schema=optional_params["response_format"][
"response_schema"
],
response=model_response,
)
except Exception as e:
raise e
@ -867,7 +905,11 @@ def client(original_function):
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model or None)
post_call_processing(
original_response=result,
model=model or None,
optional_params=kwargs,
)
# [OPTIONAL] ADD TO CACHE
if (
@ -1316,7 +1358,9 @@ def client(original_function):
).total_seconds() * 1000 # return response latency in ms like openai
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
post_call_processing(
original_response=result, model=model, optional_params=kwargs
)
# [OPTIONAL] ADD TO CACHE
if (
@ -1847,9 +1891,10 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
Parameters:
model (str): The model name to be checked.
custom_llm_provider (str): The provider to be checked.
Returns:
bool: True if the model supports function calling, False otherwise.
bool: True if the model supports system messages, False otherwise.
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
@ -1867,6 +1912,43 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
)
def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the given model + provider supports 'response_schema' as a param.
Parameters:
model (str): The model name to be checked.
custom_llm_provider (str): The provider to be checked.
Returns:
bool: True if the model supports response_schema, False otherwise.
Does not raise error. Defaults to 'False'. Outputs logging.error.
"""
try:
## GET LLM PROVIDER ##
model, custom_llm_provider, _, _ = get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
if custom_llm_provider == "predibase": # predibase supports this globally
return True
## GET MODEL INFO
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get("supports_response_schema", False) is True:
return True
return False
except Exception:
verbose_logger.error(
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
)
return False
def supports_function_calling(model: str) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
@ -2324,7 +2406,9 @@ def get_optional_params(
elif k == "hf_model_name" and custom_llm_provider != "sagemaker":
continue
elif (
k.startswith("vertex_") and custom_llm_provider != "vertex_ai" and custom_llm_provider != "vertex_ai_beta"
k.startswith("vertex_")
and custom_llm_provider != "vertex_ai"
and custom_llm_provider != "vertex_ai_beta"
): # allow dynamically setting vertex ai init logic
continue
passed_params[k] = v
@ -2756,6 +2840,11 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif (
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
@ -2824,12 +2913,7 @@ def get_optional_params(
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
else: # bedrock httpx route
elif model in litellm.BEDROCK_CONVERSE_MODELS:
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
@ -2840,6 +2924,11 @@ def get_optional_params(
else False
),
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif "amazon" in model: # amazon titan llms
_check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
@ -3755,23 +3844,18 @@ def get_supported_openai_params(
return litellm.AzureOpenAIConfig().get_supported_openai_params()
elif custom_llm_provider == "openrouter":
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"presence_penalty",
"repetition_penalty",
"seed",
"tools",
"tool_choice",
"max_retries",
"max_tokens",
"logit_bias",
"logprobs",
"top_logprobs",
"response_format",
"stop",
]
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
# mistal and codestral api have the exact same params
@ -3789,6 +3873,10 @@ def get_supported_openai_params(
"top_p",
"stop",
"seed",
"tools",
"tool_choice",
"functions",
"function_call",
]
elif custom_llm_provider == "huggingface":
return litellm.HuggingfaceConfig().get_supported_openai_params()
@ -4434,8 +4522,7 @@ def get_max_tokens(model: str) -> Optional[int]:
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
"""
Get a dict for the maximum tokens (context window),
input_cost_per_token, output_cost_per_token for a given model.
Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model.
Parameters:
- model (str): The name of the model.
@ -4520,6 +4607,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
mode="chat",
supported_openai_params=supported_openai_params,
supports_system_messages=None,
supports_response_schema=None,
)
else:
"""
@ -4541,36 +4629,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_character=_model_info.get(
"input_cost_per_character", None
),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
elif model in litellm.model_cost:
_model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params
@ -4584,36 +4642,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_character=_model_info.get(
"input_cost_per_character", None
),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
elif split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params
@ -4627,40 +4655,48 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_character=_model_info.get(
"input_cost_per_character", None
),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
else:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
)
## PROVIDER-SPECIFIC INFORMATION
if custom_llm_provider == "predibase":
_model_info["supports_response_schema"] = True
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_character=_model_info.get(
"input_cost_per_character", None
),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
supports_response_schema=_model_info.get(
"supports_response_schema", None
),
)
except Exception:
raise Exception(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
@ -5278,6 +5314,27 @@ def convert_to_model_response_object(
hidden_params: Optional[dict] = None,
):
received_args = locals()
### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary
if (
response_object is not None
and "error" in response_object
and response_object["error"] is not None
):
error_args = {"status_code": 422, "message": "Error in response object"}
if isinstance(response_object["error"], dict):
if "code" in response_object["error"]:
error_args["status_code"] = response_object["error"]["code"]
if "message" in response_object["error"]:
if isinstance(response_object["error"]["message"], dict):
message_str = json.dumps(response_object["error"]["message"])
else:
message_str = str(response_object["error"]["message"])
error_args["message"] = message_str
raised_exception = Exception()
setattr(raised_exception, "status_code", error_args["status_code"])
setattr(raised_exception, "message", error_args["message"])
raise raised_exception
try:
if response_type == "completion" and (
model_response_object is None
@ -5733,7 +5790,10 @@ def exception_type(
print() # noqa
try:
if model:
error_str = str(original_exception)
if hasattr(original_exception, "message"):
error_str = str(original_exception.message)
else:
error_str = str(original_exception)
if isinstance(original_exception, BaseException):
exception_type = type(original_exception).__name__
else:
@ -5755,6 +5815,18 @@ def exception_type(
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
extra_information = f"\nModel: {model}"
exception_provider = "Unknown"
if (
isinstance(custom_llm_provider, str)
and len(custom_llm_provider) > 0
):
exception_provider = (
custom_llm_provider[0].upper()
+ custom_llm_provider[1:]
+ "Exception"
)
if _api_base:
extra_information += f"\nAPI Base: `{_api_base}`"
if (
@ -5805,10 +5877,13 @@ def exception_type(
or custom_llm_provider in litellm.openai_compatible_providers
):
# custom_llm_provider is openai, make it OpenAI
if hasattr(original_exception, "message"):
message = original_exception.message
else:
message = str(original_exception)
message = get_error_message(error_obj=original_exception)
if message is None:
if hasattr(original_exception, "message"):
message = original_exception.message
else:
message = str(original_exception)
if message is not None and isinstance(message, str):
message = message.replace("OPENAI", custom_llm_provider.upper())
message = message.replace("openai", custom_llm_provider)
@ -6141,7 +6216,6 @@ def exception_type(
)
elif (
original_exception.status_code == 400
or original_exception.status_code == 422
or original_exception.status_code == 413
):
exception_mapping_worked = True
@ -6151,6 +6225,14 @@ def exception_type(
llm_provider="replicate",
response=original_exception.response,
)
elif original_exception.status_code == 422:
exception_mapping_worked = True
raise UnprocessableEntityError(
message=f"ReplicateException - {original_exception.message}",
model=model,
llm_provider="replicate",
response=original_exception.response,
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
@ -7254,10 +7336,17 @@ def exception_type(
request=original_exception.request,
)
elif custom_llm_provider == "azure":
message = get_error_message(error_obj=original_exception)
if message is None:
if hasattr(original_exception, "message"):
message = original_exception.message
else:
message = str(original_exception)
if "Internal server error" in error_str:
exception_mapping_worked = True
raise litellm.InternalServerError(
message=f"AzureException Internal server error - {original_exception.message}",
message=f"AzureException Internal server error - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7270,7 +7359,7 @@ def exception_type(
elif "This model's maximum context length is" in error_str:
exception_mapping_worked = True
raise ContextWindowExceededError(
message=f"AzureException ContextWindowExceededError - {original_exception.message}",
message=f"AzureException ContextWindowExceededError - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7279,7 +7368,7 @@ def exception_type(
elif "DeploymentNotFound" in error_str:
exception_mapping_worked = True
raise NotFoundError(
message=f"AzureException NotFoundError - {original_exception.message}",
message=f"AzureException NotFoundError - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7299,7 +7388,7 @@ def exception_type(
):
exception_mapping_worked = True
raise ContentPolicyViolationError(
message=f"litellm.ContentPolicyViolationError: AzureException - {original_exception.message}",
message=f"litellm.ContentPolicyViolationError: AzureException - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7308,7 +7397,7 @@ def exception_type(
elif "invalid_request_error" in error_str:
exception_mapping_worked = True
raise BadRequestError(
message=f"AzureException BadRequestError - {original_exception.message}",
message=f"AzureException BadRequestError - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7320,7 +7409,7 @@ def exception_type(
):
exception_mapping_worked = True
raise AuthenticationError(
message=f"{exception_provider} AuthenticationError - {original_exception.message}",
message=f"{exception_provider} AuthenticationError - {message}",
llm_provider=custom_llm_provider,
model=model,
litellm_debug_info=extra_information,
@ -7331,7 +7420,7 @@ def exception_type(
if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"AzureException - {original_exception.message}",
message=f"AzureException - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7340,7 +7429,7 @@ def exception_type(
elif original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"AzureException AuthenticationError - {original_exception.message}",
message=f"AzureException AuthenticationError - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7349,7 +7438,7 @@ def exception_type(
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
message=f"AzureException Timeout - {original_exception.message}",
message=f"AzureException Timeout - {message}",
model=model,
litellm_debug_info=extra_information,
llm_provider="azure",
@ -7357,7 +7446,7 @@ def exception_type(
elif original_exception.status_code == 422:
exception_mapping_worked = True
raise BadRequestError(
message=f"AzureException BadRequestError - {original_exception.message}",
message=f"AzureException BadRequestError - {message}",
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
@ -7366,7 +7455,7 @@ def exception_type(
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"AzureException RateLimitError - {original_exception.message}",
message=f"AzureException RateLimitError - {message}",
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
@ -7375,7 +7464,7 @@ def exception_type(
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"AzureException ServiceUnavailableError - {original_exception.message}",
message=f"AzureException ServiceUnavailableError - {message}",
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
@ -7384,7 +7473,7 @@ def exception_type(
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
message=f"AzureException Timeout - {original_exception.message}",
message=f"AzureException Timeout - {message}",
model=model,
litellm_debug_info=extra_information,
llm_provider="azure",
@ -7393,7 +7482,7 @@ def exception_type(
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
message=f"AzureException APIError - {original_exception.message}",
message=f"AzureException APIError - {message}",
llm_provider="azure",
litellm_debug_info=extra_information,
model=model,