forked from phoenix/litellm-mirror
fix(vertex_httpx.py): cover gemini content violation (on prompt)
This commit is contained in:
parent
123477b55a
commit
1ff0129a94
3 changed files with 79 additions and 17 deletions
|
@ -563,6 +563,43 @@ class VertexLLM(BaseLLM):
|
|||
)
|
||||
|
||||
## CHECK IF RESPONSE FLAGGED
|
||||
if "promptFeedback" in completion_response:
|
||||
if "blockReason" in completion_response["promptFeedback"]:
|
||||
# If set, the prompt was blocked and no candidates are returned. Rephrase your prompt
|
||||
model_response.choices[0].finish_reason = "content_filter"
|
||||
|
||||
chat_completion_message: ChatCompletionResponseMessage = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
choice = litellm.Choices(
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
message=chat_completion_message, # type: ignore
|
||||
logprobs=None,
|
||||
enhancements=None,
|
||||
)
|
||||
|
||||
model_response.choices = [choice]
|
||||
|
||||
## GET USAGE ##
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=completion_response["usageMetadata"][
|
||||
"promptTokenCount"
|
||||
],
|
||||
completion_tokens=completion_response["usageMetadata"].get(
|
||||
"candidatesTokenCount", 0
|
||||
),
|
||||
total_tokens=completion_response["usageMetadata"][
|
||||
"totalTokenCount"
|
||||
],
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
if len(completion_response["candidates"]) > 0:
|
||||
content_policy_violations = (
|
||||
VertexGeminiConfig().get_flagged_finish_reasons()
|
||||
|
@ -573,16 +610,40 @@ class VertexLLM(BaseLLM):
|
|||
in content_policy_violations.keys()
|
||||
):
|
||||
## CONTENT POLICY VIOLATION ERROR
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="The response was blocked. Reason={}. Raw Response={}".format(
|
||||
content_policy_violations[
|
||||
completion_response["candidates"][0]["finishReason"]
|
||||
],
|
||||
completion_response,
|
||||
),
|
||||
model_response.choices[0].finish_reason = "content_filter"
|
||||
|
||||
chat_completion_message = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
choice = litellm.Choices(
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
message=chat_completion_message, # type: ignore
|
||||
logprobs=None,
|
||||
enhancements=None,
|
||||
)
|
||||
|
||||
model_response.choices = [choice]
|
||||
|
||||
## GET USAGE ##
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=completion_response["usageMetadata"][
|
||||
"promptTokenCount"
|
||||
],
|
||||
completion_tokens=completion_response["usageMetadata"].get(
|
||||
"candidatesTokenCount", 0
|
||||
),
|
||||
total_tokens=completion_response["usageMetadata"][
|
||||
"totalTokenCount"
|
||||
],
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
model_response.choices = [] # type: ignore
|
||||
|
||||
## GET MODEL ##
|
||||
|
@ -590,9 +651,7 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
try:
|
||||
## GET TEXT ##
|
||||
chat_completion_message: ChatCompletionResponseMessage = {
|
||||
"role": "assistant"
|
||||
}
|
||||
chat_completion_message = {"role": "assistant"}
|
||||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||
|
@ -632,9 +691,9 @@ class VertexLLM(BaseLLM):
|
|||
## GET USAGE ##
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
|
||||
completion_tokens=completion_response["usageMetadata"][
|
||||
"candidatesTokenCount"
|
||||
],
|
||||
completion_tokens=completion_response["usageMetadata"].get(
|
||||
"candidatesTokenCount", 0
|
||||
),
|
||||
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
|
||||
)
|
||||
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
model_list:
|
||||
- model_name: gemini-1.5-flash-gemini
|
||||
litellm_params:
|
||||
model: gemini/gemini-1.5-flash
|
||||
- litellm_params:
|
||||
api_base: http://0.0.0.0:8080
|
||||
api_key: ''
|
||||
|
|
|
@ -227,9 +227,9 @@ class PromptFeedback(TypedDict):
|
|||
blockReasonMessage: str
|
||||
|
||||
|
||||
class UsageMetadata(TypedDict):
|
||||
promptTokenCount: int
|
||||
totalTokenCount: int
|
||||
class UsageMetadata(TypedDict, total=False):
|
||||
promptTokenCount: Required[int]
|
||||
totalTokenCount: Required[int]
|
||||
candidatesTokenCount: int
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue