fix add get_first_chars_messages in utils

This commit is contained in:
Ishaan Jaff 2024-05-04 12:43:09 -07:00
parent 76825e1d2c
commit 855c7caf0b
2 changed files with 37 additions and 14 deletions

View file

@ -638,6 +638,7 @@ from .utils import (
get_secret, get_secret,
get_supported_openai_params, get_supported_openai_params,
get_api_base, get_api_base,
get_first_chars_messages,
) )
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig

View file

@ -5897,6 +5897,15 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
return None return None
def get_first_chars_messages(kwargs: dict) -> str:
try:
_messages = kwargs.get("messages")
_messages = str(_messages)[:100]
return _messages
except:
return ""
def get_supported_openai_params(model: str, custom_llm_provider: str): def get_supported_openai_params(model: str, custom_llm_provider: str):
""" """
Returns the supported openai params for a given model + provider Returns the supported openai params for a given model + provider
@ -7885,6 +7894,9 @@ def exception_type(
except: except:
_api_base = "" _api_base = ""
error_str += f" \n model: {model} \n api_base: {_api_base} \n"
error_str += str(completion_kwargs)
if "Request Timeout Error" in error_str or "Request timed out" in error_str: if "Request Timeout Error" in error_str or "Request timed out" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
@ -9049,11 +9061,21 @@ def exception_type(
request=original_exception.request, request=original_exception.request,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
_api_base = litellm.get_api_base(
model=model, optional_params=extra_kwargs
)
messages = litellm.get_first_chars_messages(kwargs=completion_kwargs)
extra_information = f"\nModel: {model}"
if _api_base:
extra_information += f"\nAPI Base: {_api_base}"
if messages and len(messages) > 0:
extra_information += f"\nMessages: {messages}"
if "Internal server error" in error_str: if "Internal server error" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
status_code=500, status_code=500,
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
request=httpx.Request(method="POST", url="https://openai.com/"), request=httpx.Request(method="POST", url="https://openai.com/"),
@ -9061,7 +9083,7 @@ def exception_type(
elif "This model's maximum context length is" in error_str: elif "This model's maximum context length is" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise ContextWindowExceededError( raise ContextWindowExceededError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
response=original_exception.response, response=original_exception.response,
@ -9069,7 +9091,7 @@ def exception_type(
elif "DeploymentNotFound" in error_str: elif "DeploymentNotFound" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise NotFoundError( raise NotFoundError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
response=original_exception.response, response=original_exception.response,
@ -9083,7 +9105,7 @@ def exception_type(
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise ContentPolicyViolationError( raise ContentPolicyViolationError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
response=original_exception.response, response=original_exception.response,
@ -9091,7 +9113,7 @@ def exception_type(
elif "invalid_request_error" in error_str: elif "invalid_request_error" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
response=original_exception.response, response=original_exception.response,
@ -9102,7 +9124,7 @@ def exception_type(
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"{exception_provider} - {original_exception.message}", message=f"{exception_provider} - {original_exception.message} {extra_information}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=original_exception.response,
@ -9112,7 +9134,7 @@ def exception_type(
if original_exception.status_code == 401: if original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
response=original_exception.response, response=original_exception.response,
@ -9120,14 +9142,14 @@ def exception_type(
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
model=model, model=model,
llm_provider="azure", llm_provider="azure",
) )
if original_exception.status_code == 422: if original_exception.status_code == 422:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
model=model, model=model,
llm_provider="azure", llm_provider="azure",
response=original_exception.response, response=original_exception.response,
@ -9135,7 +9157,7 @@ def exception_type(
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
model=model, model=model,
llm_provider="azure", llm_provider="azure",
response=original_exception.response, response=original_exception.response,
@ -9143,7 +9165,7 @@ def exception_type(
elif original_exception.status_code == 503: elif original_exception.status_code == 503:
exception_mapping_worked = True exception_mapping_worked = True
raise ServiceUnavailableError( raise ServiceUnavailableError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
model=model, model=model,
llm_provider="azure", llm_provider="azure",
response=original_exception.response, response=original_exception.response,
@ -9151,7 +9173,7 @@ def exception_type(
elif original_exception.status_code == 504: # gateway timeout error elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
model=model, model=model,
llm_provider="azure", llm_provider="azure",
) )
@ -9159,7 +9181,7 @@ def exception_type(
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
status_code=original_exception.status_code, status_code=original_exception.status_code,
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message} {extra_information}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
request=httpx.Request( request=httpx.Request(
@ -9169,7 +9191,7 @@ def exception_type(
else: else:
# if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors
raise APIConnectionError( raise APIConnectionError(
message=f"{exception_provider} - {message}", message=f"{exception_provider} - {message} {extra_information}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
request=httpx.Request(method="POST", url="https://openai.com/"), request=httpx.Request(method="POST", url="https://openai.com/"),