fix(vertex_ai.py): fix exception mapping for vertex ai

This commit is contained in:
Krrish Dholakia 2023-11-23 17:35:26 -08:00
parent 704af2ca34
commit f24786095a
4 changed files with 107 additions and 87 deletions

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Local OpenAI Proxy Server
# [OLD PROXY 👉 [**NEW** proxy here](./simple_proxy.md)] Local OpenAI Proxy Server
A fast, and lightweight OpenAI-compatible server to call 100+ LLM APIs.

View file

@ -73,93 +73,96 @@ def completion(
try:
import vertexai
except:
raise Exception("vertexai import failed please run `pip install google-cloud-aiplatform`")
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
vertexai.init(
project=vertex_project, location=vertex_location
)
## Load Config
config = litellm.VertexAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages])
mode = ""
if model in litellm.vertex_chat_models:
chat_model = ChatModel.from_pretrained(model)
mode = "chat"
elif model in litellm.vertex_text_models:
text_model = TextGenerationModel.from_pretrained(model)
mode = "text"
elif model in litellm.vertex_code_text_models:
text_model = CodeGenerationModel.from_pretrained(model)
mode = "text"
else: # vertex_code_chat_models
chat_model = CodeChatModel.from_pretrained(model)
mode = "chat"
if mode == "chat":
chat = chat_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None)
if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
model_response = text_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
completion_response = text_model.predict(prompt, **optional_params).text
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
vertexai.init(
project=vertex_project, location=vertex_location
)
model_response.usage = usage
return model_response
## Load Config
config = litellm.VertexAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages])
mode = ""
if model in litellm.vertex_chat_models:
chat_model = ChatModel.from_pretrained(model)
mode = "chat"
elif model in litellm.vertex_text_models:
text_model = TextGenerationModel.from_pretrained(model)
mode = "text"
elif model in litellm.vertex_code_text_models:
text_model = CodeGenerationModel.from_pretrained(model)
mode = "text"
else: # vertex_code_chat_models
chat_model = CodeChatModel.from_pretrained(model)
mode = "chat"
if mode == "chat":
chat = chat_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None)
if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
model_response = text_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
completion_response = text_model.predict(prompt, **optional_params).text
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
def embedding():

View file

@ -64,7 +64,7 @@ def test_context_window_with_fallbacks(model):
# for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model)
# test_context_window(model="command-nightly")
# test_context_window(model="chat-bison")
# test_context_window_with_fallbacks(model="command-nightly")
# Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models)

View file

@ -3816,6 +3816,23 @@ def exception_type(
llm_provider="vertex_ai",
response=original_exception.response
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"VertexAIException - {error_str}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response
)
if original_exception.status_code == 500:
exception_mapping_worked = True
raise APIError(
message=f"VertexAIException - {error_str}",
model=model,
llm_provider="vertex_ai",
request=original_exception.request
)
elif custom_llm_provider == "palm":
if "503 Getting metadata" in error_str:
# auth errors look like this