diff --git a/enterprise/enterprise_hooks/banned_keywords.py b/enterprise/enterprise_hooks/banned_keywords.py index 4cf68b2fd..3f3e01f5b 100644 --- a/enterprise/enterprise_hooks/banned_keywords.py +++ b/enterprise/enterprise_hooks/banned_keywords.py @@ -93,7 +93,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger): response.choices[0], litellm.utils.Choices ): for word in self.banned_keywords_list: - self.test_violation(test_str=response.choices[0].message.content) + self.test_violation(test_str=response.choices[0].message.content or "") async def async_post_call_streaming_hook( self, diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index bd9cfaa8d..ba16598be 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -12,7 +12,12 @@ from litellm.llms.prompt_templates.factory import ( convert_to_gemini_tool_call_result, convert_to_gemini_tool_call_invoke, ) -from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type +from litellm.types.files import ( + get_file_mime_type_for_file_type, + get_file_type_from_extension, + is_gemini_1_5_accepted_file_type, + is_video_file_type, +) class VertexAIError(Exception): @@ -301,15 +306,15 @@ def _process_gemini_image(image_url: str) -> PartType: # GCS URIs if "gs://" in image_url: # Figure out file type - extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" - extension = extension_with_dot[1:] # Ex: "png" + extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" + extension = extension_with_dot[1:] # Ex: "png" file_type = get_file_type_from_extension(extension) # Validate the file type is supported by Gemini if not is_gemini_1_5_accepted_file_type(file_type): raise Exception(f"File type not supported by gemini - {file_type}") - + mime_type = get_file_mime_type_for_file_type(file_type) file_data = FileDataType(mime_type=mime_type, file_uri=image_url) @@ -320,7 +325,7 @@ def _process_gemini_image(image_url: str) -> PartType: image = _load_image_from_url(image_url) _blob = BlobType(data=image.data, mime_type=image._mime_type) return PartType(inline_data=_blob) - + # Base64 encoding elif "base64" in image_url: import base64, re @@ -611,7 +616,7 @@ def completion( llm_model = None # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now - if acompletion == True: + if acompletion is True: data = { "llm_model": llm_model, "mode": mode, @@ -643,7 +648,7 @@ def completion( tools = optional_params.pop("tools", None) content = _gemini_convert_messages_with_history(messages=messages) stream = optional_params.pop("stream", False) - if stream == True: + if stream is True: request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" logging_obj.pre_call( input=prompt, diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index b8c698c90..acf79bebb 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -9,6 +9,14 @@ import litellm, uuid import httpx, inspect # type: ignore from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from .base import BaseLLM +from litellm.types.llms.vertex_ai import ( + ContentType, + SystemInstructions, + PartType, + RequestBody, + GenerateContentResponseBody, +) +from litellm.llms.vertex_ai import _gemini_convert_messages_with_history class VertexAIError(Exception): @@ -33,16 +41,110 @@ class VertexLLM(BaseLLM): self.project_id: Optional[str] = None self.async_handler: Optional[AsyncHTTPHandler] = None - def load_auth(self) -> Tuple[Any, str]: + def _process_response( + self, + model: str, + response: httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.utils.Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> ModelResponse: + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + + print_verbose(f"raw model_response: {response.text}") + + ## RESPONSE OBJECT + try: + completion_response = GenerateContentResponseBody(**response.json()) # type: ignore + except Exception as e: + raise VertexAIError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + response.text, str(e) + ), + status_code=422, + ) + + model_response.choices = [] + + ## GET MODEL ## + model_response.model = model + ## GET TEXT ## + for idx, candidate in enumerate(completion_response["candidates"]): + if candidate.get("content", None) is None: + continue + + message = litellm.Message( + content=candidate["content"]["parts"][0]["text"], + role="assistant", + logprobs=None, + function_call=None, + tool_calls=None, + ) + choice = litellm.Choices( + finish_reason=candidate.get("finishReason", "stop"), + index=candidate.get("index", idx), + message=message, + logprobs=None, + enhancements=None, + ) + + model_response.choices.append(choice) + + ## GET USAGE ## + usage = litellm.Usage( + prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"], + completion_tokens=completion_response["usageMetadata"][ + "candidatesTokenCount" + ], + total_tokens=completion_response["usageMetadata"]["totalTokenCount"], + ) + + setattr(model_response, "usage", usage) + + return model_response + + def get_vertex_region(self, vertex_region: Optional[str]) -> str: + return vertex_region or "us-central1" + + def load_auth( + self, credentials: Optional[str], project_id: Optional[str] + ) -> Tuple[Any, str]: from google.auth.transport.requests import Request # type: ignore[import-untyped] from google.auth.credentials import Credentials # type: ignore[import-untyped] import google.auth as google_auth - credentials, project_id = google_auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) + if credentials is not None and isinstance(credentials, str): + import google.oauth2.service_account - credentials.refresh(Request()) + json_obj = json.loads(credentials) + + creds = google.oauth2.service_account.Credentials.from_service_account_info( + json_obj, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + if project_id is None: + project_id = creds.project_id + else: + creds, project_id = google_auth.default( + quota_project_id=project_id, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + creds.refresh(Request()) if not project_id: raise ValueError("Could not resolve project_id") @@ -52,38 +154,135 @@ class VertexLLM(BaseLLM): f"Expected project_id to be a str but got {type(project_id)}" ) - return credentials, project_id + return creds, project_id def refresh_auth(self, credentials: Any) -> None: from google.auth.transport.requests import Request # type: ignore[import-untyped] credentials.refresh(Request()) - def _prepare_request(self, request: httpx.Request) -> None: - access_token = self._ensure_access_token() - - if request.headers.get("Authorization"): - # already authenticated, nothing for us to do - return - - request.headers["Authorization"] = f"Bearer {access_token}" - - def _ensure_access_token(self) -> str: - if self.access_token is not None: - return self.access_token + def _ensure_access_token( + self, credentials: Optional[str], project_id: Optional[str] + ) -> Tuple[str, str]: + """ + Returns auth token and project id + """ + if self.access_token is not None and self.project_id is not None: + return self.access_token, self.project_id if not self._credentials: - self._credentials, project_id = self.load_auth() + self._credentials, project_id = self.load_auth( + credentials=credentials, project_id=project_id + ) if not self.project_id: self.project_id = project_id else: self.refresh_auth(self._credentials) - if not self._credentials.token: + if not self.project_id: + self.project_id = self._credentials.project_id + + if not self.project_id: + raise ValueError("Could not resolve project_id") + + if not self._credentials or not self._credentials.token: raise RuntimeError("Could not resolve API token from the environment") - assert isinstance(self._credentials.token, str) - return self._credentials.token + return self._credentials.token, self.project_id + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + acompletion: bool, + timeout: Optional[Union[float, httpx.Timeout]], + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], + litellm_params=None, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + + auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + vertex_location = self.get_vertex_region(vertex_region=vertex_location) + stream = optional_params.pop("stream", None) + + ### SET RUNTIME ENDPOINT ### + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent" + + ## TRANSFORMATION ## + # Separate system prompt from rest of message + system_prompt_indices = [] + system_content_blocks: List[PartType] = [] + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block = PartType(text=message["content"]) + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + system_instructions = SystemInstructions(parts=system_content_blocks) + content = _gemini_convert_messages_with_history(messages=messages) + + data = RequestBody(system_instruction=system_instructions, contents=content) + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + } + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": url, + "headers": headers, + }, + ) + + ## COMPLETION CALL ## + if client is None or isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = HTTPHandler(**_params) # type: ignore + else: + client = client + try: + response = client.post(url=url, headers=headers, json=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + return self._process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + optional_params=optional_params, + api_key="", + data=data, # type: ignore + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) def image_generation( self, @@ -163,7 +362,7 @@ class VertexLLM(BaseLLM): } \ "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" """ - auth_header = self._ensure_access_token() + auth_header, _ = self._ensure_access_token(credentials=None, project_id=None) optional_params = optional_params or { "sampleCount": 1 } # default optional params diff --git a/litellm/main.py b/litellm/main.py index 8133e3517..63b86b43b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1893,6 +1893,7 @@ def completion( or optional_params.pop("vertex_ai_credentials", None) or get_secret("VERTEXAI_CREDENTIALS") ) + new_params = deepcopy(optional_params) if "claude-3" in model: model_response = vertex_ai_anthropic.completion( @@ -1910,6 +1911,26 @@ def completion( logging_obj=logging, acompletion=acompletion, ) + elif ( + model in litellm.vertex_language_models + or model in litellm.vertex_vision_models + ): + model_response = vertex_chat_completion.completion( # type: ignore + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + ) else: model_response = vertex_ai.completion( model=model, diff --git a/litellm/proxy/hooks/azure_content_safety.py b/litellm/proxy/hooks/azure_content_safety.py index 47ba36a68..972ac9992 100644 --- a/litellm/proxy/hooks/azure_content_safety.py +++ b/litellm/proxy/hooks/azure_content_safety.py @@ -140,7 +140,7 @@ class _PROXY_AzureContentSafety( response.choices[0], litellm.utils.Choices ): await self.test_violation( - content=response.choices[0].message.content, source="output" + content=response.choices[0].message.content or "", source="output" ) # async def async_post_call_streaming_hook( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 84d3a2bfc..cf49fd130 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -532,6 +532,8 @@ def test_gemini_pro_vision(): # DO Not DELETE this ASSERT # Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response assert prompt_tokens == 263 # the gemini api returns 263 to us + + # assert False except litellm.RateLimitError as e: pass except Exception as e: diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 3ad3e62c4..fe903841e 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -9,6 +9,7 @@ from typing_extensions import ( runtime_checkable, Required, ) +from enum import Enum class Field(TypedDict): @@ -51,3 +52,161 @@ class PartType(TypedDict, total=False): class ContentType(TypedDict, total=False): role: Literal["user", "model"] parts: Required[List[PartType]] + + +class SystemInstructions(TypedDict): + parts: Required[List[PartType]] + + +class Schema(TypedDict, total=False): + type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"] + description: str + enum: List[str] + items: List["Schema"] + properties: "Schema" + required: List[str] + nullable: bool + + +class FunctionDeclaration(TypedDict, total=False): + name: Required[str] + description: str + parameters: Schema + response: Schema + + +class FunctionCallingConfig(TypedDict, total=False): + mode: Literal["ANY", "AUTO", "NONE"] + allowed_function_names: List[str] + + +HarmCategory = Literal[ + "HARM_CATEGORY_UNSPECIFIED", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_DANGEROUS_CONTENT", + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", +] +HarmBlockThreshold = Literal[ + "HARM_BLOCK_THRESHOLD_UNSPECIFIED", + "BLOCK_LOW_AND_ABOVE", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_ONLY_HIGH", + "BLOCK_NONE", +] +HarmBlockMethod = Literal["HARM_BLOCK_METHOD_UNSPECIFIED", "SEVERITY", "PROBABILITY"] + +HarmProbability = Literal[ + "HARM_PROBABILITY_UNSPECIFIED", "NEGLIGIBLE", "LOW", "MEDIUM", "HIGH" +] + +HarmSeverity = Literal[ + "HARM_SEVERITY_UNSPECIFIED", + "HARM_SEVERITY_NEGLIGIBLE", + "HARM_SEVERITY_LOW", + "HARM_SEVERITY_MEDIUM", + "HARM_SEVERITY_HIGH", +] + + +class SafetSettingsConfig(TypedDict, total=False): + category: HarmCategory + threshold: HarmBlockThreshold + max_influential_terms: int + method: HarmBlockMethod + + +class GenerationConfig(TypedDict, total=False): + temperature: float + top_p: float + top_k: float + candidate_count: int + max_output_tokens: int + stop_sequences: List[str] + presence_penalty: float + frequency_penalty: float + response_mime_type: Literal["text/plain", "application/json"] + + +class RequestBody(TypedDict, total=False): + contents: Required[List[ContentType]] + system_instruction: SystemInstructions + tools: FunctionDeclaration + tool_config: FunctionCallingConfig + safety_settings: SafetSettingsConfig + generation_config: GenerationConfig + + +class SafetyRatings(TypedDict): + category: HarmCategory + probability: HarmProbability + probabilityScore: int + severity: HarmSeverity + blocked: bool + + +class Date(TypedDict): + year: int + month: int + date: int + + +class Citation(TypedDict): + startIndex: int + endIndex: int + uri: str + title: str + license: str + publicationDate: Date + + +class CitationMetadata(TypedDict): + citations: List[Citation] + + +class SearchEntryPoint(TypedDict, total=False): + renderedContent: str + sdkBlob: str + + +class GroundingMetadata(TypedDict, total=False): + webSearchQueries: List[str] + searchEntryPoint: SearchEntryPoint + + +class Candidates(TypedDict, total=False): + index: int + content: ContentType + finishReason: Literal[ + "FINISH_REASON_UNSPECIFIED", + "STOP", + "MAX_TOKENS", + "SAFETY", + "RECITATION", + "OTHER", + "BLOCKLIST", + "PROHIBITED_CONTENT", + "SPII", + ] + safetyRatings: SafetyRatings + citationMetadata: CitationMetadata + groundingMetadata: GroundingMetadata + finishMessage: str + + +class PromptFeedback(TypedDict): + blockReason: str + safetyRatings: List[SafetyRatings] + blockReasonMessage: str + + +class UsageMetadata(TypedDict): + promptTokenCount: int + totalTokenCount: int + candidatesTokenCount: int + + +class GenerateContentResponseBody(TypedDict, total=False): + candidates: Required[List[Candidates]] + promptFeedback: PromptFeedback + usageMetadata: Required[UsageMetadata] diff --git a/litellm/utils.py b/litellm/utils.py index 49ff7cd98..14041d2b6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -518,15 +518,18 @@ class Choices(OpenAIObject): self, finish_reason=None, index=0, - message=None, + message: Optional[Union[Message, dict]] = None, logprobs=None, enhancements=None, **params, ): super(Choices, self).__init__(**params) - self.finish_reason = ( - map_finish_reason(finish_reason) or "stop" - ) # set finish_reason for all responses + if finish_reason is not None: + self.finish_reason = map_finish_reason( + finish_reason + ) # set finish_reason for all responses + else: + self.finish_reason = "stop" self.index = index if message is None: self.message = Message() @@ -2822,7 +2825,9 @@ class Rules: raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore return True - def post_call_rules(self, input: str, model: str): + def post_call_rules(self, input: Optional[str], model: str) -> bool: + if input is None: + return True for rule in litellm.post_call_rules: if callable(rule): decision = rule(input) @@ -3101,9 +3106,9 @@ def client(original_function): pass else: if isinstance(original_response, ModelResponse): - model_response = original_response["choices"][0]["message"][ - "content" - ] + model_response = original_response.choices[ + 0 + ].message.content ### POST-CALL RULES ### rules_obj.post_call_rules(input=model_response, model=model) except Exception as e: