fix(vertex_ai_anthropic.py): support pre-filling "{" for json mode

This commit is contained in:
Krrish Dholakia 2024-06-29 18:54:10 -07:00
parent b699d9a8b9
commit 4b1e85f54e
4 changed files with 137 additions and 55 deletions

View file

@ -1,24 +1,32 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
import os, types
import copy
import json
import os
import time
import types
import uuid
from enum import Enum
import requests, copy # type: ignore
import time, uuid
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from typing import Any, Callable, List, Optional, Tuple
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .prompt_templates.factory import (
contains_tag,
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
import httpx # type: ignore
class VertexAIError(Exception):
@ -104,6 +112,7 @@ class VertexAIAnthropicConfig:
"stop",
"temperature",
"top_p",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
@ -120,6 +129,8 @@ class VertexAIAnthropicConfig:
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and "response_schema" in value:
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore
return optional_params
@ -129,7 +140,6 @@ class VertexAIAnthropicConfig:
"""
# makes headers for API call
def refresh_auth(
credentials,
) -> str: # used when user passes in credentials as json string
@ -144,6 +154,40 @@ def refresh_auth(
return credentials.token
def get_vertex_client(
client: Any,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
) -> Tuple[Any, Optional[str]]:
args = locals()
from litellm.llms.vertex_httpx import VertexLLM
try:
from anthropic import AnthropicVertex
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
)
access_token: Optional[str] = None
if client is None:
_credentials, cred_project_id = VertexLLM().load_auth(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project or cred_project_id,
region=vertex_location or "us-central1",
access_token=_credentials.token,
)
else:
vertex_ai_client = client
return vertex_ai_client, access_token
def completion(
model: str,
messages: list,
@ -151,10 +195,10 @@ def completion(
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
@ -178,6 +222,13 @@ def completion(
)
try:
vertex_ai_client, access_token = get_vertex_client(
client=client,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
)
## Load Config
config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items():
@ -186,6 +237,7 @@ def completion(
## Format Prompt
_is_function_call = False
_is_json_schema = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
# Separate system prompt from rest of message
@ -200,6 +252,29 @@ def completion(
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Checks for 'response_schema' support - if passed in
if "response_format" in optional_params:
response_format_chunk = ResponseFormatChunk(
**optional_params["response_format"] # type: ignore
)
supports_response_schema = litellm.supports_response_schema(
model=model, custom_llm_provider="vertex_ai"
)
if (
supports_response_schema is False
and response_format_chunk["type"] == "json_object"
and "response_schema" in response_format_chunk
):
_is_json_schema = True
user_response_schema_message = response_schema_prompt(
model=model,
response_schema=response_format_chunk["response_schema"],
)
messages.append(
{"role": "user", "content": user_response_schema_message}
)
messages.append({"role": "assistant", "content": "{"})
optional_params.pop("response_format")
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
@ -233,32 +308,6 @@ def completion(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
)
access_token = None
if client is None:
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
try:
json_obj = json.loads(vertex_credentials)
except json.JSONDecodeError:
json_obj = json.load(open(vertex_credentials))
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
### CHECK IF ACCESS
access_token = refresh_auth(credentials=creds)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project,
region=vertex_location,
access_token=access_token,
)
else:
vertex_ai_client = client
if acompletion == True:
"""
@ -315,7 +364,16 @@ def completion(
)
message = vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=message,
additional_args={"complete_input_dict": data},
)
text_content: str = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
@ -339,7 +397,13 @@ def completion(
)
model_response.choices[0].message = _message # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
if (
_is_json_schema
): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
json_response = "{" + text_content[: text_content.rfind("}") + 1]
model_response.choices[0].message.content = json_response # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
## CALCULATING USAGE

View file

@ -993,10 +993,10 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
@pytest.mark.parametrize(
"model, supports_response_schema",
"model, vertex_location, supports_response_schema",
[
("vertex_ai_beta/gemini-1.5-pro-001", True),
("vertex_ai_beta/gemini-1.5-flash", False),
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
],
)
@pytest.mark.parametrize(
@ -1005,7 +1005,7 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
)
@pytest.mark.asyncio
async def test_gemini_pro_json_schema_args_sent_httpx(
model, supports_response_schema, invalid_response
model, supports_response_schema, vertex_location, invalid_response
):
load_vertex_ai_credentials()
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
@ -1015,17 +1015,27 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
from litellm.llms.custom_httpx.http_handler import HTTPHandler
# response_schema = {
# "type": "array",
# "items": {
# "type": "object",
# "properties": {
# "recipe_name": {
# "type": "string",
# },
# },
# "required": ["recipe_name"],
# },
# }
response_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipe_name": {
"type": "string",
},
},
"required": ["recipe_name"],
"type": "object",
"properties": {
"recipe_name": {"type": "string"},
"ingredients": {"type": "array", "items": {"type": "string"}},
"prep_time": {"type": "number"},
"difficulty": {"type": "string", "enum": ["easy", "medium", "hard"]},
},
"required": ["recipe_name", "ingredients", "prep_time"],
}
client = HTTPHandler()
@ -1044,12 +1054,13 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
"type": "json_object",
"response_schema": response_schema,
},
vertex_location=vertex_location,
client=client,
)
if invalid_response is True:
pytest.fail("Expected this to fail")
except litellm.JSONSchemaValidationError as e:
if invalid_response is False:
if invalid_response is False and "claude-3" not in model:
pytest.fail("Expected this to pass. Got={}".format(e))
mock_call.assert_called_once()

View file

@ -995,3 +995,8 @@ class GenericImageParsingChunk(TypedDict):
type: str
media_type: str
data: str
class ResponseFormatChunk(TypedDict, total=False):
type: Required[Literal["json_object", "text"]]
response_schema: dict

View file

@ -1351,7 +1351,9 @@ def client(original_function):
).total_seconds() * 1000 # return response latency in ms like openai
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
post_call_processing(
original_response=result, model=model, optional_params=kwargs
)
# [OPTIONAL] ADD TO CACHE
if (