fix(vertex_ai.py): add exception mapping for acompletion calls

This commit is contained in:
Krrish Dholakia 2023-12-13 16:35:50 -08:00
parent 16b7b8097b
commit c673a23769
5 changed files with 99 additions and 70 deletions

View file

@ -18,10 +18,12 @@ from openai import (
APIError,
APITimeoutError,
APIConnectionError,
APIResponseValidationError
APIResponseValidationError,
UnprocessableEntityError
)
import httpx
class AuthenticationError(AuthenticationError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 401
@ -46,6 +48,18 @@ class BadRequestError(BadRequestError): # type: ignore
body=None
) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 422
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 408

View file

@ -220,6 +220,7 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
"""
Add support for acompletion calls for gemini-pro
"""
try:
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
@ -278,6 +279,8 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
)
model_response.usage = usage
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params):
"""

View file

@ -63,21 +63,22 @@ def load_vertex_ai_credentials():
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
def test_vertex_ai_sdk():
async def get_response():
load_vertex_ai_credentials()
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
llm_model = GenerativeModel("gemini-pro")
chat = llm_model.start_chat()
print(chat.send_message("write code for saying hi from LiteLLM", generation_config=GenerationConfig(**{})).text)
test_vertex_ai_sdk()
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
response = await acompletion(
model="gemini-pro",
messages=[
{
"role": "system",
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
},
{"role": "user", "content": prompt},
],
)
return response
def simple_test_vertex_ai():
try:
load_vertex_ai_credentials()
response = completion(model="gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
print(f'response: {response}')
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
asyncio.run(get_response())
def test_vertex_ai():
import random

View file

@ -1040,6 +1040,8 @@ def test_completion_together_ai_mixtral():
cost = completion_cost(completion_response=response)
assert cost > 0.0
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -52,7 +52,8 @@ from .exceptions import (
Timeout,
APIConnectionError,
APIError,
BudgetExceededError
BudgetExceededError,
UnprocessableEntityError
)
from typing import cast, List, Dict, Union, Optional, Literal
from .caching import Cache
@ -4432,7 +4433,15 @@ def exception_type(
)
elif "403" in error_str:
exception_mapping_worked = True
raise AuthenticationError(
raise U(
message=f"VertexAIException - {error_str}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response
)
elif "The response was blocked." in error_str:
exception_mapping_worked = True
raise UnprocessableEntityError(
message=f"VertexAIException - {error_str}",
model=model,
llm_provider="vertex_ai",