mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(vertex_ai.py): add exception mapping for acompletion calls
This commit is contained in:
parent
16b7b8097b
commit
c673a23769
5 changed files with 99 additions and 70 deletions
|
@ -18,10 +18,12 @@ from openai import (
|
||||||
APIError,
|
APIError,
|
||||||
APITimeoutError,
|
APITimeoutError,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIResponseValidationError
|
APIResponseValidationError,
|
||||||
|
UnprocessableEntityError
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(AuthenticationError): # type: ignore
|
class AuthenticationError(AuthenticationError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 401
|
self.status_code = 401
|
||||||
|
@ -46,6 +48,18 @@ class BadRequestError(BadRequestError): # type: ignore
|
||||||
body=None
|
body=None
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
||||||
|
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||||
|
self.status_code = 422
|
||||||
|
self.message = message
|
||||||
|
self.model = model
|
||||||
|
self.llm_provider = llm_provider
|
||||||
|
super().__init__(
|
||||||
|
self.message,
|
||||||
|
response=response,
|
||||||
|
body=None
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
class Timeout(APITimeoutError): # type: ignore
|
class Timeout(APITimeoutError): # type: ignore
|
||||||
def __init__(self, message, model, llm_provider):
|
def __init__(self, message, model, llm_provider):
|
||||||
self.status_code = 408
|
self.status_code = 408
|
||||||
|
|
|
@ -220,64 +220,67 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
|
||||||
"""
|
"""
|
||||||
Add support for acompletion calls for gemini-pro
|
Add support for acompletion calls for gemini-pro
|
||||||
"""
|
"""
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
try:
|
||||||
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
|
||||||
if mode == "":
|
if mode == "":
|
||||||
# gemini-pro
|
# gemini-pro
|
||||||
chat = llm_model.start_chat()
|
chat = llm_model.start_chat()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
|
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
|
||||||
completion_response = response_obj.text
|
completion_response = response_obj.text
|
||||||
response_obj = response_obj._raw_response
|
response_obj = response_obj._raw_response
|
||||||
elif mode == "chat":
|
elif mode == "chat":
|
||||||
# chat-bison etc.
|
# chat-bison etc.
|
||||||
chat = llm_model.start_chat()
|
chat = llm_model.start_chat()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
response_obj = await chat.send_message_async(prompt, **optional_params)
|
response_obj = await chat.send_message_async(prompt, **optional_params)
|
||||||
completion_response = response_obj.text
|
completion_response = response_obj.text
|
||||||
elif mode == "text":
|
elif mode == "text":
|
||||||
# gecko etc.
|
# gecko etc.
|
||||||
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
response_obj = await llm_model.predict_async(prompt, **optional_params)
|
response_obj = await llm_model.predict_async(prompt, **optional_params)
|
||||||
completion_response = response_obj.text
|
completion_response = response_obj.text
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt, api_key=None, original_response=completion_response
|
input=prompt, api_key=None, original_response=completion_response
|
||||||
)
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
if len(str(completion_response)) > 0:
|
|
||||||
model_response["choices"][0]["message"][
|
|
||||||
"content"
|
|
||||||
] = str(completion_response)
|
|
||||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
|
||||||
model_response["created"] = int(time.time())
|
|
||||||
model_response["model"] = model
|
|
||||||
## CALCULATING USAGE
|
|
||||||
if model in litellm.vertex_language_models and response_obj is not None:
|
|
||||||
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
|
|
||||||
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
|
||||||
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
|
||||||
total_tokens=response_obj.usage_metadata.total_token_count)
|
|
||||||
else:
|
|
||||||
prompt_tokens = len(
|
|
||||||
encoding.encode(prompt)
|
|
||||||
)
|
|
||||||
completion_tokens = len(
|
|
||||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
|
||||||
)
|
)
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
## RESPONSE OBJECT
|
||||||
completion_tokens=completion_tokens,
|
if len(str(completion_response)) > 0:
|
||||||
total_tokens=prompt_tokens + completion_tokens
|
model_response["choices"][0]["message"][
|
||||||
|
"content"
|
||||||
|
] = str(completion_response)
|
||||||
|
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
## CALCULATING USAGE
|
||||||
|
if model in litellm.vertex_language_models and response_obj is not None:
|
||||||
|
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
|
||||||
|
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||||
|
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||||
|
total_tokens=response_obj.usage_metadata.total_token_count)
|
||||||
|
else:
|
||||||
|
prompt_tokens = len(
|
||||||
|
encoding.encode(prompt)
|
||||||
|
)
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
usage = Usage(
|
||||||
return model_response
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens
|
||||||
|
)
|
||||||
|
model_response.usage = usage
|
||||||
|
return model_response
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params):
|
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -63,21 +63,22 @@ def load_vertex_ai_credentials():
|
||||||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||||
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
|
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
|
||||||
|
|
||||||
def test_vertex_ai_sdk():
|
async def get_response():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
|
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
|
||||||
llm_model = GenerativeModel("gemini-pro")
|
response = await acompletion(
|
||||||
chat = llm_model.start_chat()
|
model="gemini-pro",
|
||||||
print(chat.send_message("write code for saying hi from LiteLLM", generation_config=GenerationConfig(**{})).text)
|
messages=[
|
||||||
test_vertex_ai_sdk()
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
def simple_test_vertex_ai():
|
asyncio.run(get_response())
|
||||||
try:
|
|
||||||
load_vertex_ai_credentials()
|
|
||||||
response = completion(model="gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
|
|
||||||
print(f'response: {response}')
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
def test_vertex_ai():
|
def test_vertex_ai():
|
||||||
import random
|
import random
|
||||||
|
|
|
@ -1040,6 +1040,8 @@ def test_completion_together_ai_mixtral():
|
||||||
cost = completion_cost(completion_response=response)
|
cost = completion_cost(completion_response=response)
|
||||||
assert cost > 0.0
|
assert cost > 0.0
|
||||||
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,8 @@ from .exceptions import (
|
||||||
Timeout,
|
Timeout,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIError,
|
APIError,
|
||||||
BudgetExceededError
|
BudgetExceededError,
|
||||||
|
UnprocessableEntityError
|
||||||
)
|
)
|
||||||
from typing import cast, List, Dict, Union, Optional, Literal
|
from typing import cast, List, Dict, Union, Optional, Literal
|
||||||
from .caching import Cache
|
from .caching import Cache
|
||||||
|
@ -4432,7 +4433,15 @@ def exception_type(
|
||||||
)
|
)
|
||||||
elif "403" in error_str:
|
elif "403" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise AuthenticationError(
|
raise U(
|
||||||
|
message=f"VertexAIException - {error_str}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="vertex_ai",
|
||||||
|
response=original_exception.response
|
||||||
|
)
|
||||||
|
elif "The response was blocked." in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise UnprocessableEntityError(
|
||||||
message=f"VertexAIException - {error_str}",
|
message=f"VertexAIException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="vertex_ai",
|
llm_provider="vertex_ai",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue