From 4b1e85f54e06bf8186b87c7872bdb3baaf921b23 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 29 Jun 2024 18:54:10 -0700 Subject: [PATCH] fix(vertex_ai_anthropic.py): support pre-filling "{" for json mode --- litellm/llms/vertex_ai_anthropic.py | 144 +++++++++++++----- .../tests/test_amazing_vertex_completion.py | 39 +++-- litellm/types/utils.py | 5 + litellm/utils.py | 4 +- 4 files changed, 137 insertions(+), 55 deletions(-) diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index ee6653afc..6b39716f1 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -1,24 +1,32 @@ # What is this? ## Handler file for calling claude-3 on vertex ai -import os, types +import copy import json +import os +import time +import types +import uuid from enum import Enum -import requests, copy # type: ignore -import time, uuid -from typing import Callable, Optional, List -from litellm.utils import ModelResponse, Usage, CustomStreamWrapper -from litellm.litellm_core_utils.core_helpers import map_finish_reason +from typing import Any, Callable, List, Optional, Tuple + +import httpx # type: ignore +import requests # type: ignore + import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.utils import ResponseFormatChunk +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + from .prompt_templates.factory import ( - contains_tag, - prompt_factory, - custom_prompt, construct_tool_use_system_prompt, + contains_tag, + custom_prompt, extract_between_tags, parse_xml_params, + prompt_factory, + response_schema_prompt, ) -import httpx # type: ignore class VertexAIError(Exception): @@ -104,6 +112,7 @@ class VertexAIAnthropicConfig: "stop", "temperature", "top_p", + "response_format", ] def map_openai_params(self, non_default_params: dict, optional_params: dict): @@ -120,6 +129,8 @@ class VertexAIAnthropicConfig: optional_params["temperature"] = value if param == "top_p": optional_params["top_p"] = value + if param == "response_format" and "response_schema" in value: + optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore return optional_params @@ -129,7 +140,6 @@ class VertexAIAnthropicConfig: """ -# makes headers for API call def refresh_auth( credentials, ) -> str: # used when user passes in credentials as json string @@ -144,6 +154,40 @@ def refresh_auth( return credentials.token +def get_vertex_client( + client: Any, + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], +) -> Tuple[Any, Optional[str]]: + args = locals() + from litellm.llms.vertex_httpx import VertexLLM + + try: + from anthropic import AnthropicVertex + except Exception: + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""", + ) + + access_token: Optional[str] = None + + if client is None: + _credentials, cred_project_id = VertexLLM().load_auth( + credentials=vertex_credentials, project_id=vertex_project + ) + vertex_ai_client = AnthropicVertex( + project_id=vertex_project or cred_project_id, + region=vertex_location or "us-central1", + access_token=_credentials.token, + ) + else: + vertex_ai_client = client + + return vertex_ai_client, access_token + + def completion( model: str, messages: list, @@ -151,10 +195,10 @@ def completion( print_verbose: Callable, encoding, logging_obj, + optional_params: dict, vertex_project=None, vertex_location=None, vertex_credentials=None, - optional_params=None, litellm_params=None, logger_fn=None, acompletion: bool = False, @@ -178,6 +222,13 @@ def completion( ) try: + vertex_ai_client, access_token = get_vertex_client( + client=client, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + ) + ## Load Config config = litellm.VertexAIAnthropicConfig.get_config() for k, v in config.items(): @@ -186,6 +237,7 @@ def completion( ## Format Prompt _is_function_call = False + _is_json_schema = False messages = copy.deepcopy(messages) optional_params = copy.deepcopy(optional_params) # Separate system prompt from rest of message @@ -200,6 +252,29 @@ def completion( messages.pop(idx) if len(system_prompt) > 0: optional_params["system"] = system_prompt + # Checks for 'response_schema' support - if passed in + if "response_format" in optional_params: + response_format_chunk = ResponseFormatChunk( + **optional_params["response_format"] # type: ignore + ) + supports_response_schema = litellm.supports_response_schema( + model=model, custom_llm_provider="vertex_ai" + ) + if ( + supports_response_schema is False + and response_format_chunk["type"] == "json_object" + and "response_schema" in response_format_chunk + ): + _is_json_schema = True + user_response_schema_message = response_schema_prompt( + model=model, + response_schema=response_format_chunk["response_schema"], + ) + messages.append( + {"role": "user", "content": user_response_schema_message} + ) + messages.append({"role": "assistant", "content": "{"}) + optional_params.pop("response_format") # Format rest of message according to anthropic guidelines try: messages = prompt_factory( @@ -233,32 +308,6 @@ def completion( print_verbose( f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}" ) - access_token = None - if client is None: - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - - try: - json_obj = json.loads(vertex_credentials) - except json.JSONDecodeError: - json_obj = json.load(open(vertex_credentials)) - - creds = ( - google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - ) - ### CHECK IF ACCESS - access_token = refresh_auth(credentials=creds) - - vertex_ai_client = AnthropicVertex( - project_id=vertex_project, - region=vertex_location, - access_token=access_token, - ) - else: - vertex_ai_client = client if acompletion == True: """ @@ -315,7 +364,16 @@ def completion( ) message = vertex_ai_client.messages.create(**data) # type: ignore - text_content = message.content[0].text + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=message, + additional_args={"complete_input_dict": data}, + ) + + text_content: str = message.content[0].text ## TOOL CALLING - OUTPUT PARSE if text_content is not None and contains_tag("invoke", text_content): function_name = extract_between_tags("tool_name", text_content)[0] @@ -339,7 +397,13 @@ def completion( ) model_response.choices[0].message = _message # type: ignore else: - model_response.choices[0].message.content = text_content # type: ignore + if ( + _is_json_schema + ): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb + json_response = "{" + text_content[: text_content.rfind("}") + 1] + model_response.choices[0].message.content = json_response # type: ignore + else: + model_response.choices[0].message.content = text_content # type: ignore model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason) ## CALCULATING USAGE diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index ecad16fe6..505933fc2 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -993,10 +993,10 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): @pytest.mark.parametrize( - "model, supports_response_schema", + "model, vertex_location, supports_response_schema", [ - ("vertex_ai_beta/gemini-1.5-pro-001", True), - ("vertex_ai_beta/gemini-1.5-flash", False), + ("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True), + ("vertex_ai_beta/gemini-1.5-flash", "us-central1", False), ], ) @pytest.mark.parametrize( @@ -1005,7 +1005,7 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): ) @pytest.mark.asyncio async def test_gemini_pro_json_schema_args_sent_httpx( - model, supports_response_schema, invalid_response + model, supports_response_schema, vertex_location, invalid_response ): load_vertex_ai_credentials() os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" @@ -1015,17 +1015,27 @@ async def test_gemini_pro_json_schema_args_sent_httpx( messages = [{"role": "user", "content": "List 5 cookie recipes"}] from litellm.llms.custom_httpx.http_handler import HTTPHandler + # response_schema = { + # "type": "array", + # "items": { + # "type": "object", + # "properties": { + # "recipe_name": { + # "type": "string", + # }, + # }, + # "required": ["recipe_name"], + # }, + # } response_schema = { - "type": "array", - "items": { - "type": "object", - "properties": { - "recipe_name": { - "type": "string", - }, - }, - "required": ["recipe_name"], + "type": "object", + "properties": { + "recipe_name": {"type": "string"}, + "ingredients": {"type": "array", "items": {"type": "string"}}, + "prep_time": {"type": "number"}, + "difficulty": {"type": "string", "enum": ["easy", "medium", "hard"]}, }, + "required": ["recipe_name", "ingredients", "prep_time"], } client = HTTPHandler() @@ -1044,12 +1054,13 @@ async def test_gemini_pro_json_schema_args_sent_httpx( "type": "json_object", "response_schema": response_schema, }, + vertex_location=vertex_location, client=client, ) if invalid_response is True: pytest.fail("Expected this to fail") except litellm.JSONSchemaValidationError as e: - if invalid_response is False: + if invalid_response is False and "claude-3" not in model: pytest.fail("Expected this to pass. Got={}".format(e)) mock_call.assert_called_once() diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 51ce08671..d6b7bf744 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -995,3 +995,8 @@ class GenericImageParsingChunk(TypedDict): type: str media_type: str data: str + + +class ResponseFormatChunk(TypedDict, total=False): + type: Required[Literal["json_object", "text"]] + response_schema: dict diff --git a/litellm/utils.py b/litellm/utils.py index e75a5df73..877685416 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1351,7 +1351,9 @@ def client(original_function): ).total_seconds() * 1000 # return response latency in ms like openai ### POST-CALL RULES ### - post_call_processing(original_response=result, model=model) + post_call_processing( + original_response=result, model=model, optional_params=kwargs + ) # [OPTIONAL] ADD TO CACHE if (