fix(vertex_ai.py): add exception mapping for acompletion calls

This commit is contained in:
Krrish Dholakia 2023-12-13 16:35:50 -08:00
parent 16b7b8097b
commit c673a23769
5 changed files with 99 additions and 70 deletions

View file

@ -18,10 +18,12 @@ from openai import (
APIError, APIError,
APITimeoutError, APITimeoutError,
APIConnectionError, APIConnectionError,
APIResponseValidationError APIResponseValidationError,
UnprocessableEntityError
) )
import httpx import httpx
class AuthenticationError(AuthenticationError): # type: ignore class AuthenticationError(AuthenticationError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 401 self.status_code = 401
@ -46,6 +48,18 @@ class BadRequestError(BadRequestError): # type: ignore
body=None body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 422
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 408 self.status_code = 408

View file

@ -220,6 +220,7 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
""" """
Add support for acompletion calls for gemini-pro Add support for acompletion calls for gemini-pro
""" """
try:
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
if mode == "": if mode == "":
@ -278,6 +279,8 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params): async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params):
""" """

View file

@ -63,21 +63,22 @@ def load_vertex_ai_credentials():
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name) os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
def test_vertex_ai_sdk(): async def get_response():
load_vertex_ai_credentials() load_vertex_ai_credentials()
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
llm_model = GenerativeModel("gemini-pro") response = await acompletion(
chat = llm_model.start_chat() model="gemini-pro",
print(chat.send_message("write code for saying hi from LiteLLM", generation_config=GenerationConfig(**{})).text) messages=[
test_vertex_ai_sdk() {
"role": "system",
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
},
{"role": "user", "content": prompt},
],
)
return response
def simple_test_vertex_ai(): asyncio.run(get_response())
try:
load_vertex_ai_credentials()
response = completion(model="gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
print(f'response: {response}')
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
def test_vertex_ai(): def test_vertex_ai():
import random import random

View file

@ -1040,6 +1040,8 @@ def test_completion_together_ai_mixtral():
cost = completion_cost(completion_response=response) cost = completion_cost(completion_response=response)
assert cost > 0.0 assert cost > 0.0
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}") print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
except litellm.Timeout as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -52,7 +52,8 @@ from .exceptions import (
Timeout, Timeout,
APIConnectionError, APIConnectionError,
APIError, APIError,
BudgetExceededError BudgetExceededError,
UnprocessableEntityError
) )
from typing import cast, List, Dict, Union, Optional, Literal from typing import cast, List, Dict, Union, Optional, Literal
from .caching import Cache from .caching import Cache
@ -4432,7 +4433,15 @@ def exception_type(
) )
elif "403" in error_str: elif "403" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise U(
message=f"VertexAIException - {error_str}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response
)
elif "The response was blocked." in error_str:
exception_mapping_worked = True
raise UnprocessableEntityError(
message=f"VertexAIException - {error_str}", message=f"VertexAIException - {error_str}",
model=model, model=model,
llm_provider="vertex_ai", llm_provider="vertex_ai",