mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(vertex_ai.py): add exception mapping for acompletion calls
This commit is contained in:
parent
16b7b8097b
commit
c673a23769
5 changed files with 99 additions and 70 deletions
|
@ -18,10 +18,12 @@ from openai import (
|
||||||
APIError,
|
APIError,
|
||||||
APITimeoutError,
|
APITimeoutError,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIResponseValidationError
|
APIResponseValidationError,
|
||||||
|
UnprocessableEntityError
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(AuthenticationError): # type: ignore
|
class AuthenticationError(AuthenticationError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 401
|
self.status_code = 401
|
||||||
|
@ -46,6 +48,18 @@ class BadRequestError(BadRequestError): # type: ignore
|
||||||
body=None
|
body=None
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
||||||
|
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||||
|
self.status_code = 422
|
||||||
|
self.message = message
|
||||||
|
self.model = model
|
||||||
|
self.llm_provider = llm_provider
|
||||||
|
super().__init__(
|
||||||
|
self.message,
|
||||||
|
response=response,
|
||||||
|
body=None
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
class Timeout(APITimeoutError): # type: ignore
|
class Timeout(APITimeoutError): # type: ignore
|
||||||
def __init__(self, message, model, llm_provider):
|
def __init__(self, message, model, llm_provider):
|
||||||
self.status_code = 408
|
self.status_code = 408
|
||||||
|
|
|
@ -220,6 +220,7 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
|
||||||
"""
|
"""
|
||||||
Add support for acompletion calls for gemini-pro
|
Add support for acompletion calls for gemini-pro
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
|
||||||
if mode == "":
|
if mode == "":
|
||||||
|
@ -278,6 +279,8 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
return model_response
|
return model_response
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params):
|
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -63,21 +63,22 @@ def load_vertex_ai_credentials():
|
||||||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||||
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
|
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
|
||||||
|
|
||||||
def test_vertex_ai_sdk():
|
async def get_response():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
|
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
|
||||||
llm_model = GenerativeModel("gemini-pro")
|
response = await acompletion(
|
||||||
chat = llm_model.start_chat()
|
model="gemini-pro",
|
||||||
print(chat.send_message("write code for saying hi from LiteLLM", generation_config=GenerationConfig(**{})).text)
|
messages=[
|
||||||
test_vertex_ai_sdk()
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
def simple_test_vertex_ai():
|
asyncio.run(get_response())
|
||||||
try:
|
|
||||||
load_vertex_ai_credentials()
|
|
||||||
response = completion(model="gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
|
|
||||||
print(f'response: {response}')
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
def test_vertex_ai():
|
def test_vertex_ai():
|
||||||
import random
|
import random
|
||||||
|
|
|
@ -1040,6 +1040,8 @@ def test_completion_together_ai_mixtral():
|
||||||
cost = completion_cost(completion_response=response)
|
cost = completion_cost(completion_response=response)
|
||||||
assert cost > 0.0
|
assert cost > 0.0
|
||||||
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,8 @@ from .exceptions import (
|
||||||
Timeout,
|
Timeout,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIError,
|
APIError,
|
||||||
BudgetExceededError
|
BudgetExceededError,
|
||||||
|
UnprocessableEntityError
|
||||||
)
|
)
|
||||||
from typing import cast, List, Dict, Union, Optional, Literal
|
from typing import cast, List, Dict, Union, Optional, Literal
|
||||||
from .caching import Cache
|
from .caching import Cache
|
||||||
|
@ -4432,7 +4433,15 @@ def exception_type(
|
||||||
)
|
)
|
||||||
elif "403" in error_str:
|
elif "403" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise AuthenticationError(
|
raise U(
|
||||||
|
message=f"VertexAIException - {error_str}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="vertex_ai",
|
||||||
|
response=original_exception.response
|
||||||
|
)
|
||||||
|
elif "The response was blocked." in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise UnprocessableEntityError(
|
||||||
message=f"VertexAIException - {error_str}",
|
message=f"VertexAIException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="vertex_ai",
|
llm_provider="vertex_ai",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue