fix(palm.py): exception mapping bad requests / filtered responses

This commit is contained in:
Krrish Dholakia 2023-11-14 11:53:06 -08:00
parent e1a567a353
commit e9e86cac79
3 changed files with 24 additions and 7 deletions

View file

@ -5,12 +5,14 @@ import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm
import sys
import sys, httpx
class PalmError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
@ -146,6 +148,11 @@ def completion(
except Exception as e:
traceback.print_exc()
raise PalmError(message=traceback.format_exc(), status_code=response.status_code)
try:
completion_response = model_response["choices"][0]["message"].get("content")
except:
raise PalmError(status_code=400, message=f"No response received. Original response - {response}")
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(

View file

@ -58,7 +58,7 @@ def test_completion_claude():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_claude()
# test_completion_claude()
# def test_completion_oobabooga():
# try:
@ -1352,7 +1352,7 @@ def test_completion_deep_infra():
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_deep_infra()
# test_completion_deep_infra()
def test_completion_deep_infra_mistral():
print("deep infra test with temp=0")
@ -1374,15 +1374,16 @@ def test_completion_deep_infra_mistral():
# Palm tests
def test_completion_palm():
# litellm.set_verbose = True
litellm.set_verbose = True
model_name = "palm/chat-bison"
messages = [{"role": "user", "content": "what model are you?"}]
try:
response = completion(model=model_name, messages=messages, stop=["stop"])
response = completion(model=model_name, messages=messages)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_palm()
test_completion_palm()
# test palm with streaming
def test_completion_palm_stream():
@ -1462,4 +1463,4 @@ def test_moderation():
print(output)
return output
test_moderation()
# test_moderation()

View file

@ -3445,6 +3445,15 @@ def exception_type(
llm_provider="palm",
response=original_exception.response
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"PalmException - {error_str}",
model=model,
llm_provider="palm",
response=original_exception.response
)
# Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes
elif custom_llm_provider == "cohere": # Cohere
if (