fix(together_ai.py): exception mapping for tgai

This commit is contained in:
Krrish Dholakia 2023-11-13 13:17:15 -08:00
parent aa8ca781ba
commit d4de55b053
4 changed files with 103 additions and 50 deletions

View file

@ -5,6 +5,7 @@ import requests
import time
from typing import Callable, Optional
import litellm
import httpx
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
@ -12,6 +13,8 @@ class TogetherAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs

View file

@ -877,28 +877,39 @@ def completion(
return response
response = model_response
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
openai.api_type = "openai"
# not sure if this will work after someone first uses another API
openai.base_url = (
litellm.api_base
if litellm.api_base is not None
else "https://openrouter.ai/api/v1"
api_base = (
api_base
or litellm.api_base
or "https://openrouter.ai/api/v1"
)
api_key = (
api_key or
litellm.api_key or
litellm.openrouter_key or
get_secret("OPENROUTER_API_KEY") or
get_secret("OR_API_KEY")
)
openrouter_site_url = (
get_secret("OR_SITE_URL")
or "https://litellm.ai"
)
openrouter_app_name = (
get_secret("OR_APP_NAME")
or "liteLLM"
)
openai.api_version = None
if litellm.organization:
openai.organization = litellm.organization
if api_key:
openai.api_key = api_key
elif litellm.openrouter_key:
openai.api_key = litellm.openrouter_key
else:
openai.api_key = get_secret("OPENROUTER_API_KEY") or get_secret(
"OR_API_KEY"
) or litellm.api_key
headers = (
headers or
litellm.headers
litellm.headers or
{
"HTTP-Referer": openrouter_site_url,
"X-Title": openrouter_app_name,
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
)
data = {
@ -909,27 +920,44 @@ def completion(
## LOGGING
logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers})
## COMPLETION CALL
if headers:
response = openai.chat.completions.create(
headers=headers, # type: ignore
**data, # type: ignore
)
else:
openrouter_site_url = get_secret("OR_SITE_URL")
openrouter_app_name = get_secret("OR_APP_NAME")
# if openrouter_site_url is None, set it to https://litellm.ai
if openrouter_site_url is None:
openrouter_site_url = "https://litellm.ai"
# if openrouter_app_name is None, set it to liteLLM
if openrouter_app_name is None:
openrouter_app_name = "liteLLM"
response = openai.chat.completions.create( # type: ignore
extra_headers=httpx.Headers({ # type: ignore
"HTTP-Referer": openrouter_site_url, # type: ignore
"X-Title": openrouter_app_name, # type: ignore
}), # type: ignore
**data,
## COMPLETION CALL
response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
api_key=api_key,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion
)
# if headers:
# response = openai.chat.completions.create(
# headers=headers, # type: ignore
# **data, # type: ignore
# )
# else:
# openrouter_site_url = get_secret("OR_SITE_URL")
# openrouter_app_name = get_secret("OR_APP_NAME")
# # if openrouter_site_url is None, set it to https://litellm.ai
# if openrouter_site_url is None:
# openrouter_site_url = "https://litellm.ai"
# # if openrouter_app_name is None, set it to liteLLM
# if openrouter_app_name is None:
# openrouter_app_name = "liteLLM"
# response = openai.chat.completions.create( # type: ignore
# extra_headers=httpx.Headers({ # type: ignore
# "HTTP-Referer": openrouter_site_url, # type: ignore
# "X-Title": openrouter_app_name, # type: ignore
# }), # type: ignore
# **data,
# )
## LOGGING
logging.post_call(
input=messages, api_key=openai.api_key, original_response=response
@ -1961,7 +1989,7 @@ def moderation(input: str, api_key: Optional[str]=None):
openai.api_key = api_key
openai.api_type = "open_ai" # type: ignore
openai.api_version = None
openai.base_url = "https://api.openai.com/v1"
openai.base_url = "https://api.openai.com/v1/"
response = openai.moderations.create(input=input)
return response

View file

@ -873,18 +873,20 @@ def test_completion_together_ai():
# test_completion_together_ai()
def test_customprompt_together_ai():
try:
litellm.set_verbose = True
litellm.set_verbose = False
litellm.num_retries = 0
response = completion(
model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10",
messages=messages,
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
)
print(response)
except litellm.APIError as e:
except litellm.exceptions.Timeout as e:
print(f"Timeout Error")
litellm.num_retries = 3 # reset retries
pass
except Exception as e:
print(type(e))
print(e)
print(f"ERROR TYPE {type(e)}")
pytest.fail(f"Error occurred: {e}")
# test_customprompt_together_ai()
@ -1364,9 +1366,11 @@ def test_completion_deep_infra_mistral():
)
# Add any assertions here to check the response
print(response)
except litellm.exceptions.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_deep_infra_mistral()
# test_completion_deep_infra_mistral()
# Palm tests
def test_completion_palm():
@ -1454,8 +1458,8 @@ def test_moderation():
openai.api_version = "GM"
response = litellm.moderation(input="i'm ishaan cto of litellm")
print(response)
output = response["results"][0]
output = response.results[0]
print(output)
return output
# test_moderation()
test_moderation()

View file

@ -3182,6 +3182,13 @@ def exception_type(
llm_provider="openai",
response=original_exception.response
)
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
message=f"OpenAIException - {original_exception.message}",
model=model,
llm_provider="openai",
)
else:
exception_mapping_worked = True
raise APIError(
@ -3707,7 +3714,10 @@ def exception_type(
)
elif custom_llm_provider == "together_ai":
import json
try:
error_response = json.loads(error_str)
except:
error_response = {"error": error_str}
if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]:
exception_mapping_worked = True
raise ContextWindowExceededError(
@ -3732,6 +3742,7 @@ def exception_type(
llm_provider="together_ai",
response=original_exception.response
)
elif "error" in error_response and "API key doesn't match expected format." in error_response["error"]:
exception_mapping_worked = True
raise BadRequestError(
@ -3764,6 +3775,13 @@ def exception_type(
model=model,
response=original_exception.response
)
elif original_exception.status_code == 524:
exception_mapping_worked = True
raise Timeout(
message=f"TogetherAIException - {original_exception.message}",
llm_provider="together_ai",
model=model,
)
else:
exception_mapping_worked = True
raise APIError(
@ -3967,8 +3985,8 @@ def exception_type(
model=model,
request=original_exception.request
)
exception_mapping_worked = True
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
exception_mapping_worked = True
raise BadRequestError(
message=f"OpenAIException: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}",
model=model,