From 1ff0129a94e8fa3b422e38b49d6ec24df6745791 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 24 Jun 2024 19:13:56 -0700 Subject: [PATCH] fix(vertex_httpx.py): cover gemini content violation (on prompt) --- litellm/llms/vertex_httpx.py | 87 +++++++++++++++++++++---- litellm/proxy/_super_secret_config.yaml | 3 + litellm/types/llms/vertex_ai.py | 6 +- 3 files changed, 79 insertions(+), 17 deletions(-) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 63bcd9f4f..028c3f721 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -563,6 +563,43 @@ class VertexLLM(BaseLLM): ) ## CHECK IF RESPONSE FLAGGED + if "promptFeedback" in completion_response: + if "blockReason" in completion_response["promptFeedback"]: + # If set, the prompt was blocked and no candidates are returned. Rephrase your prompt + model_response.choices[0].finish_reason = "content_filter" + + chat_completion_message: ChatCompletionResponseMessage = { + "role": "assistant", + "content": None, + } + + choice = litellm.Choices( + finish_reason="content_filter", + index=0, + message=chat_completion_message, # type: ignore + logprobs=None, + enhancements=None, + ) + + model_response.choices = [choice] + + ## GET USAGE ## + usage = litellm.Usage( + prompt_tokens=completion_response["usageMetadata"][ + "promptTokenCount" + ], + completion_tokens=completion_response["usageMetadata"].get( + "candidatesTokenCount", 0 + ), + total_tokens=completion_response["usageMetadata"][ + "totalTokenCount" + ], + ) + + setattr(model_response, "usage", usage) + + return model_response + if len(completion_response["candidates"]) > 0: content_policy_violations = ( VertexGeminiConfig().get_flagged_finish_reasons() @@ -573,16 +610,40 @@ class VertexLLM(BaseLLM): in content_policy_violations.keys() ): ## CONTENT POLICY VIOLATION ERROR - raise VertexAIError( - status_code=400, - message="The response was blocked. Reason={}. Raw Response={}".format( - content_policy_violations[ - completion_response["candidates"][0]["finishReason"] - ], - completion_response, - ), + model_response.choices[0].finish_reason = "content_filter" + + chat_completion_message = { + "role": "assistant", + "content": None, + } + + choice = litellm.Choices( + finish_reason="content_filter", + index=0, + message=chat_completion_message, # type: ignore + logprobs=None, + enhancements=None, ) + model_response.choices = [choice] + + ## GET USAGE ## + usage = litellm.Usage( + prompt_tokens=completion_response["usageMetadata"][ + "promptTokenCount" + ], + completion_tokens=completion_response["usageMetadata"].get( + "candidatesTokenCount", 0 + ), + total_tokens=completion_response["usageMetadata"][ + "totalTokenCount" + ], + ) + + setattr(model_response, "usage", usage) + + return model_response + model_response.choices = [] # type: ignore ## GET MODEL ## @@ -590,9 +651,7 @@ class VertexLLM(BaseLLM): try: ## GET TEXT ## - chat_completion_message: ChatCompletionResponseMessage = { - "role": "assistant" - } + chat_completion_message = {"role": "assistant"} content_str = "" tools: List[ChatCompletionToolCallChunk] = [] for idx, candidate in enumerate(completion_response["candidates"]): @@ -632,9 +691,9 @@ class VertexLLM(BaseLLM): ## GET USAGE ## usage = litellm.Usage( prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"], - completion_tokens=completion_response["usageMetadata"][ - "candidatesTokenCount" - ], + completion_tokens=completion_response["usageMetadata"].get( + "candidatesTokenCount", 0 + ), total_tokens=completion_response["usageMetadata"]["totalTokenCount"], ) diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 04a4806c1..c5f1b4768 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -1,4 +1,7 @@ model_list: +- model_name: gemini-1.5-flash-gemini + litellm_params: + model: gemini/gemini-1.5-flash - litellm_params: api_base: http://0.0.0.0:8080 api_key: '' diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 1612f8761..2dda57c2e 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -227,9 +227,9 @@ class PromptFeedback(TypedDict): blockReasonMessage: str -class UsageMetadata(TypedDict): - promptTokenCount: int - totalTokenCount: int +class UsageMetadata(TypedDict, total=False): + promptTokenCount: Required[int] + totalTokenCount: Required[int] candidatesTokenCount: int