mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(palm.py): exception mapping bad requests / filtered responses
This commit is contained in:
parent
e1a567a353
commit
e9e86cac79
3 changed files with 24 additions and 7 deletions
|
@ -5,12 +5,14 @@ import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
|
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import sys
|
import sys, httpx
|
||||||
|
|
||||||
class PalmError(Exception):
|
class PalmError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat")
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message
|
self.message
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
@ -147,6 +149,11 @@ def completion(
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise PalmError(message=traceback.format_exc(), status_code=response.status_code)
|
raise PalmError(message=traceback.format_exc(), status_code=response.status_code)
|
||||||
|
|
||||||
|
try:
|
||||||
|
completion_response = model_response["choices"][0]["message"].get("content")
|
||||||
|
except:
|
||||||
|
raise PalmError(status_code=400, message=f"No response received. Original response - {response}")
|
||||||
|
|
||||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||||
prompt_tokens = len(
|
prompt_tokens = len(
|
||||||
encoding.encode(prompt)
|
encoding.encode(prompt)
|
||||||
|
|
|
@ -58,7 +58,7 @@ def test_completion_claude():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_completion_claude()
|
# test_completion_claude()
|
||||||
|
|
||||||
# def test_completion_oobabooga():
|
# def test_completion_oobabooga():
|
||||||
# try:
|
# try:
|
||||||
|
@ -1352,7 +1352,7 @@ def test_completion_deep_infra():
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
test_completion_deep_infra()
|
# test_completion_deep_infra()
|
||||||
|
|
||||||
def test_completion_deep_infra_mistral():
|
def test_completion_deep_infra_mistral():
|
||||||
print("deep infra test with temp=0")
|
print("deep infra test with temp=0")
|
||||||
|
@ -1374,15 +1374,16 @@ def test_completion_deep_infra_mistral():
|
||||||
|
|
||||||
# Palm tests
|
# Palm tests
|
||||||
def test_completion_palm():
|
def test_completion_palm():
|
||||||
# litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "palm/chat-bison"
|
model_name = "palm/chat-bison"
|
||||||
|
messages = [{"role": "user", "content": "what model are you?"}]
|
||||||
try:
|
try:
|
||||||
response = completion(model=model_name, messages=messages, stop=["stop"])
|
response = completion(model=model_name, messages=messages)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_palm()
|
test_completion_palm()
|
||||||
|
|
||||||
# test palm with streaming
|
# test palm with streaming
|
||||||
def test_completion_palm_stream():
|
def test_completion_palm_stream():
|
||||||
|
@ -1462,4 +1463,4 @@ def test_moderation():
|
||||||
print(output)
|
print(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
test_moderation()
|
# test_moderation()
|
|
@ -3445,6 +3445,15 @@ def exception_type(
|
||||||
llm_provider="palm",
|
llm_provider="palm",
|
||||||
response=original_exception.response
|
response=original_exception.response
|
||||||
)
|
)
|
||||||
|
if hasattr(original_exception, "status_code"):
|
||||||
|
if original_exception.status_code == 400:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise BadRequestError(
|
||||||
|
message=f"PalmException - {error_str}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="palm",
|
||||||
|
response=original_exception.response
|
||||||
|
)
|
||||||
# Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes
|
# Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes
|
||||||
elif custom_llm_provider == "cohere": # Cohere
|
elif custom_llm_provider == "cohere": # Cohere
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue