mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(together_ai.py): exception mapping for tgai
This commit is contained in:
parent
aa8ca781ba
commit
d4de55b053
4 changed files with 103 additions and 50 deletions
|
@ -5,6 +5,7 @@ import requests
|
|||
import time
|
||||
from typing import Callable, Optional
|
||||
import litellm
|
||||
import httpx
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
@ -12,6 +13,8 @@ class TogetherAIError(Exception):
|
|||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
|
106
litellm/main.py
106
litellm/main.py
|
@ -877,28 +877,39 @@ def completion(
|
|||
return response
|
||||
response = model_response
|
||||
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
|
||||
openai.api_type = "openai"
|
||||
# not sure if this will work after someone first uses another API
|
||||
openai.base_url = (
|
||||
litellm.api_base
|
||||
if litellm.api_base is not None
|
||||
else "https://openrouter.ai/api/v1"
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or "https://openrouter.ai/api/v1"
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key or
|
||||
litellm.api_key or
|
||||
litellm.openrouter_key or
|
||||
get_secret("OPENROUTER_API_KEY") or
|
||||
get_secret("OR_API_KEY")
|
||||
)
|
||||
|
||||
openrouter_site_url = (
|
||||
get_secret("OR_SITE_URL")
|
||||
or "https://litellm.ai"
|
||||
)
|
||||
|
||||
openrouter_app_name = (
|
||||
get_secret("OR_APP_NAME")
|
||||
or "liteLLM"
|
||||
)
|
||||
openai.api_version = None
|
||||
if litellm.organization:
|
||||
openai.organization = litellm.organization
|
||||
if api_key:
|
||||
openai.api_key = api_key
|
||||
elif litellm.openrouter_key:
|
||||
openai.api_key = litellm.openrouter_key
|
||||
else:
|
||||
openai.api_key = get_secret("OPENROUTER_API_KEY") or get_secret(
|
||||
"OR_API_KEY"
|
||||
) or litellm.api_key
|
||||
|
||||
headers = (
|
||||
headers or
|
||||
litellm.headers
|
||||
litellm.headers or
|
||||
{
|
||||
"HTTP-Referer": openrouter_site_url,
|
||||
"X-Title": openrouter_app_name,
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
)
|
||||
|
||||
data = {
|
||||
|
@ -909,27 +920,44 @@ def completion(
|
|||
## LOGGING
|
||||
logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers})
|
||||
## COMPLETION CALL
|
||||
if headers:
|
||||
response = openai.chat.completions.create(
|
||||
headers=headers, # type: ignore
|
||||
**data, # type: ignore
|
||||
)
|
||||
else:
|
||||
openrouter_site_url = get_secret("OR_SITE_URL")
|
||||
openrouter_app_name = get_secret("OR_APP_NAME")
|
||||
# if openrouter_site_url is None, set it to https://litellm.ai
|
||||
if openrouter_site_url is None:
|
||||
openrouter_site_url = "https://litellm.ai"
|
||||
# if openrouter_app_name is None, set it to liteLLM
|
||||
if openrouter_app_name is None:
|
||||
openrouter_app_name = "liteLLM"
|
||||
response = openai.chat.completions.create( # type: ignore
|
||||
extra_headers=httpx.Headers({ # type: ignore
|
||||
"HTTP-Referer": openrouter_site_url, # type: ignore
|
||||
"X-Title": openrouter_app_name, # type: ignore
|
||||
}), # type: ignore
|
||||
**data,
|
||||
|
||||
## COMPLETION CALL
|
||||
response = openai_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion
|
||||
)
|
||||
|
||||
# if headers:
|
||||
# response = openai.chat.completions.create(
|
||||
# headers=headers, # type: ignore
|
||||
# **data, # type: ignore
|
||||
# )
|
||||
# else:
|
||||
# openrouter_site_url = get_secret("OR_SITE_URL")
|
||||
# openrouter_app_name = get_secret("OR_APP_NAME")
|
||||
# # if openrouter_site_url is None, set it to https://litellm.ai
|
||||
# if openrouter_site_url is None:
|
||||
# openrouter_site_url = "https://litellm.ai"
|
||||
# # if openrouter_app_name is None, set it to liteLLM
|
||||
# if openrouter_app_name is None:
|
||||
# openrouter_app_name = "liteLLM"
|
||||
# response = openai.chat.completions.create( # type: ignore
|
||||
# extra_headers=httpx.Headers({ # type: ignore
|
||||
# "HTTP-Referer": openrouter_site_url, # type: ignore
|
||||
# "X-Title": openrouter_app_name, # type: ignore
|
||||
# }), # type: ignore
|
||||
# **data,
|
||||
# )
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages, api_key=openai.api_key, original_response=response
|
||||
|
@ -1961,7 +1989,7 @@ def moderation(input: str, api_key: Optional[str]=None):
|
|||
openai.api_key = api_key
|
||||
openai.api_type = "open_ai" # type: ignore
|
||||
openai.api_version = None
|
||||
openai.base_url = "https://api.openai.com/v1"
|
||||
openai.base_url = "https://api.openai.com/v1/"
|
||||
response = openai.moderations.create(input=input)
|
||||
return response
|
||||
|
||||
|
|
|
@ -873,18 +873,20 @@ def test_completion_together_ai():
|
|||
# test_completion_together_ai()
|
||||
def test_customprompt_together_ai():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
litellm.set_verbose = False
|
||||
litellm.num_retries = 0
|
||||
response = completion(
|
||||
model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10",
|
||||
messages=messages,
|
||||
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
|
||||
)
|
||||
print(response)
|
||||
except litellm.APIError as e:
|
||||
except litellm.exceptions.Timeout as e:
|
||||
print(f"Timeout Error")
|
||||
litellm.num_retries = 3 # reset retries
|
||||
pass
|
||||
except Exception as e:
|
||||
print(type(e))
|
||||
print(e)
|
||||
print(f"ERROR TYPE {type(e)}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_customprompt_together_ai()
|
||||
|
@ -1364,9 +1366,11 @@ def test_completion_deep_infra_mistral():
|
|||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except litellm.exceptions.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_completion_deep_infra_mistral()
|
||||
# test_completion_deep_infra_mistral()
|
||||
|
||||
# Palm tests
|
||||
def test_completion_palm():
|
||||
|
@ -1454,8 +1458,8 @@ def test_moderation():
|
|||
openai.api_version = "GM"
|
||||
response = litellm.moderation(input="i'm ishaan cto of litellm")
|
||||
print(response)
|
||||
output = response["results"][0]
|
||||
output = response.results[0]
|
||||
print(output)
|
||||
return output
|
||||
|
||||
# test_moderation()
|
||||
test_moderation()
|
|
@ -3182,6 +3182,13 @@ def exception_type(
|
|||
llm_provider="openai",
|
||||
response=original_exception.response
|
||||
)
|
||||
elif original_exception.status_code == 504: # gateway timeout error
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"OpenAIException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="openai",
|
||||
)
|
||||
else:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
|
@ -3707,7 +3714,10 @@ def exception_type(
|
|||
)
|
||||
elif custom_llm_provider == "together_ai":
|
||||
import json
|
||||
try:
|
||||
error_response = json.loads(error_str)
|
||||
except:
|
||||
error_response = {"error": error_str}
|
||||
if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]:
|
||||
exception_mapping_worked = True
|
||||
raise ContextWindowExceededError(
|
||||
|
@ -3732,6 +3742,7 @@ def exception_type(
|
|||
llm_provider="together_ai",
|
||||
response=original_exception.response
|
||||
)
|
||||
|
||||
elif "error" in error_response and "API key doesn't match expected format." in error_response["error"]:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
|
@ -3764,6 +3775,13 @@ def exception_type(
|
|||
model=model,
|
||||
response=original_exception.response
|
||||
)
|
||||
elif original_exception.status_code == 524:
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"TogetherAIException - {original_exception.message}",
|
||||
llm_provider="together_ai",
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
|
@ -3967,8 +3985,8 @@ def exception_type(
|
|||
model=model,
|
||||
request=original_exception.request
|
||||
)
|
||||
exception_mapping_worked = True
|
||||
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"OpenAIException: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}",
|
||||
model=model,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue