From 29169b3039744f914c28ae024e9246e3cc1cfad6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 12 Jun 2024 16:47:00 -0700 Subject: [PATCH 1/6] feat(vertex_httpx.py): Moving to call vertex ai via httpx (instead of their sdk). Allows us to support all their api updates. --- .../enterprise_hooks/banned_keywords.py | 2 +- litellm/llms/vertex_ai.py | 19 +- litellm/llms/vertex_httpx.py | 245 ++++++++++++++++-- litellm/main.py | 21 ++ litellm/proxy/hooks/azure_content_safety.py | 2 +- .../tests/test_amazing_vertex_completion.py | 2 + litellm/types/llms/vertex_ai.py | 159 ++++++++++++ litellm/utils.py | 21 +- 8 files changed, 431 insertions(+), 40 deletions(-) diff --git a/enterprise/enterprise_hooks/banned_keywords.py b/enterprise/enterprise_hooks/banned_keywords.py index 4cf68b2fd9..3f3e01f5b6 100644 --- a/enterprise/enterprise_hooks/banned_keywords.py +++ b/enterprise/enterprise_hooks/banned_keywords.py @@ -93,7 +93,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger): response.choices[0], litellm.utils.Choices ): for word in self.banned_keywords_list: - self.test_violation(test_str=response.choices[0].message.content) + self.test_violation(test_str=response.choices[0].message.content or "") async def async_post_call_streaming_hook( self, diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index bd9cfaa8d6..ba16598be2 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -12,7 +12,12 @@ from litellm.llms.prompt_templates.factory import ( convert_to_gemini_tool_call_result, convert_to_gemini_tool_call_invoke, ) -from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type +from litellm.types.files import ( + get_file_mime_type_for_file_type, + get_file_type_from_extension, + is_gemini_1_5_accepted_file_type, + is_video_file_type, +) class VertexAIError(Exception): @@ -301,15 +306,15 @@ def _process_gemini_image(image_url: str) -> PartType: # GCS URIs if "gs://" in image_url: # Figure out file type - extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" - extension = extension_with_dot[1:] # Ex: "png" + extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" + extension = extension_with_dot[1:] # Ex: "png" file_type = get_file_type_from_extension(extension) # Validate the file type is supported by Gemini if not is_gemini_1_5_accepted_file_type(file_type): raise Exception(f"File type not supported by gemini - {file_type}") - + mime_type = get_file_mime_type_for_file_type(file_type) file_data = FileDataType(mime_type=mime_type, file_uri=image_url) @@ -320,7 +325,7 @@ def _process_gemini_image(image_url: str) -> PartType: image = _load_image_from_url(image_url) _blob = BlobType(data=image.data, mime_type=image._mime_type) return PartType(inline_data=_blob) - + # Base64 encoding elif "base64" in image_url: import base64, re @@ -611,7 +616,7 @@ def completion( llm_model = None # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now - if acompletion == True: + if acompletion is True: data = { "llm_model": llm_model, "mode": mode, @@ -643,7 +648,7 @@ def completion( tools = optional_params.pop("tools", None) content = _gemini_convert_messages_with_history(messages=messages) stream = optional_params.pop("stream", False) - if stream == True: + if stream is True: request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" logging_obj.pre_call( input=prompt, diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index b8c698c901..acf79bebbe 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -9,6 +9,14 @@ import litellm, uuid import httpx, inspect # type: ignore from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from .base import BaseLLM +from litellm.types.llms.vertex_ai import ( + ContentType, + SystemInstructions, + PartType, + RequestBody, + GenerateContentResponseBody, +) +from litellm.llms.vertex_ai import _gemini_convert_messages_with_history class VertexAIError(Exception): @@ -33,16 +41,110 @@ class VertexLLM(BaseLLM): self.project_id: Optional[str] = None self.async_handler: Optional[AsyncHTTPHandler] = None - def load_auth(self) -> Tuple[Any, str]: + def _process_response( + self, + model: str, + response: httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.utils.Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> ModelResponse: + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + + print_verbose(f"raw model_response: {response.text}") + + ## RESPONSE OBJECT + try: + completion_response = GenerateContentResponseBody(**response.json()) # type: ignore + except Exception as e: + raise VertexAIError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + response.text, str(e) + ), + status_code=422, + ) + + model_response.choices = [] + + ## GET MODEL ## + model_response.model = model + ## GET TEXT ## + for idx, candidate in enumerate(completion_response["candidates"]): + if candidate.get("content", None) is None: + continue + + message = litellm.Message( + content=candidate["content"]["parts"][0]["text"], + role="assistant", + logprobs=None, + function_call=None, + tool_calls=None, + ) + choice = litellm.Choices( + finish_reason=candidate.get("finishReason", "stop"), + index=candidate.get("index", idx), + message=message, + logprobs=None, + enhancements=None, + ) + + model_response.choices.append(choice) + + ## GET USAGE ## + usage = litellm.Usage( + prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"], + completion_tokens=completion_response["usageMetadata"][ + "candidatesTokenCount" + ], + total_tokens=completion_response["usageMetadata"]["totalTokenCount"], + ) + + setattr(model_response, "usage", usage) + + return model_response + + def get_vertex_region(self, vertex_region: Optional[str]) -> str: + return vertex_region or "us-central1" + + def load_auth( + self, credentials: Optional[str], project_id: Optional[str] + ) -> Tuple[Any, str]: from google.auth.transport.requests import Request # type: ignore[import-untyped] from google.auth.credentials import Credentials # type: ignore[import-untyped] import google.auth as google_auth - credentials, project_id = google_auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) + if credentials is not None and isinstance(credentials, str): + import google.oauth2.service_account - credentials.refresh(Request()) + json_obj = json.loads(credentials) + + creds = google.oauth2.service_account.Credentials.from_service_account_info( + json_obj, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + if project_id is None: + project_id = creds.project_id + else: + creds, project_id = google_auth.default( + quota_project_id=project_id, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + creds.refresh(Request()) if not project_id: raise ValueError("Could not resolve project_id") @@ -52,38 +154,135 @@ class VertexLLM(BaseLLM): f"Expected project_id to be a str but got {type(project_id)}" ) - return credentials, project_id + return creds, project_id def refresh_auth(self, credentials: Any) -> None: from google.auth.transport.requests import Request # type: ignore[import-untyped] credentials.refresh(Request()) - def _prepare_request(self, request: httpx.Request) -> None: - access_token = self._ensure_access_token() - - if request.headers.get("Authorization"): - # already authenticated, nothing for us to do - return - - request.headers["Authorization"] = f"Bearer {access_token}" - - def _ensure_access_token(self) -> str: - if self.access_token is not None: - return self.access_token + def _ensure_access_token( + self, credentials: Optional[str], project_id: Optional[str] + ) -> Tuple[str, str]: + """ + Returns auth token and project id + """ + if self.access_token is not None and self.project_id is not None: + return self.access_token, self.project_id if not self._credentials: - self._credentials, project_id = self.load_auth() + self._credentials, project_id = self.load_auth( + credentials=credentials, project_id=project_id + ) if not self.project_id: self.project_id = project_id else: self.refresh_auth(self._credentials) - if not self._credentials.token: + if not self.project_id: + self.project_id = self._credentials.project_id + + if not self.project_id: + raise ValueError("Could not resolve project_id") + + if not self._credentials or not self._credentials.token: raise RuntimeError("Could not resolve API token from the environment") - assert isinstance(self._credentials.token, str) - return self._credentials.token + return self._credentials.token, self.project_id + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + acompletion: bool, + timeout: Optional[Union[float, httpx.Timeout]], + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], + litellm_params=None, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + + auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + vertex_location = self.get_vertex_region(vertex_region=vertex_location) + stream = optional_params.pop("stream", None) + + ### SET RUNTIME ENDPOINT ### + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent" + + ## TRANSFORMATION ## + # Separate system prompt from rest of message + system_prompt_indices = [] + system_content_blocks: List[PartType] = [] + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block = PartType(text=message["content"]) + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + system_instructions = SystemInstructions(parts=system_content_blocks) + content = _gemini_convert_messages_with_history(messages=messages) + + data = RequestBody(system_instruction=system_instructions, contents=content) + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + } + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": url, + "headers": headers, + }, + ) + + ## COMPLETION CALL ## + if client is None or isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = HTTPHandler(**_params) # type: ignore + else: + client = client + try: + response = client.post(url=url, headers=headers, json=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + return self._process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + optional_params=optional_params, + api_key="", + data=data, # type: ignore + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) def image_generation( self, @@ -163,7 +362,7 @@ class VertexLLM(BaseLLM): } \ "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" """ - auth_header = self._ensure_access_token() + auth_header, _ = self._ensure_access_token(credentials=None, project_id=None) optional_params = optional_params or { "sampleCount": 1 } # default optional params diff --git a/litellm/main.py b/litellm/main.py index 8133e35170..63b86b43bb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1893,6 +1893,7 @@ def completion( or optional_params.pop("vertex_ai_credentials", None) or get_secret("VERTEXAI_CREDENTIALS") ) + new_params = deepcopy(optional_params) if "claude-3" in model: model_response = vertex_ai_anthropic.completion( @@ -1910,6 +1911,26 @@ def completion( logging_obj=logging, acompletion=acompletion, ) + elif ( + model in litellm.vertex_language_models + or model in litellm.vertex_vision_models + ): + model_response = vertex_chat_completion.completion( # type: ignore + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + ) else: model_response = vertex_ai.completion( model=model, diff --git a/litellm/proxy/hooks/azure_content_safety.py b/litellm/proxy/hooks/azure_content_safety.py index 47ba36a683..972ac99928 100644 --- a/litellm/proxy/hooks/azure_content_safety.py +++ b/litellm/proxy/hooks/azure_content_safety.py @@ -140,7 +140,7 @@ class _PROXY_AzureContentSafety( response.choices[0], litellm.utils.Choices ): await self.test_violation( - content=response.choices[0].message.content, source="output" + content=response.choices[0].message.content or "", source="output" ) # async def async_post_call_streaming_hook( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 84d3a2bfc6..cf49fd130c 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -532,6 +532,8 @@ def test_gemini_pro_vision(): # DO Not DELETE this ASSERT # Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response assert prompt_tokens == 263 # the gemini api returns 263 to us + + # assert False except litellm.RateLimitError as e: pass except Exception as e: diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 3ad3e62c46..fe903841ef 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -9,6 +9,7 @@ from typing_extensions import ( runtime_checkable, Required, ) +from enum import Enum class Field(TypedDict): @@ -51,3 +52,161 @@ class PartType(TypedDict, total=False): class ContentType(TypedDict, total=False): role: Literal["user", "model"] parts: Required[List[PartType]] + + +class SystemInstructions(TypedDict): + parts: Required[List[PartType]] + + +class Schema(TypedDict, total=False): + type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"] + description: str + enum: List[str] + items: List["Schema"] + properties: "Schema" + required: List[str] + nullable: bool + + +class FunctionDeclaration(TypedDict, total=False): + name: Required[str] + description: str + parameters: Schema + response: Schema + + +class FunctionCallingConfig(TypedDict, total=False): + mode: Literal["ANY", "AUTO", "NONE"] + allowed_function_names: List[str] + + +HarmCategory = Literal[ + "HARM_CATEGORY_UNSPECIFIED", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_DANGEROUS_CONTENT", + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", +] +HarmBlockThreshold = Literal[ + "HARM_BLOCK_THRESHOLD_UNSPECIFIED", + "BLOCK_LOW_AND_ABOVE", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_ONLY_HIGH", + "BLOCK_NONE", +] +HarmBlockMethod = Literal["HARM_BLOCK_METHOD_UNSPECIFIED", "SEVERITY", "PROBABILITY"] + +HarmProbability = Literal[ + "HARM_PROBABILITY_UNSPECIFIED", "NEGLIGIBLE", "LOW", "MEDIUM", "HIGH" +] + +HarmSeverity = Literal[ + "HARM_SEVERITY_UNSPECIFIED", + "HARM_SEVERITY_NEGLIGIBLE", + "HARM_SEVERITY_LOW", + "HARM_SEVERITY_MEDIUM", + "HARM_SEVERITY_HIGH", +] + + +class SafetSettingsConfig(TypedDict, total=False): + category: HarmCategory + threshold: HarmBlockThreshold + max_influential_terms: int + method: HarmBlockMethod + + +class GenerationConfig(TypedDict, total=False): + temperature: float + top_p: float + top_k: float + candidate_count: int + max_output_tokens: int + stop_sequences: List[str] + presence_penalty: float + frequency_penalty: float + response_mime_type: Literal["text/plain", "application/json"] + + +class RequestBody(TypedDict, total=False): + contents: Required[List[ContentType]] + system_instruction: SystemInstructions + tools: FunctionDeclaration + tool_config: FunctionCallingConfig + safety_settings: SafetSettingsConfig + generation_config: GenerationConfig + + +class SafetyRatings(TypedDict): + category: HarmCategory + probability: HarmProbability + probabilityScore: int + severity: HarmSeverity + blocked: bool + + +class Date(TypedDict): + year: int + month: int + date: int + + +class Citation(TypedDict): + startIndex: int + endIndex: int + uri: str + title: str + license: str + publicationDate: Date + + +class CitationMetadata(TypedDict): + citations: List[Citation] + + +class SearchEntryPoint(TypedDict, total=False): + renderedContent: str + sdkBlob: str + + +class GroundingMetadata(TypedDict, total=False): + webSearchQueries: List[str] + searchEntryPoint: SearchEntryPoint + + +class Candidates(TypedDict, total=False): + index: int + content: ContentType + finishReason: Literal[ + "FINISH_REASON_UNSPECIFIED", + "STOP", + "MAX_TOKENS", + "SAFETY", + "RECITATION", + "OTHER", + "BLOCKLIST", + "PROHIBITED_CONTENT", + "SPII", + ] + safetyRatings: SafetyRatings + citationMetadata: CitationMetadata + groundingMetadata: GroundingMetadata + finishMessage: str + + +class PromptFeedback(TypedDict): + blockReason: str + safetyRatings: List[SafetyRatings] + blockReasonMessage: str + + +class UsageMetadata(TypedDict): + promptTokenCount: int + totalTokenCount: int + candidatesTokenCount: int + + +class GenerateContentResponseBody(TypedDict, total=False): + candidates: Required[List[Candidates]] + promptFeedback: PromptFeedback + usageMetadata: Required[UsageMetadata] diff --git a/litellm/utils.py b/litellm/utils.py index 49ff7cd984..14041d2b6f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -518,15 +518,18 @@ class Choices(OpenAIObject): self, finish_reason=None, index=0, - message=None, + message: Optional[Union[Message, dict]] = None, logprobs=None, enhancements=None, **params, ): super(Choices, self).__init__(**params) - self.finish_reason = ( - map_finish_reason(finish_reason) or "stop" - ) # set finish_reason for all responses + if finish_reason is not None: + self.finish_reason = map_finish_reason( + finish_reason + ) # set finish_reason for all responses + else: + self.finish_reason = "stop" self.index = index if message is None: self.message = Message() @@ -2822,7 +2825,9 @@ class Rules: raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore return True - def post_call_rules(self, input: str, model: str): + def post_call_rules(self, input: Optional[str], model: str) -> bool: + if input is None: + return True for rule in litellm.post_call_rules: if callable(rule): decision = rule(input) @@ -3101,9 +3106,9 @@ def client(original_function): pass else: if isinstance(original_response, ModelResponse): - model_response = original_response["choices"][0]["message"][ - "content" - ] + model_response = original_response.choices[ + 0 + ].message.content ### POST-CALL RULES ### rules_obj.post_call_rules(input=model_response, model=model) except Exception as e: From 1dac2aa59f0c629cb6395d133b28c8128b3272a4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 12 Jun 2024 19:55:14 -0700 Subject: [PATCH 2/6] fix(vertex_httpx.py): support streaming via httpx client --- litellm/__init__.py | 1 + litellm/llms/vertex_httpx.py | 201 +++++++++++++++++++++++++++++++- litellm/main.py | 56 +++++---- litellm/tests/test_streaming.py | 5 +- litellm/types/llms/openai.py | 6 + litellm/types/utils.py | 11 ++ litellm/utils.py | 29 +++++ 7 files changed, 283 insertions(+), 26 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index e18be347da..19c3bcca6b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -605,6 +605,7 @@ provider_list: List = [ "together_ai", "openrouter", "vertex_ai", + "vertex_ai_beta", "palm", "gemini", "ai21", diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index acf79bebbe..70a408c2b2 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -1,3 +1,7 @@ +# What is this? +## httpx client for vertex ai calls +## Initial implementation - covers gemini + image gen calls +from functools import partial import os, types import json from enum import Enum @@ -17,6 +21,86 @@ from litellm.types.llms.vertex_ai import ( GenerateContentResponseBody, ) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history +from litellm.types.utils import GenericStreamingChunk +from litellm.types.llms.openai import ( + ChatCompletionUsageBlock, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, +) + + +class VertexGeminiConfig: + def __init__(self) -> None: + pass + + def supports_system_message(self) -> bool: + """ + Not all gemini models support system instructions + """ + return True + + +async def make_call( + client: Optional[AsyncHTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = AsyncHTTPHandler() # Create a new client if none provided + + response = await client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise VertexAIError(status_code=response.status_code, message=response.text) + + completion_stream = ModelResponseIterator( + streaming_response=response.aiter_bytes(chunk_size=2056) + ) + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = HTTPHandler() # Create a new client if none provided + + response = client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise VertexAIError(status_code=response.status_code, message=response.read()) + + completion_stream = ModelResponseIterator( + streaming_response=response.iter_bytes(chunk_size=2056) + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream class VertexAIError(Exception): @@ -46,7 +130,6 @@ class VertexLLM(BaseLLM): model: str, response: httpx.Response, model_response: ModelResponse, - stream: bool, logging_obj: litellm.utils.Logging, optional_params: dict, api_key: str, @@ -77,7 +160,7 @@ class VertexLLM(BaseLLM): status_code=422, ) - model_response.choices = [] + model_response.choices = [] # type: ignore ## GET MODEL ## model_response.model = model @@ -190,6 +273,16 @@ class VertexLLM(BaseLLM): return self._credentials.token, self.project_id + async def async_streaming( + self, + ): + pass + + async def async_completion( + self, + ): + pass + def completion( self, model: str, @@ -214,7 +307,7 @@ class VertexLLM(BaseLLM): credentials=vertex_credentials, project_id=vertex_project ) vertex_location = self.get_vertex_region(vertex_region=vertex_location) - stream = optional_params.pop("stream", None) + stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore ### SET RUNTIME ENDPOINT ### url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent" @@ -251,6 +344,26 @@ class VertexLLM(BaseLLM): }, ) + ## SYNC STREAMING CALL ## + if stream is not None and stream is True: + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_sync_call, + client=None, + api_base=url, + headers=headers, # type: ignore + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + + return streaming_response ## COMPLETION CALL ## if client is None or isinstance(client, AsyncHTTPHandler): _params = {} @@ -274,7 +387,6 @@ class VertexLLM(BaseLLM): model=model, response=response, model_response=model_response, - stream=stream, logging_obj=logging_obj, optional_params=optional_params, api_key="", @@ -421,3 +533,84 @@ class VertexLLM(BaseLLM): model_response.data = _response_data return model_response + + +class ModelResponseIterator: + def __init__(self, streaming_response): + self.streaming_response = streaming_response + self.response_iterator = iter(self.streaming_response) + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + + gemini_chunk = processed_chunk["candidates"][0] + + if ( + "content" in gemini_chunk + and "text" in gemini_chunk["content"]["parts"][0] + ): + text = gemini_chunk["content"]["parts"][0]["text"] + + if "finishReason" in gemini_chunk: + finish_reason = map_finish_reason( + finish_reason=gemini_chunk["finishReason"] + ) + is_finished = True + + if "usageMetadata" in processed_chunk: + usage = ChatCompletionUsageBlock( + prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"], + completion_tokens=processed_chunk["usageMetadata"][ + "candidatesTokenCount" + ], + total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"], + ) + + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=0, + ) + return returned_chunk + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + try: + chunk = next(self.response_iterator) + chunk = chunk.decode() + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + chunk = chunk.decode() + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e}") diff --git a/litellm/main.py b/litellm/main.py index 63b86b43bb..16fd394f84 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1875,6 +1875,42 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "vertex_ai_beta": + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) + new_params = deepcopy(optional_params) + response = vertex_chat_completion.completion( # type: ignore + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + ) + elif custom_llm_provider == "vertex_ai": vertex_ai_project = ( optional_params.pop("vertex_project", None) @@ -1911,26 +1947,6 @@ def completion( logging_obj=logging, acompletion=acompletion, ) - elif ( - model in litellm.vertex_language_models - or model in litellm.vertex_vision_models - ): - model_response = vertex_chat_completion.completion( # type: ignore - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, - ) else: model_response = vertex_ai.completion( model=model, diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index ac107d28ee..c23dbb7d9a 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1029,7 +1029,8 @@ def test_completion_claude_stream_bad_key(): # test_completion_replicate_stream() -def test_vertex_ai_stream(): +@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"]) +def test_vertex_ai_stream(provider): from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials load_vertex_ai_credentials() @@ -1042,7 +1043,7 @@ def test_vertex_ai_stream(): try: print("making request", model) response = completion( - model=model, + model="{}/{}".format(provider, model), messages=[ {"role": "user", "content": "write 10 line code code for saying hi"} ], diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 66aec4906c..88f498ede9 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -323,3 +323,9 @@ class ChatCompletionResponseMessage(TypedDict, total=False): content: Optional[str] tool_calls: List[ChatCompletionToolCallChunk] role: Literal["assistant"] + + +class ChatCompletionUsageBlock(TypedDict): + prompt_tokens: int + completion_tokens: int + total_tokens: int diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 2b6aefcf59..1fbb375d30 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1,6 +1,8 @@ from typing import List, Optional, Union, Dict, Tuple, Literal from typing_extensions import TypedDict from enum import Enum +from typing_extensions import override, Required, Dict +from .llms.openai import ChatCompletionUsageBlock, ChatCompletionToolCallChunk class LiteLLMCommonStrings(Enum): @@ -37,3 +39,12 @@ class ModelInfo(TypedDict): "completion", "embedding", "image_generation", "chat", "audio_transcription" ] supported_openai_params: Optional[List[str]] + + +class GenericStreamingChunk(TypedDict): + text: Required[str] + tool_use: Optional[ChatCompletionToolCallChunk] + is_finished: Required[bool] + finish_reason: Required[str] + usage: Optional[ChatCompletionUsageBlock] + index: int diff --git a/litellm/utils.py b/litellm/utils.py index 14041d2b6f..f132e3202b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -11223,6 +11223,34 @@ class CustomStreamWrapper: ) else: completion_obj["content"] = str(chunk) + elif self.custom_llm_provider and ( + self.custom_llm_provider == "vertex_ai_beta" + ): + from litellm.types.utils import ( + GenericStreamingChunk as UtilsStreamingChunk, + ) + + if self.received_finish_reason is not None: + raise StopIteration + response_obj: UtilsStreamingChunk = chunk + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + + if ( + self.stream_options + and self.stream_options.get("include_usage", False) is True + and response_obj["usage"] is not None + ): + self.sent_stream_usage = True + model_response.usage = litellm.Usage( + prompt_tokens=response_obj["usage"]["prompt_tokens"], + completion_tokens=response_obj["usage"]["completion_tokens"], + total_tokens=response_obj["usage"]["total_tokens"], + ) + + if "tool_use" in response_obj and response_obj["tool_use"] is not None: + completion_obj["tool_calls"] = [response_obj["tool_use"]] elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): import proto # type: ignore @@ -11900,6 +11928,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "ollama" or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "vertex_ai" + or self.custom_llm_provider == "vertex_ai_beta" or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "replicate" From afebf867f69139eebc1912cf2e203f7ce24331ce Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 12 Jun 2024 20:15:03 -0700 Subject: [PATCH 3/6] fix(vertex_httpx.py): support async completion calls --- litellm/llms/vertex_httpx.py | 107 +++++++++++++++++- litellm/main.py | 1 + .../tests/test_amazing_vertex_completion.py | 56 ++++++--- 3 files changed, 142 insertions(+), 22 deletions(-) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 70a408c2b2..550fffe4a3 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -275,13 +275,89 @@ class VertexLLM(BaseLLM): async def async_streaming( self, - ): - pass + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_call, + client=client, + api_base=api_base, + headers=headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="vertex_ai_beta", + logging_obj=logging_obj, + ) + return streaming_response async def async_completion( self, - ): - pass + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = AsyncHTTPHandler(**_params) # type: ignore + else: + client = client # type: ignore + + try: + response = await client.post(api_base, headers=headers, json=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + return self._process_response( + model=model, + response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) def completion( self, @@ -344,6 +420,27 @@ class VertexLLM(BaseLLM): }, ) + ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, # type: ignore + api_base=url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, # type: ignore + ) + ## SYNC STREAMING CALL ## if stream is not None and stream is True: streaming_response = CustomStreamWrapper( @@ -359,7 +456,7 @@ class VertexLLM(BaseLLM): logging_obj=logging_obj, ), model=model, - custom_llm_provider="bedrock", + custom_llm_provider="vertex_ai_beta", logging_obj=logging_obj, ) diff --git a/litellm/main.py b/litellm/main.py index 16fd394f84..83104290d4 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -329,6 +329,7 @@ async def acompletion( or custom_llm_provider == "ollama_chat" or custom_llm_provider == "replicate" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" or custom_llm_provider == "anthropic" diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index cf49fd130c..7f0b49808c 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -503,28 +503,50 @@ async def test_async_vertexai_streaming_response(): # asyncio.run(test_async_vertexai_streaming_response()) -def test_gemini_pro_vision(): +@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"]) +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_gemini_pro_vision(provider, sync_mode): try: load_vertex_ai_credentials() litellm.set_verbose = True litellm.num_retries = 3 - resp = litellm.completion( - model="vertex_ai/gemini-1.5-flash-preview-0514", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "Whats in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" + if sync_mode: + resp = litellm.completion( + model="{}/gemini-1.5-flash-preview-0514".format(provider), + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" + }, }, - }, - ], - } - ], - ) + ], + } + ], + ) + else: + resp = await litellm.acompletion( + model="{}/gemini-1.5-flash-preview-0514".format(provider), + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" + }, + }, + ], + } + ], + ) print(resp) prompt_tokens = resp.usage.prompt_tokens From e60b0e96e47e3777d07abb3171e39bf3dd60b39e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 12 Jun 2024 21:11:00 -0700 Subject: [PATCH 4/6] fix(vertex_httpx.py): add function calling support to httpx route --- litellm/__init__.py | 1 + litellm/llms/vertex_httpx.py | 239 +++++++++++++++++- .../tests/test_amazing_vertex_completion.py | 73 +++++- litellm/types/llms/vertex_ai.py | 32 ++- litellm/utils.py | 10 + log.txt | 10 + 6 files changed, 345 insertions(+), 20 deletions(-) create mode 100644 log.txt diff --git a/litellm/__init__.py b/litellm/__init__.py index 19c3bcca6b..523ce4684a 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -767,6 +767,7 @@ from .llms.nlp_cloud import NLPCloudConfig from .llms.aleph_alpha import AlephAlphaConfig from .llms.petals import PetalsConfig from .llms.vertex_ai import VertexAIConfig +from .llms.vertex_httpx import VertexGeminiConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 550fffe4a3..e660e4d729 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -19,6 +19,10 @@ from litellm.types.llms.vertex_ai import ( PartType, RequestBody, GenerateContentResponseBody, + FunctionCallingConfig, + FunctionDeclaration, + Tools, + ToolConfig, ) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.types.utils import GenericStreamingChunk @@ -26,18 +30,203 @@ from litellm.types.llms.openai import ( ChatCompletionUsageBlock, ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, + ChatCompletionResponseMessage, ) class VertexGeminiConfig: - def __init__(self) -> None: - pass + """ + Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts + Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference - def supports_system_message(self) -> bool: + The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters: + + - `temperature` (float): This controls the degree of randomness in token selection. + + - `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256. + + - `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95. + + - `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40. + + - `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. + + - `candidate_count` (int): Number of generated responses to return. + + - `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. + + - `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0. + + - `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0. + + Note: Please make sure to modify the default parameters as required for your use case. + """ + + temperature: Optional[float] = None + max_output_tokens: Optional[int] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + response_mime_type: Optional[str] = None + candidate_count: Optional[int] = None + stop_sequences: Optional[list] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + + def __init__( + self, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + response_mime_type: Optional[str] = None, + candidate_count: Optional[int] = None, + stop_sequences: Optional[list] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "temperature", + "top_p", + "max_tokens", + "stream", + "tools", + "tool_choice", + "response_format", + "n", + "stop", + ] + + def map_tool_choice_values( + self, model: str, tool_choice: Union[str, dict] + ) -> Optional[ToolConfig]: + if tool_choice == "none": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE")) + elif tool_choice == "required": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY")) + elif tool_choice == "auto": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO")) + elif isinstance(tool_choice, dict): + # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + name = tool_choice.get("function", {}).get("name", "") + return ToolConfig( + functionCallingConfig=FunctionCallingConfig( + mode="ANY", allowed_function_names=[name] + ) + ) + else: + raise litellm.utils.UnsupportedParamsError( + message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + + def map_openai_params( + self, + model: str, + non_default_params: dict, + optional_params: dict, + ): + for param, value in non_default_params.items(): + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if ( + param == "stream" and value is True + ): # sending stream = False, can cause it to get passed unchecked and raise issues + optional_params["stream"] = value + if param == "n": + optional_params["candidate_count"] = value + if param == "stop": + if isinstance(value, str): + optional_params["stop_sequences"] = [value] + elif isinstance(value, list): + optional_params["stop_sequences"] = value + if param == "max_tokens": + optional_params["max_output_tokens"] = value + if param == "response_format" and value["type"] == "json_object": # type: ignore + optional_params["response_mime_type"] = "application/json" + if param == "frequency_penalty": + optional_params["frequency_penalty"] = value + if param == "presence_penalty": + optional_params["presence_penalty"] = value + if param == "tools" and isinstance(value, list): + gtool_func_declarations = [] + for tool in value: + gtool_func_declaration = FunctionDeclaration( + name=tool["function"]["name"], + description=tool["function"].get("description", ""), + parameters=tool["function"].get("parameters", {}), + ) + gtool_func_declarations.append(gtool_func_declaration) + optional_params["tools"] = [ + Tools(function_declarations=gtool_func_declarations) + ] + if param == "tool_choice" and ( + isinstance(value, str) or isinstance(value, dict) + ): + _tool_choice_value = self.map_tool_choice_values( + model=model, tool_choice=value # type: ignore + ) + if _tool_choice_value is not None: + optional_params["tool_choice"] = _tool_choice_value + return optional_params + + def get_mapped_special_auth_params(self) -> dict: """ - Not all gemini models support system instructions + Common auth params across bedrock/vertex_ai/azure/watsonx """ - return True + return {"project": "vertex_project", "region_name": "vertex_location"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def get_eu_regions(self) -> List[str]: + """ + Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions + """ + return [ + "europe-central2", + "europe-north1", + "europe-southwest1", + "europe-west1", + "europe-west2", + "europe-west3", + "europe-west4", + "europe-west6", + "europe-west8", + "europe-west9", + ] async def make_call( @@ -165,21 +354,37 @@ class VertexLLM(BaseLLM): ## GET MODEL ## model_response.model = model ## GET TEXT ## + chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} + content_str = "" + tools: List[ChatCompletionToolCallChunk] = [] for idx, candidate in enumerate(completion_response["candidates"]): - if candidate.get("content", None) is None: + if "content" not in candidate: continue - message = litellm.Message( - content=candidate["content"]["parts"][0]["text"], - role="assistant", - logprobs=None, - function_call=None, - tool_calls=None, - ) + if "text" in candidate["content"]["parts"][0]: + content_str = candidate["content"]["parts"][0]["text"] + + if "functionCall" in candidate["content"]["parts"][0]: + _function_chunk = ChatCompletionToolCallFunctionChunk( + name=candidate["content"]["parts"][0]["functionCall"]["name"], + arguments=json.dumps( + candidate["content"]["parts"][0]["functionCall"]["args"] + ), + ) + _tool_response_chunk = ChatCompletionToolCallChunk( + id=f"call_{str(uuid.uuid4())}", + type="function", + function=_function_chunk, + ) + tools.append(_tool_response_chunk) + + chat_completion_message["content"] = content_str + chat_completion_message["tool_calls"] = tools + choice = litellm.Choices( finish_reason=candidate.get("finishReason", "stop"), index=candidate.get("index", idx), - message=message, + message=chat_completion_message, # type: ignore logprobs=None, enhancements=None, ) @@ -402,8 +607,14 @@ class VertexLLM(BaseLLM): messages.pop(idx) system_instructions = SystemInstructions(parts=system_content_blocks) content = _gemini_convert_messages_with_history(messages=messages) + tools: Optional[Tools] = optional_params.pop("tools", None) + tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) data = RequestBody(system_instruction=system_instructions, contents=content) + if tools is not None: + data["tools"] = tools + if tool_choice is not None: + data["toolConfig"] = tool_choice headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {auth_header}", diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 7f0b49808c..3037f51e6e 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -615,9 +615,76 @@ def test_gemini_pro_vision_base64(): pytest.fail(f"An exception occurred - {str(e)}") -@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +@pytest.mark.parametrize("sync_mode", [True, False]) # "vertex_ai", @pytest.mark.asyncio -async def test_gemini_pro_function_calling(sync_mode): +async def test_gemini_pro_function_calling_httpx(provider, sync_mode): + try: + load_vertex_ai_credentials() + litellm.set_verbose = True + + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + ] + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], + }, + }, + } + ] + + data = { + "model": "{}/gemini-1.5-pro-preview-0514".format(provider), + "messages": messages, + "tools": tools, + "tool_choice": "required", + } + if sync_mode: + response = litellm.completion(**data) + else: + response = await litellm.acompletion(**data) + + print(f"response: {response}") + + assert response.choices[0].message.tool_calls[0].function.arguments is not None + assert isinstance( + response.choices[0].message.tool_calls[0].function.arguments, str + ) + except litellm.RateLimitError as e: + pass + except Exception as e: + if "429 Quota exceeded" in str(e): + pass + else: + pytest.fail("An unexpected exception occurred - {}".format(str(e))) + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("provider", ["vertex_ai"]) +@pytest.mark.asyncio +async def test_gemini_pro_function_calling(provider, sync_mode): try: load_vertex_ai_credentials() litellm.set_verbose = True @@ -679,7 +746,7 @@ async def test_gemini_pro_function_calling(sync_mode): ] data = { - "model": "vertex_ai/gemini-1.5-pro-preview-0514", + "model": "{}/gemini-1.5-pro-preview-0514".format(provider), "messages": messages, "tools": tools, } diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index fe903841ef..18207b88e9 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -49,6 +49,24 @@ class PartType(TypedDict, total=False): function_response: FunctionResponse +class HttpxFunctionCall(TypedDict): + name: str + args: dict + + +class HttpxPartType(TypedDict, total=False): + text: str + inline_data: BlobType + file_data: FileDataType + functionCall: HttpxFunctionCall + function_response: FunctionResponse + + +class HttpxContentType(TypedDict, total=False): + role: Literal["user", "model"] + parts: Required[List[HttpxPartType]] + + class ContentType(TypedDict, total=False): role: Literal["user", "model"] parts: Required[List[PartType]] @@ -128,11 +146,19 @@ class GenerationConfig(TypedDict, total=False): response_mime_type: Literal["text/plain", "application/json"] +class Tools(TypedDict): + function_declarations: List[FunctionDeclaration] + + +class ToolConfig(TypedDict): + functionCallingConfig: FunctionCallingConfig + + class RequestBody(TypedDict, total=False): contents: Required[List[ContentType]] system_instruction: SystemInstructions - tools: FunctionDeclaration - tool_config: FunctionCallingConfig + tools: Tools + toolConfig: ToolConfig safety_settings: SafetSettingsConfig generation_config: GenerationConfig @@ -176,7 +202,7 @@ class GroundingMetadata(TypedDict, total=False): class Candidates(TypedDict, total=False): index: int - content: ContentType + content: HttpxContentType finishReason: Literal[ "FINISH_REASON_UNSPECIFIED", "STOP", diff --git a/litellm/utils.py b/litellm/utils.py index f132e3202b..cfec3fd4a4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5386,6 +5386,16 @@ def get_optional_params( print_verbose( f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}" ) + elif custom_llm_provider == "vertex_ai_beta": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.VertexGeminiConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) elif ( custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models ): diff --git a/log.txt b/log.txt new file mode 100644 index 0000000000..9f76605636 --- /dev/null +++ b/log.txt @@ -0,0 +1,10 @@ +============================= test session starts ============================== +platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/bin/python3.11 +cachedir: .pytest_cache +rootdir: /Users/krrishdholakia/Documents/litellm +configfile: pyproject.toml +plugins: logfire-0.35.0, asyncio-0.23.6, mock-3.14.0, anyio-4.2.0 +asyncio: mode=Mode.STRICT +collecting ... collected 0 items + +============================ no tests ran in 0.00s ============================= From ab4b1d931bf1d543ff11082a6486d8917f38b5f6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 12 Jun 2024 21:45:47 -0700 Subject: [PATCH 5/6] fix(vertex_httpx.py): support json schema --- docs/my-website/docs/providers/vertex.md | 151 +++++++++++++++++- litellm/llms/vertex_httpx.py | 8 +- .../tests/test_amazing_vertex_completion.py | 32 ++++ litellm/types/llms/vertex_ai.py | 4 +- 4 files changed, 188 insertions(+), 7 deletions(-) diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index a5c8e06c9f..33ce3e3406 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -8,6 +8,152 @@ import TabItem from '@theme/TabItem'; Open In Colab +## 🆕 `vertex_ai_beta/` route + +New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk). + +```python +from litellm import completion +import json + +## GET CREDENTIALS +file_path = 'path/to/vertex_ai_service_account.json' + +# Load the JSON file +with open(file_path, 'r') as file: + vertex_credentials = json.load(file) + +# Convert to JSON string +vertex_credentials_json = json.dumps(vertex_credentials) + +## COMPLETION CALL +response = completion( + model="vertex_ai_beta/gemini-pro", + messages=[{ "content": "Hello, how are you?","role": "user"}], + vertex_credentials=vertex_credentials_json +) +``` + +### **System Message** + +```python +from litellm import completion +import json + +## GET CREDENTIALS +file_path = 'path/to/vertex_ai_service_account.json' + +# Load the JSON file +with open(file_path, 'r') as file: + vertex_credentials = json.load(file) + +# Convert to JSON string +vertex_credentials_json = json.dumps(vertex_credentials) + + +response = completion( + model="vertex_ai_beta/gemini-pro", + messages=[{"content": "You are a good bot.","role": "system"}, {"content": "Hello, how are you?","role": "user"}], + vertex_credentials=vertex_credentials_json +) +``` + +### **Function Calling** + +Force Gemini to make tool calls with `tool_choice="required"`. + +```python +from litellm import completion +import json + +## GET CREDENTIALS +file_path = 'path/to/vertex_ai_service_account.json' + +# Load the JSON file +with open(file_path, 'r') as file: + vertex_credentials = json.load(file) + +# Convert to JSON string +vertex_credentials_json = json.dumps(vertex_credentials) + + +messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, +] + +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], + }, + }, + } +] + +data = { + "model": "vertex_ai_beta/gemini-1.5-pro-preview-0514"), + "messages": messages, + "tools": tools, + "tool_choice": "required", + "vertex_credentials": vertex_credentials_json +} + +## COMPLETION CALL +print(completion(**data)) +``` + +### **JSON Schema** + +```python +from litellm import completion + +## GET CREDENTIALS +file_path = 'path/to/vertex_ai_service_account.json' + +# Load the JSON file +with open(file_path, 'r') as file: + vertex_credentials = json.load(file) + +# Convert to JSON string +vertex_credentials_json = json.dumps(vertex_credentials) + +messages = [ + { + "role": "user", + "content": """ +List 5 popular cookie recipes. + +Using this JSON schema: + + Recipe = {"recipe_name": str} + +Return a `list[Recipe]` + """ + } +] + +completion(model="vertex_ai_beta/gemini-1.5-flash-preview-0514", messages=messages, response_format={ "type": "json_object" }) +``` + ## Pre-requisites * `pip install google-cloud-aiplatform` (pre-installed on proxy docker image) * Authentication: @@ -140,7 +286,7 @@ In certain use-cases you may need to make calls to the models and pass [safety s ```python response = completion( - model="gemini/gemini-pro", + model="vertex_ai/gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}] safety_settings=[ { @@ -680,6 +826,3 @@ s/o @[Darien Kindlund](https://www.linkedin.com/in/kindlund/) for this tutorial - - - diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index e660e4d729..b1c38f0bc5 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -23,6 +23,7 @@ from litellm.types.llms.vertex_ai import ( FunctionDeclaration, Tools, ToolConfig, + GenerationConfig, ) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.types.utils import GenericStreamingChunk @@ -609,12 +610,17 @@ class VertexLLM(BaseLLM): content = _gemini_convert_messages_with_history(messages=messages) tools: Optional[Tools] = optional_params.pop("tools", None) tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) - + generation_config: Optional[GenerationConfig] = GenerationConfig( + **optional_params + ) data = RequestBody(system_instruction=system_instructions, contents=content) if tools is not None: data["tools"] = tools if tool_choice is not None: data["toolConfig"] = tool_choice + if generation_config is not None: + data["generationConfig"] = generation_config + headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {auth_header}", diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 7281b8a08e..eacadf5299 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -681,6 +681,38 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode): pytest.fail("An unexpected exception occurred - {}".format(str(e))) +@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +@pytest.mark.asyncio +async def test_gemini_pro_json_schema_httpx(provider): + load_vertex_ai_credentials() + litellm.set_verbose = True + messages = [ + { + "role": "user", + "content": """ + List 5 popular cookie recipes. + + Using this JSON schema: + + Recipe = {"recipe_name": str} + + Return a `list[Recipe]` + """, + } + ] + + response = completion( + model="vertex_ai_beta/gemini-1.5-flash-preview-0514", + messages=messages, + response_format={"type": "json_object"}, + ) + + assert response.choices[0].message.content is not None + response_json = json.loads(response.choices[0].message.content) + + assert isinstance(response_json, dict) or isinstance(response_json, list) + + @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("provider", ["vertex_ai"]) @pytest.mark.asyncio diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 18207b88e9..d9cd25f302 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -159,8 +159,8 @@ class RequestBody(TypedDict, total=False): system_instruction: SystemInstructions tools: Tools toolConfig: ToolConfig - safety_settings: SafetSettingsConfig - generation_config: GenerationConfig + safetySettings: SafetSettingsConfig + generationConfig: GenerationConfig class SafetyRatings(TypedDict): From a84021c84871973c5d7a9ffe6010b5a15e390348 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 12 Jun 2024 22:08:43 -0700 Subject: [PATCH 6/6] test(test_amazing_vertex_completion.py): reduce vertex tests to avoid exhausting quota --- litellm/tests/test_amazing_vertex_completion.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index eacadf5299..72c2e95c15 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -615,8 +615,9 @@ def test_gemini_pro_vision_base64(): pytest.fail(f"An exception occurred - {str(e)}") +@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", -@pytest.mark.parametrize("sync_mode", [True, False]) # "vertex_ai", +@pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai", @pytest.mark.asyncio async def test_gemini_pro_function_calling_httpx(provider, sync_mode): try: @@ -656,7 +657,7 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode): ] data = { - "model": "{}/gemini-1.5-pro-preview-0514".format(provider), + "model": "{}/gemini-1.5-pro".format(provider), "messages": messages, "tools": tools, "tool_choice": "required", @@ -681,6 +682,7 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode): pytest.fail("An unexpected exception occurred - {}".format(str(e))) +@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.asyncio async def test_gemini_pro_json_schema_httpx(provider): @@ -713,7 +715,8 @@ async def test_gemini_pro_json_schema_httpx(provider): assert isinstance(response_json, dict) or isinstance(response_json, list) -@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") +@pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.parametrize("provider", ["vertex_ai"]) @pytest.mark.asyncio async def test_gemini_pro_function_calling(provider, sync_mode):