fix(vertex_httpx.py): cover gemini content violation (on prompt)

This commit is contained in:
Krrish Dholakia 2024-06-24 19:13:56 -07:00
parent 123477b55a
commit 1ff0129a94
3 changed files with 79 additions and 17 deletions

View file

@ -563,6 +563,43 @@ class VertexLLM(BaseLLM):
)
## CHECK IF RESPONSE FLAGGED
if "promptFeedback" in completion_response:
if "blockReason" in completion_response["promptFeedback"]:
# If set, the prompt was blocked and no candidates are returned. Rephrase your prompt
model_response.choices[0].finish_reason = "content_filter"
chat_completion_message: ChatCompletionResponseMessage = {
"role": "assistant",
"content": None,
}
choice = litellm.Choices(
finish_reason="content_filter",
index=0,
message=chat_completion_message, # type: ignore
logprobs=None,
enhancements=None,
)
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"][
"promptTokenCount"
],
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"][
"totalTokenCount"
],
)
setattr(model_response, "usage", usage)
return model_response
if len(completion_response["candidates"]) > 0:
content_policy_violations = (
VertexGeminiConfig().get_flagged_finish_reasons()
@ -573,16 +610,40 @@ class VertexLLM(BaseLLM):
in content_policy_violations.keys()
):
## CONTENT POLICY VIOLATION ERROR
raise VertexAIError(
status_code=400,
message="The response was blocked. Reason={}. Raw Response={}".format(
content_policy_violations[
completion_response["candidates"][0]["finishReason"]
],
completion_response,
),
model_response.choices[0].finish_reason = "content_filter"
chat_completion_message = {
"role": "assistant",
"content": None,
}
choice = litellm.Choices(
finish_reason="content_filter",
index=0,
message=chat_completion_message, # type: ignore
logprobs=None,
enhancements=None,
)
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"][
"promptTokenCount"
],
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"][
"totalTokenCount"
],
)
setattr(model_response, "usage", usage)
return model_response
model_response.choices = [] # type: ignore
## GET MODEL ##
@ -590,9 +651,7 @@ class VertexLLM(BaseLLM):
try:
## GET TEXT ##
chat_completion_message: ChatCompletionResponseMessage = {
"role": "assistant"
}
chat_completion_message = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
for idx, candidate in enumerate(completion_response["candidates"]):
@ -632,9 +691,9 @@ class VertexLLM(BaseLLM):
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
completion_tokens=completion_response["usageMetadata"][
"candidatesTokenCount"
],
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
)