fix(vertex_httpx.py): cover gemini content violation (on prompt)

This commit is contained in:
Krrish Dholakia 2024-06-24 19:13:56 -07:00
parent 123477b55a
commit 1ff0129a94
3 changed files with 79 additions and 17 deletions

View file

@ -563,6 +563,43 @@ class VertexLLM(BaseLLM):
)
## CHECK IF RESPONSE FLAGGED
if "promptFeedback" in completion_response:
if "blockReason" in completion_response["promptFeedback"]:
# If set, the prompt was blocked and no candidates are returned. Rephrase your prompt
model_response.choices[0].finish_reason = "content_filter"
chat_completion_message: ChatCompletionResponseMessage = {
"role": "assistant",
"content": None,
}
choice = litellm.Choices(
finish_reason="content_filter",
index=0,
message=chat_completion_message, # type: ignore
logprobs=None,
enhancements=None,
)
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"][
"promptTokenCount"
],
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"][
"totalTokenCount"
],
)
setattr(model_response, "usage", usage)
return model_response
if len(completion_response["candidates"]) > 0:
content_policy_violations = (
VertexGeminiConfig().get_flagged_finish_reasons()
@ -573,16 +610,40 @@ class VertexLLM(BaseLLM):
in content_policy_violations.keys()
):
## CONTENT POLICY VIOLATION ERROR
raise VertexAIError(
status_code=400,
message="The response was blocked. Reason={}. Raw Response={}".format(
content_policy_violations[
completion_response["candidates"][0]["finishReason"]
],
completion_response,
),
model_response.choices[0].finish_reason = "content_filter"
chat_completion_message = {
"role": "assistant",
"content": None,
}
choice = litellm.Choices(
finish_reason="content_filter",
index=0,
message=chat_completion_message, # type: ignore
logprobs=None,
enhancements=None,
)
model_response.choices = [choice]
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"][
"promptTokenCount"
],
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"][
"totalTokenCount"
],
)
setattr(model_response, "usage", usage)
return model_response
model_response.choices = [] # type: ignore
## GET MODEL ##
@ -590,9 +651,7 @@ class VertexLLM(BaseLLM):
try:
## GET TEXT ##
chat_completion_message: ChatCompletionResponseMessage = {
"role": "assistant"
}
chat_completion_message = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
for idx, candidate in enumerate(completion_response["candidates"]):
@ -632,9 +691,9 @@ class VertexLLM(BaseLLM):
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
completion_tokens=completion_response["usageMetadata"][
"candidatesTokenCount"
],
completion_tokens=completion_response["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
)

View file

@ -1,4 +1,7 @@
model_list:
- model_name: gemini-1.5-flash-gemini
litellm_params:
model: gemini/gemini-1.5-flash
- litellm_params:
api_base: http://0.0.0.0:8080
api_key: ''

View file

@ -227,9 +227,9 @@ class PromptFeedback(TypedDict):
blockReasonMessage: str
class UsageMetadata(TypedDict):
promptTokenCount: int
totalTokenCount: int
class UsageMetadata(TypedDict, total=False):
promptTokenCount: Required[int]
totalTokenCount: Required[int]
candidatesTokenCount: int