mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(together_ai.py): exception mapping for tgai
This commit is contained in:
parent
aa8ca781ba
commit
d4de55b053
4 changed files with 103 additions and 50 deletions
|
@ -5,6 +5,7 @@ import requests
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
|
import httpx
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
@ -12,6 +13,8 @@ class TogetherAIError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference")
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message
|
self.message
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
108
litellm/main.py
108
litellm/main.py
|
@ -877,28 +877,39 @@ def completion(
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
|
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
|
||||||
openai.api_type = "openai"
|
api_base = (
|
||||||
# not sure if this will work after someone first uses another API
|
api_base
|
||||||
openai.base_url = (
|
or litellm.api_base
|
||||||
litellm.api_base
|
or "https://openrouter.ai/api/v1"
|
||||||
if litellm.api_base is not None
|
)
|
||||||
else "https://openrouter.ai/api/v1"
|
|
||||||
|
api_key = (
|
||||||
|
api_key or
|
||||||
|
litellm.api_key or
|
||||||
|
litellm.openrouter_key or
|
||||||
|
get_secret("OPENROUTER_API_KEY") or
|
||||||
|
get_secret("OR_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
openrouter_site_url = (
|
||||||
|
get_secret("OR_SITE_URL")
|
||||||
|
or "https://litellm.ai"
|
||||||
|
)
|
||||||
|
|
||||||
|
openrouter_app_name = (
|
||||||
|
get_secret("OR_APP_NAME")
|
||||||
|
or "liteLLM"
|
||||||
)
|
)
|
||||||
openai.api_version = None
|
|
||||||
if litellm.organization:
|
|
||||||
openai.organization = litellm.organization
|
|
||||||
if api_key:
|
|
||||||
openai.api_key = api_key
|
|
||||||
elif litellm.openrouter_key:
|
|
||||||
openai.api_key = litellm.openrouter_key
|
|
||||||
else:
|
|
||||||
openai.api_key = get_secret("OPENROUTER_API_KEY") or get_secret(
|
|
||||||
"OR_API_KEY"
|
|
||||||
) or litellm.api_key
|
|
||||||
|
|
||||||
headers = (
|
headers = (
|
||||||
headers or
|
headers or
|
||||||
litellm.headers
|
litellm.headers or
|
||||||
|
{
|
||||||
|
"HTTP-Referer": openrouter_site_url,
|
||||||
|
"X-Title": openrouter_app_name,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
@ -909,27 +920,44 @@ def completion(
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers})
|
logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers})
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if headers:
|
|
||||||
response = openai.chat.completions.create(
|
## COMPLETION CALL
|
||||||
headers=headers, # type: ignore
|
response = openai_chat_completions.completion(
|
||||||
**data, # type: ignore
|
model=model,
|
||||||
)
|
messages=messages,
|
||||||
else:
|
headers=headers,
|
||||||
openrouter_site_url = get_secret("OR_SITE_URL")
|
api_key=api_key,
|
||||||
openrouter_app_name = get_secret("OR_APP_NAME")
|
api_base=api_base,
|
||||||
# if openrouter_site_url is None, set it to https://litellm.ai
|
model_response=model_response,
|
||||||
if openrouter_site_url is None:
|
print_verbose=print_verbose,
|
||||||
openrouter_site_url = "https://litellm.ai"
|
optional_params=optional_params,
|
||||||
# if openrouter_app_name is None, set it to liteLLM
|
litellm_params=litellm_params,
|
||||||
if openrouter_app_name is None:
|
logger_fn=logger_fn,
|
||||||
openrouter_app_name = "liteLLM"
|
logging_obj=logging,
|
||||||
response = openai.chat.completions.create( # type: ignore
|
acompletion=acompletion
|
||||||
extra_headers=httpx.Headers({ # type: ignore
|
)
|
||||||
"HTTP-Referer": openrouter_site_url, # type: ignore
|
|
||||||
"X-Title": openrouter_app_name, # type: ignore
|
# if headers:
|
||||||
}), # type: ignore
|
# response = openai.chat.completions.create(
|
||||||
**data,
|
# headers=headers, # type: ignore
|
||||||
)
|
# **data, # type: ignore
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# openrouter_site_url = get_secret("OR_SITE_URL")
|
||||||
|
# openrouter_app_name = get_secret("OR_APP_NAME")
|
||||||
|
# # if openrouter_site_url is None, set it to https://litellm.ai
|
||||||
|
# if openrouter_site_url is None:
|
||||||
|
# openrouter_site_url = "https://litellm.ai"
|
||||||
|
# # if openrouter_app_name is None, set it to liteLLM
|
||||||
|
# if openrouter_app_name is None:
|
||||||
|
# openrouter_app_name = "liteLLM"
|
||||||
|
# response = openai.chat.completions.create( # type: ignore
|
||||||
|
# extra_headers=httpx.Headers({ # type: ignore
|
||||||
|
# "HTTP-Referer": openrouter_site_url, # type: ignore
|
||||||
|
# "X-Title": openrouter_app_name, # type: ignore
|
||||||
|
# }), # type: ignore
|
||||||
|
# **data,
|
||||||
|
# )
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages, api_key=openai.api_key, original_response=response
|
input=messages, api_key=openai.api_key, original_response=response
|
||||||
|
@ -1961,7 +1989,7 @@ def moderation(input: str, api_key: Optional[str]=None):
|
||||||
openai.api_key = api_key
|
openai.api_key = api_key
|
||||||
openai.api_type = "open_ai" # type: ignore
|
openai.api_type = "open_ai" # type: ignore
|
||||||
openai.api_version = None
|
openai.api_version = None
|
||||||
openai.base_url = "https://api.openai.com/v1"
|
openai.base_url = "https://api.openai.com/v1/"
|
||||||
response = openai.moderations.create(input=input)
|
response = openai.moderations.create(input=input)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
@ -491,7 +491,7 @@ def test_completion_openrouter1():
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_openrouter1()
|
# test_completion_openrouter1()
|
||||||
|
|
||||||
def test_completion_openrouter2():
|
def test_completion_openrouter2():
|
||||||
try:
|
try:
|
||||||
|
@ -873,18 +873,20 @@ def test_completion_together_ai():
|
||||||
# test_completion_together_ai()
|
# test_completion_together_ai()
|
||||||
def test_customprompt_together_ai():
|
def test_customprompt_together_ai():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = False
|
||||||
|
litellm.num_retries = 0
|
||||||
response = completion(
|
response = completion(
|
||||||
model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10",
|
model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
|
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
except litellm.APIError as e:
|
except litellm.exceptions.Timeout as e:
|
||||||
|
print(f"Timeout Error")
|
||||||
|
litellm.num_retries = 3 # reset retries
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(type(e))
|
print(f"ERROR TYPE {type(e)}")
|
||||||
print(e)
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_customprompt_together_ai()
|
# test_customprompt_together_ai()
|
||||||
|
@ -1364,9 +1366,11 @@ def test_completion_deep_infra_mistral():
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
except litellm.exceptions.Timeout as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
test_completion_deep_infra_mistral()
|
# test_completion_deep_infra_mistral()
|
||||||
|
|
||||||
# Palm tests
|
# Palm tests
|
||||||
def test_completion_palm():
|
def test_completion_palm():
|
||||||
|
@ -1454,8 +1458,8 @@ def test_moderation():
|
||||||
openai.api_version = "GM"
|
openai.api_version = "GM"
|
||||||
response = litellm.moderation(input="i'm ishaan cto of litellm")
|
response = litellm.moderation(input="i'm ishaan cto of litellm")
|
||||||
print(response)
|
print(response)
|
||||||
output = response["results"][0]
|
output = response.results[0]
|
||||||
print(output)
|
print(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# test_moderation()
|
test_moderation()
|
|
@ -3182,6 +3182,13 @@ def exception_type(
|
||||||
llm_provider="openai",
|
llm_provider="openai",
|
||||||
response=original_exception.response
|
response=original_exception.response
|
||||||
)
|
)
|
||||||
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise Timeout(
|
||||||
|
message=f"OpenAIException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="openai",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise APIError(
|
raise APIError(
|
||||||
|
@ -3707,7 +3714,10 @@ def exception_type(
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
import json
|
import json
|
||||||
error_response = json.loads(error_str)
|
try:
|
||||||
|
error_response = json.loads(error_str)
|
||||||
|
except:
|
||||||
|
error_response = {"error": error_str}
|
||||||
if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]:
|
if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise ContextWindowExceededError(
|
raise ContextWindowExceededError(
|
||||||
|
@ -3732,6 +3742,7 @@ def exception_type(
|
||||||
llm_provider="together_ai",
|
llm_provider="together_ai",
|
||||||
response=original_exception.response
|
response=original_exception.response
|
||||||
)
|
)
|
||||||
|
|
||||||
elif "error" in error_response and "API key doesn't match expected format." in error_response["error"]:
|
elif "error" in error_response and "API key doesn't match expected format." in error_response["error"]:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
|
@ -3764,6 +3775,13 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response
|
response=original_exception.response
|
||||||
)
|
)
|
||||||
|
elif original_exception.status_code == 524:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise Timeout(
|
||||||
|
message=f"TogetherAIException - {original_exception.message}",
|
||||||
|
llm_provider="together_ai",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise APIError(
|
raise APIError(
|
||||||
|
@ -3967,8 +3985,8 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request
|
request=original_exception.request
|
||||||
)
|
)
|
||||||
exception_mapping_worked = True
|
|
||||||
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
|
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
|
||||||
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
message=f"OpenAIException: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}",
|
message=f"OpenAIException: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}",
|
||||||
model=model,
|
model=model,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue