fix(vertex_ai.py): add exception mapping for acompletion calls

This commit is contained in:
Krrish Dholakia 2023-12-13 16:35:50 -08:00
parent 16b7b8097b
commit c673a23769
5 changed files with 99 additions and 70 deletions

View file

@ -18,10 +18,12 @@ from openai import (
APIError, APIError,
APITimeoutError, APITimeoutError,
APIConnectionError, APIConnectionError,
APIResponseValidationError APIResponseValidationError,
UnprocessableEntityError
) )
import httpx import httpx
class AuthenticationError(AuthenticationError): # type: ignore class AuthenticationError(AuthenticationError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 401 self.status_code = 401
@ -46,6 +48,18 @@ class BadRequestError(BadRequestError): # type: ignore
body=None body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 422
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 408 self.status_code = 408

View file

@ -220,64 +220,67 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
""" """
Add support for acompletion calls for gemini-pro Add support for acompletion calls for gemini-pro
""" """
from vertexai.preview.generative_models import GenerationConfig try:
from vertexai.preview.generative_models import GenerationConfig
if mode == "": if mode == "":
# gemini-pro # gemini-pro
chat = llm_model.start_chat() chat = llm_model.start_chat()
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params)) response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
completion_response = response_obj.text completion_response = response_obj.text
response_obj = response_obj._raw_response response_obj = response_obj._raw_response
elif mode == "chat": elif mode == "chat":
# chat-bison etc. # chat-bison etc.
chat = llm_model.start_chat() chat = llm_model.start_chat()
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, **optional_params) response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text completion_response = response_obj.text
elif mode == "text": elif mode == "text":
# gecko etc. # gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await llm_model.predict_async(prompt, **optional_params) response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text completion_response = response_obj.text
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
usage = Usage(
prompt_tokens=prompt_tokens, ## RESPONSE OBJECT
completion_tokens=completion_tokens, if len(str(completion_response)) > 0:
total_tokens=prompt_tokens + completion_tokens model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
model_response.usage = usage usage = Usage(
return model_response prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params): async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params):
""" """

View file

@ -63,21 +63,22 @@ def load_vertex_ai_credentials():
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name) os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
def test_vertex_ai_sdk(): async def get_response():
load_vertex_ai_credentials() load_vertex_ai_credentials()
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
llm_model = GenerativeModel("gemini-pro") response = await acompletion(
chat = llm_model.start_chat() model="gemini-pro",
print(chat.send_message("write code for saying hi from LiteLLM", generation_config=GenerationConfig(**{})).text) messages=[
test_vertex_ai_sdk() {
"role": "system",
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
},
{"role": "user", "content": prompt},
],
)
return response
def simple_test_vertex_ai(): asyncio.run(get_response())
try:
load_vertex_ai_credentials()
response = completion(model="gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
print(f'response: {response}')
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
def test_vertex_ai(): def test_vertex_ai():
import random import random

View file

@ -1040,6 +1040,8 @@ def test_completion_together_ai_mixtral():
cost = completion_cost(completion_response=response) cost = completion_cost(completion_response=response)
assert cost > 0.0 assert cost > 0.0
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}") print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
except litellm.Timeout as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -52,7 +52,8 @@ from .exceptions import (
Timeout, Timeout,
APIConnectionError, APIConnectionError,
APIError, APIError,
BudgetExceededError BudgetExceededError,
UnprocessableEntityError
) )
from typing import cast, List, Dict, Union, Optional, Literal from typing import cast, List, Dict, Union, Optional, Literal
from .caching import Cache from .caching import Cache
@ -4432,7 +4433,15 @@ def exception_type(
) )
elif "403" in error_str: elif "403" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise U(
message=f"VertexAIException - {error_str}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response
)
elif "The response was blocked." in error_str:
exception_mapping_worked = True
raise UnprocessableEntityError(
message=f"VertexAIException - {error_str}", message=f"VertexAIException - {error_str}",
model=model, model=model,
llm_provider="vertex_ai", llm_provider="vertex_ai",