fix(together_ai.py): exception mapping for tgai

This commit is contained in:
Krrish Dholakia 2023-11-13 13:17:15 -08:00
parent aa8ca781ba
commit d4de55b053
4 changed files with 103 additions and 50 deletions

View file

@ -5,6 +5,7 @@ import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
import httpx
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -12,6 +13,8 @@ class TogetherAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs

View file

@ -877,28 +877,39 @@ def completion(
return response return response
response = model_response response = model_response
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter": elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
openai.api_type = "openai" api_base = (
# not sure if this will work after someone first uses another API api_base
openai.base_url = ( or litellm.api_base
litellm.api_base or "https://openrouter.ai/api/v1"
if litellm.api_base is not None )
else "https://openrouter.ai/api/v1"
api_key = (
api_key or
litellm.api_key or
litellm.openrouter_key or
get_secret("OPENROUTER_API_KEY") or
get_secret("OR_API_KEY")
)
openrouter_site_url = (
get_secret("OR_SITE_URL")
or "https://litellm.ai"
)
openrouter_app_name = (
get_secret("OR_APP_NAME")
or "liteLLM"
) )
openai.api_version = None
if litellm.organization:
openai.organization = litellm.organization
if api_key:
openai.api_key = api_key
elif litellm.openrouter_key:
openai.api_key = litellm.openrouter_key
else:
openai.api_key = get_secret("OPENROUTER_API_KEY") or get_secret(
"OR_API_KEY"
) or litellm.api_key
headers = ( headers = (
headers or headers or
litellm.headers litellm.headers or
{
"HTTP-Referer": openrouter_site_url,
"X-Title": openrouter_app_name,
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
) )
data = { data = {
@ -909,27 +920,44 @@ def completion(
## LOGGING ## LOGGING
logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers}) logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers})
## COMPLETION CALL ## COMPLETION CALL
if headers:
response = openai.chat.completions.create( ## COMPLETION CALL
headers=headers, # type: ignore response = openai_chat_completions.completion(
**data, # type: ignore model=model,
) messages=messages,
else: headers=headers,
openrouter_site_url = get_secret("OR_SITE_URL") api_key=api_key,
openrouter_app_name = get_secret("OR_APP_NAME") api_base=api_base,
# if openrouter_site_url is None, set it to https://litellm.ai model_response=model_response,
if openrouter_site_url is None: print_verbose=print_verbose,
openrouter_site_url = "https://litellm.ai" optional_params=optional_params,
# if openrouter_app_name is None, set it to liteLLM litellm_params=litellm_params,
if openrouter_app_name is None: logger_fn=logger_fn,
openrouter_app_name = "liteLLM" logging_obj=logging,
response = openai.chat.completions.create( # type: ignore acompletion=acompletion
extra_headers=httpx.Headers({ # type: ignore )
"HTTP-Referer": openrouter_site_url, # type: ignore
"X-Title": openrouter_app_name, # type: ignore # if headers:
}), # type: ignore # response = openai.chat.completions.create(
**data, # headers=headers, # type: ignore
) # **data, # type: ignore
# )
# else:
# openrouter_site_url = get_secret("OR_SITE_URL")
# openrouter_app_name = get_secret("OR_APP_NAME")
# # if openrouter_site_url is None, set it to https://litellm.ai
# if openrouter_site_url is None:
# openrouter_site_url = "https://litellm.ai"
# # if openrouter_app_name is None, set it to liteLLM
# if openrouter_app_name is None:
# openrouter_app_name = "liteLLM"
# response = openai.chat.completions.create( # type: ignore
# extra_headers=httpx.Headers({ # type: ignore
# "HTTP-Referer": openrouter_site_url, # type: ignore
# "X-Title": openrouter_app_name, # type: ignore
# }), # type: ignore
# **data,
# )
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, api_key=openai.api_key, original_response=response input=messages, api_key=openai.api_key, original_response=response
@ -1961,7 +1989,7 @@ def moderation(input: str, api_key: Optional[str]=None):
openai.api_key = api_key openai.api_key = api_key
openai.api_type = "open_ai" # type: ignore openai.api_type = "open_ai" # type: ignore
openai.api_version = None openai.api_version = None
openai.base_url = "https://api.openai.com/v1" openai.base_url = "https://api.openai.com/v1/"
response = openai.moderations.create(input=input) response = openai.moderations.create(input=input)
return response return response

View file

@ -491,7 +491,7 @@ def test_completion_openrouter1():
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_openrouter1() # test_completion_openrouter1()
def test_completion_openrouter2(): def test_completion_openrouter2():
try: try:
@ -873,18 +873,20 @@ def test_completion_together_ai():
# test_completion_together_ai() # test_completion_together_ai()
def test_customprompt_together_ai(): def test_customprompt_together_ai():
try: try:
litellm.set_verbose = True litellm.set_verbose = False
litellm.num_retries = 0
response = completion( response = completion(
model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10",
messages=messages, messages=messages,
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}} roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
) )
print(response) print(response)
except litellm.APIError as e: except litellm.exceptions.Timeout as e:
print(f"Timeout Error")
litellm.num_retries = 3 # reset retries
pass pass
except Exception as e: except Exception as e:
print(type(e)) print(f"ERROR TYPE {type(e)}")
print(e)
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_customprompt_together_ai() # test_customprompt_together_ai()
@ -1364,9 +1366,11 @@ def test_completion_deep_infra_mistral():
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except litellm.exceptions.Timeout as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_deep_infra_mistral() # test_completion_deep_infra_mistral()
# Palm tests # Palm tests
def test_completion_palm(): def test_completion_palm():
@ -1454,8 +1458,8 @@ def test_moderation():
openai.api_version = "GM" openai.api_version = "GM"
response = litellm.moderation(input="i'm ishaan cto of litellm") response = litellm.moderation(input="i'm ishaan cto of litellm")
print(response) print(response)
output = response["results"][0] output = response.results[0]
print(output) print(output)
return output return output
# test_moderation() test_moderation()

View file

@ -3182,6 +3182,13 @@ def exception_type(
llm_provider="openai", llm_provider="openai",
response=original_exception.response response=original_exception.response
) )
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
message=f"OpenAIException - {original_exception.message}",
model=model,
llm_provider="openai",
)
else: else:
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
@ -3707,7 +3714,10 @@ def exception_type(
) )
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
import json import json
error_response = json.loads(error_str) try:
error_response = json.loads(error_str)
except:
error_response = {"error": error_str}
if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]: if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]:
exception_mapping_worked = True exception_mapping_worked = True
raise ContextWindowExceededError( raise ContextWindowExceededError(
@ -3732,6 +3742,7 @@ def exception_type(
llm_provider="together_ai", llm_provider="together_ai",
response=original_exception.response response=original_exception.response
) )
elif "error" in error_response and "API key doesn't match expected format." in error_response["error"]: elif "error" in error_response and "API key doesn't match expected format." in error_response["error"]:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
@ -3764,6 +3775,13 @@ def exception_type(
model=model, model=model,
response=original_exception.response response=original_exception.response
) )
elif original_exception.status_code == 524:
exception_mapping_worked = True
raise Timeout(
message=f"TogetherAIException - {original_exception.message}",
llm_provider="together_ai",
model=model,
)
else: else:
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
@ -3967,8 +3985,8 @@ def exception_type(
model=model, model=model,
request=original_exception.request request=original_exception.request
) )
exception_mapping_worked = True
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"OpenAIException: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}", message=f"OpenAIException: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}",
model=model, model=model,