forked from phoenix/litellm-mirror
fix(vertex_httpx.py): cover gemini content violation (on prompt)
This commit is contained in:
parent
123477b55a
commit
1ff0129a94
3 changed files with 79 additions and 17 deletions
|
@ -563,6 +563,43 @@ class VertexLLM(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
## CHECK IF RESPONSE FLAGGED
|
## CHECK IF RESPONSE FLAGGED
|
||||||
|
if "promptFeedback" in completion_response:
|
||||||
|
if "blockReason" in completion_response["promptFeedback"]:
|
||||||
|
# If set, the prompt was blocked and no candidates are returned. Rephrase your prompt
|
||||||
|
model_response.choices[0].finish_reason = "content_filter"
|
||||||
|
|
||||||
|
chat_completion_message: ChatCompletionResponseMessage = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
choice = litellm.Choices(
|
||||||
|
finish_reason="content_filter",
|
||||||
|
index=0,
|
||||||
|
message=chat_completion_message, # type: ignore
|
||||||
|
logprobs=None,
|
||||||
|
enhancements=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.choices = [choice]
|
||||||
|
|
||||||
|
## GET USAGE ##
|
||||||
|
usage = litellm.Usage(
|
||||||
|
prompt_tokens=completion_response["usageMetadata"][
|
||||||
|
"promptTokenCount"
|
||||||
|
],
|
||||||
|
completion_tokens=completion_response["usageMetadata"].get(
|
||||||
|
"candidatesTokenCount", 0
|
||||||
|
),
|
||||||
|
total_tokens=completion_response["usageMetadata"][
|
||||||
|
"totalTokenCount"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
if len(completion_response["candidates"]) > 0:
|
if len(completion_response["candidates"]) > 0:
|
||||||
content_policy_violations = (
|
content_policy_violations = (
|
||||||
VertexGeminiConfig().get_flagged_finish_reasons()
|
VertexGeminiConfig().get_flagged_finish_reasons()
|
||||||
|
@ -573,16 +610,40 @@ class VertexLLM(BaseLLM):
|
||||||
in content_policy_violations.keys()
|
in content_policy_violations.keys()
|
||||||
):
|
):
|
||||||
## CONTENT POLICY VIOLATION ERROR
|
## CONTENT POLICY VIOLATION ERROR
|
||||||
raise VertexAIError(
|
model_response.choices[0].finish_reason = "content_filter"
|
||||||
status_code=400,
|
|
||||||
message="The response was blocked. Reason={}. Raw Response={}".format(
|
chat_completion_message = {
|
||||||
content_policy_violations[
|
"role": "assistant",
|
||||||
completion_response["candidates"][0]["finishReason"]
|
"content": None,
|
||||||
],
|
}
|
||||||
completion_response,
|
|
||||||
),
|
choice = litellm.Choices(
|
||||||
|
finish_reason="content_filter",
|
||||||
|
index=0,
|
||||||
|
message=chat_completion_message, # type: ignore
|
||||||
|
logprobs=None,
|
||||||
|
enhancements=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_response.choices = [choice]
|
||||||
|
|
||||||
|
## GET USAGE ##
|
||||||
|
usage = litellm.Usage(
|
||||||
|
prompt_tokens=completion_response["usageMetadata"][
|
||||||
|
"promptTokenCount"
|
||||||
|
],
|
||||||
|
completion_tokens=completion_response["usageMetadata"].get(
|
||||||
|
"candidatesTokenCount", 0
|
||||||
|
),
|
||||||
|
total_tokens=completion_response["usageMetadata"][
|
||||||
|
"totalTokenCount"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
model_response.choices = [] # type: ignore
|
model_response.choices = [] # type: ignore
|
||||||
|
|
||||||
## GET MODEL ##
|
## GET MODEL ##
|
||||||
|
@ -590,9 +651,7 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
## GET TEXT ##
|
## GET TEXT ##
|
||||||
chat_completion_message: ChatCompletionResponseMessage = {
|
chat_completion_message = {"role": "assistant"}
|
||||||
"role": "assistant"
|
|
||||||
}
|
|
||||||
content_str = ""
|
content_str = ""
|
||||||
tools: List[ChatCompletionToolCallChunk] = []
|
tools: List[ChatCompletionToolCallChunk] = []
|
||||||
for idx, candidate in enumerate(completion_response["candidates"]):
|
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||||
|
@ -632,9 +691,9 @@ class VertexLLM(BaseLLM):
|
||||||
## GET USAGE ##
|
## GET USAGE ##
|
||||||
usage = litellm.Usage(
|
usage = litellm.Usage(
|
||||||
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
|
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
|
||||||
completion_tokens=completion_response["usageMetadata"][
|
completion_tokens=completion_response["usageMetadata"].get(
|
||||||
"candidatesTokenCount"
|
"candidatesTokenCount", 0
|
||||||
],
|
),
|
||||||
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
|
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
model_list:
|
model_list:
|
||||||
|
- model_name: gemini-1.5-flash-gemini
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-1.5-flash
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
api_base: http://0.0.0.0:8080
|
api_base: http://0.0.0.0:8080
|
||||||
api_key: ''
|
api_key: ''
|
||||||
|
|
|
@ -227,9 +227,9 @@ class PromptFeedback(TypedDict):
|
||||||
blockReasonMessage: str
|
blockReasonMessage: str
|
||||||
|
|
||||||
|
|
||||||
class UsageMetadata(TypedDict):
|
class UsageMetadata(TypedDict, total=False):
|
||||||
promptTokenCount: int
|
promptTokenCount: Required[int]
|
||||||
totalTokenCount: int
|
totalTokenCount: Required[int]
|
||||||
candidatesTokenCount: int
|
candidatesTokenCount: int
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue