fix(vertex_ai_anthropic.py): support pre-filling "{" for json mode

This commit is contained in:
Krrish Dholakia 2024-06-29 18:54:10 -07:00
parent b699d9a8b9
commit 4b1e85f54e
4 changed files with 137 additions and 55 deletions

View file

@ -1,24 +1,32 @@
# What is this? # What is this?
## Handler file for calling claude-3 on vertex ai ## Handler file for calling claude-3 on vertex ai
import os, types import copy
import json import json
import os
import time
import types
import uuid
from enum import Enum from enum import Enum
import requests, copy # type: ignore from typing import Any, Callable, List, Optional, Tuple
import time, uuid
from typing import Callable, Optional, List import httpx # type: ignore
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper import requests # type: ignore
from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .prompt_templates.factory import ( from .prompt_templates.factory import (
contains_tag,
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt, construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags, extract_between_tags,
parse_xml_params, parse_xml_params,
prompt_factory,
response_schema_prompt,
) )
import httpx # type: ignore
class VertexAIError(Exception): class VertexAIError(Exception):
@ -104,6 +112,7 @@ class VertexAIAnthropicConfig:
"stop", "stop",
"temperature", "temperature",
"top_p", "top_p",
"response_format",
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(self, non_default_params: dict, optional_params: dict):
@ -120,6 +129,8 @@ class VertexAIAnthropicConfig:
optional_params["temperature"] = value optional_params["temperature"] = value
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "response_format" and "response_schema" in value:
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore
return optional_params return optional_params
@ -129,7 +140,6 @@ class VertexAIAnthropicConfig:
""" """
# makes headers for API call
def refresh_auth( def refresh_auth(
credentials, credentials,
) -> str: # used when user passes in credentials as json string ) -> str: # used when user passes in credentials as json string
@ -144,6 +154,40 @@ def refresh_auth(
return credentials.token return credentials.token
def get_vertex_client(
client: Any,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
) -> Tuple[Any, Optional[str]]:
args = locals()
from litellm.llms.vertex_httpx import VertexLLM
try:
from anthropic import AnthropicVertex
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
)
access_token: Optional[str] = None
if client is None:
_credentials, cred_project_id = VertexLLM().load_auth(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project or cred_project_id,
region=vertex_location or "us-central1",
access_token=_credentials.token,
)
else:
vertex_ai_client = client
return vertex_ai_client, access_token
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -151,10 +195,10 @@ def completion(
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params: dict,
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,
vertex_credentials=None, vertex_credentials=None,
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool = False, acompletion: bool = False,
@ -178,6 +222,13 @@ def completion(
) )
try: try:
vertex_ai_client, access_token = get_vertex_client(
client=client,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
)
## Load Config ## Load Config
config = litellm.VertexAIAnthropicConfig.get_config() config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items(): for k, v in config.items():
@ -186,6 +237,7 @@ def completion(
## Format Prompt ## Format Prompt
_is_function_call = False _is_function_call = False
_is_json_schema = False
messages = copy.deepcopy(messages) messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params) optional_params = copy.deepcopy(optional_params)
# Separate system prompt from rest of message # Separate system prompt from rest of message
@ -200,6 +252,29 @@ def completion(
messages.pop(idx) messages.pop(idx)
if len(system_prompt) > 0: if len(system_prompt) > 0:
optional_params["system"] = system_prompt optional_params["system"] = system_prompt
# Checks for 'response_schema' support - if passed in
if "response_format" in optional_params:
response_format_chunk = ResponseFormatChunk(
**optional_params["response_format"] # type: ignore
)
supports_response_schema = litellm.supports_response_schema(
model=model, custom_llm_provider="vertex_ai"
)
if (
supports_response_schema is False
and response_format_chunk["type"] == "json_object"
and "response_schema" in response_format_chunk
):
_is_json_schema = True
user_response_schema_message = response_schema_prompt(
model=model,
response_schema=response_format_chunk["response_schema"],
)
messages.append(
{"role": "user", "content": user_response_schema_message}
)
messages.append({"role": "assistant", "content": "{"})
optional_params.pop("response_format")
# Format rest of message according to anthropic guidelines # Format rest of message according to anthropic guidelines
try: try:
messages = prompt_factory( messages = prompt_factory(
@ -233,32 +308,6 @@ def completion(
print_verbose( print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}" f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
) )
access_token = None
if client is None:
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
try:
json_obj = json.loads(vertex_credentials)
except json.JSONDecodeError:
json_obj = json.load(open(vertex_credentials))
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
### CHECK IF ACCESS
access_token = refresh_auth(credentials=creds)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project,
region=vertex_location,
access_token=access_token,
)
else:
vertex_ai_client = client
if acompletion == True: if acompletion == True:
""" """
@ -315,7 +364,16 @@ def completion(
) )
message = vertex_ai_client.messages.create(**data) # type: ignore message = vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=message,
additional_args={"complete_input_dict": data},
)
text_content: str = message.content[0].text
## TOOL CALLING - OUTPUT PARSE ## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content): if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0] function_name = extract_between_tags("tool_name", text_content)[0]
@ -338,6 +396,12 @@ def completion(
content=None, content=None,
) )
model_response.choices[0].message = _message # type: ignore model_response.choices[0].message = _message # type: ignore
else:
if (
_is_json_schema
): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
json_response = "{" + text_content[: text_content.rfind("}") + 1]
model_response.choices[0].message.content = json_response # type: ignore
else: else:
model_response.choices[0].message.content = text_content # type: ignore model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason) model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)

View file

@ -993,10 +993,10 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, supports_response_schema", "model, vertex_location, supports_response_schema",
[ [
("vertex_ai_beta/gemini-1.5-pro-001", True), ("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
("vertex_ai_beta/gemini-1.5-flash", False), ("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -1005,7 +1005,7 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_pro_json_schema_args_sent_httpx( async def test_gemini_pro_json_schema_args_sent_httpx(
model, supports_response_schema, invalid_response model, supports_response_schema, vertex_location, invalid_response
): ):
load_vertex_ai_credentials() load_vertex_ai_credentials()
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
@ -1015,17 +1015,27 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
messages = [{"role": "user", "content": "List 5 cookie recipes"}] messages = [{"role": "user", "content": "List 5 cookie recipes"}]
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
# response_schema = {
# "type": "array",
# "items": {
# "type": "object",
# "properties": {
# "recipe_name": {
# "type": "string",
# },
# },
# "required": ["recipe_name"],
# },
# }
response_schema = { response_schema = {
"type": "array",
"items": {
"type": "object", "type": "object",
"properties": { "properties": {
"recipe_name": { "recipe_name": {"type": "string"},
"type": "string", "ingredients": {"type": "array", "items": {"type": "string"}},
}, "prep_time": {"type": "number"},
}, "difficulty": {"type": "string", "enum": ["easy", "medium", "hard"]},
"required": ["recipe_name"],
}, },
"required": ["recipe_name", "ingredients", "prep_time"],
} }
client = HTTPHandler() client = HTTPHandler()
@ -1044,12 +1054,13 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
"type": "json_object", "type": "json_object",
"response_schema": response_schema, "response_schema": response_schema,
}, },
vertex_location=vertex_location,
client=client, client=client,
) )
if invalid_response is True: if invalid_response is True:
pytest.fail("Expected this to fail") pytest.fail("Expected this to fail")
except litellm.JSONSchemaValidationError as e: except litellm.JSONSchemaValidationError as e:
if invalid_response is False: if invalid_response is False and "claude-3" not in model:
pytest.fail("Expected this to pass. Got={}".format(e)) pytest.fail("Expected this to pass. Got={}".format(e))
mock_call.assert_called_once() mock_call.assert_called_once()

View file

@ -995,3 +995,8 @@ class GenericImageParsingChunk(TypedDict):
type: str type: str
media_type: str media_type: str
data: str data: str
class ResponseFormatChunk(TypedDict, total=False):
type: Required[Literal["json_object", "text"]]
response_schema: dict

View file

@ -1351,7 +1351,9 @@ def client(original_function):
).total_seconds() * 1000 # return response latency in ms like openai ).total_seconds() * 1000 # return response latency in ms like openai
### POST-CALL RULES ### ### POST-CALL RULES ###
post_call_processing(original_response=result, model=model) post_call_processing(
original_response=result, model=model, optional_params=kwargs
)
# [OPTIONAL] ADD TO CACHE # [OPTIONAL] ADD TO CACHE
if ( if (