forked from phoenix/litellm-mirror
fix(vertex_ai_anthropic.py): support pre-filling "{" for json mode
This commit is contained in:
parent
b699d9a8b9
commit
4b1e85f54e
4 changed files with 137 additions and 55 deletions
|
@ -1,24 +1,32 @@
|
|||
# What is this?
|
||||
## Handler file for calling claude-3 on vertex ai
|
||||
import os, types
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
import uuid
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time, uuid
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.utils import ResponseFormatChunk
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
from .prompt_templates.factory import (
|
||||
contains_tag,
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
construct_tool_use_system_prompt,
|
||||
contains_tag,
|
||||
custom_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
prompt_factory,
|
||||
response_schema_prompt,
|
||||
)
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -104,6 +112,7 @@ class VertexAIAnthropicConfig:
|
|||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
|
@ -120,6 +129,8 @@ class VertexAIAnthropicConfig:
|
|||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "response_format" and "response_schema" in value:
|
||||
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore
|
||||
return optional_params
|
||||
|
||||
|
||||
|
@ -129,7 +140,6 @@ class VertexAIAnthropicConfig:
|
|||
"""
|
||||
|
||||
|
||||
# makes headers for API call
|
||||
def refresh_auth(
|
||||
credentials,
|
||||
) -> str: # used when user passes in credentials as json string
|
||||
|
@ -144,6 +154,40 @@ def refresh_auth(
|
|||
return credentials.token
|
||||
|
||||
|
||||
def get_vertex_client(
|
||||
client: Any,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
) -> Tuple[Any, Optional[str]]:
|
||||
args = locals()
|
||||
from litellm.llms.vertex_httpx import VertexLLM
|
||||
|
||||
try:
|
||||
from anthropic import AnthropicVertex
|
||||
except Exception:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
||||
)
|
||||
|
||||
access_token: Optional[str] = None
|
||||
|
||||
if client is None:
|
||||
_credentials, cred_project_id = VertexLLM().load_auth(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
vertex_ai_client = AnthropicVertex(
|
||||
project_id=vertex_project or cred_project_id,
|
||||
region=vertex_location or "us-central1",
|
||||
access_token=_credentials.token,
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
return vertex_ai_client, access_token
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -151,10 +195,10 @@ def completion(
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
|
@ -178,6 +222,13 @@ def completion(
|
|||
)
|
||||
try:
|
||||
|
||||
vertex_ai_client, access_token = get_vertex_client(
|
||||
client=client,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.VertexAIAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -186,6 +237,7 @@ def completion(
|
|||
|
||||
## Format Prompt
|
||||
_is_function_call = False
|
||||
_is_json_schema = False
|
||||
messages = copy.deepcopy(messages)
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
# Separate system prompt from rest of message
|
||||
|
@ -200,6 +252,29 @@ def completion(
|
|||
messages.pop(idx)
|
||||
if len(system_prompt) > 0:
|
||||
optional_params["system"] = system_prompt
|
||||
# Checks for 'response_schema' support - if passed in
|
||||
if "response_format" in optional_params:
|
||||
response_format_chunk = ResponseFormatChunk(
|
||||
**optional_params["response_format"] # type: ignore
|
||||
)
|
||||
supports_response_schema = litellm.supports_response_schema(
|
||||
model=model, custom_llm_provider="vertex_ai"
|
||||
)
|
||||
if (
|
||||
supports_response_schema is False
|
||||
and response_format_chunk["type"] == "json_object"
|
||||
and "response_schema" in response_format_chunk
|
||||
):
|
||||
_is_json_schema = True
|
||||
user_response_schema_message = response_schema_prompt(
|
||||
model=model,
|
||||
response_schema=response_format_chunk["response_schema"],
|
||||
)
|
||||
messages.append(
|
||||
{"role": "user", "content": user_response_schema_message}
|
||||
)
|
||||
messages.append({"role": "assistant", "content": "{"})
|
||||
optional_params.pop("response_format")
|
||||
# Format rest of message according to anthropic guidelines
|
||||
try:
|
||||
messages = prompt_factory(
|
||||
|
@ -233,32 +308,6 @@ def completion(
|
|||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
||||
)
|
||||
access_token = None
|
||||
if client is None:
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
try:
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
except json.JSONDecodeError:
|
||||
json_obj = json.load(open(vertex_credentials))
|
||||
|
||||
creds = (
|
||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
)
|
||||
### CHECK IF ACCESS
|
||||
access_token = refresh_auth(credentials=creds)
|
||||
|
||||
vertex_ai_client = AnthropicVertex(
|
||||
project_id=vertex_project,
|
||||
region=vertex_location,
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
if acompletion == True:
|
||||
"""
|
||||
|
@ -315,7 +364,16 @@ def completion(
|
|||
)
|
||||
|
||||
message = vertex_ai_client.messages.create(**data) # type: ignore
|
||||
text_content = message.content[0].text
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=message,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
text_content: str = message.content[0].text
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
if text_content is not None and contains_tag("invoke", text_content):
|
||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||
|
@ -339,7 +397,13 @@ def completion(
|
|||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
if (
|
||||
_is_json_schema
|
||||
): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
|
||||
json_response = "{" + text_content[: text_content.rfind("}") + 1]
|
||||
model_response.choices[0].message.content = json_response # type: ignore
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
||||
|
||||
## CALCULATING USAGE
|
||||
|
|
|
@ -993,10 +993,10 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, supports_response_schema",
|
||||
"model, vertex_location, supports_response_schema",
|
||||
[
|
||||
("vertex_ai_beta/gemini-1.5-pro-001", True),
|
||||
("vertex_ai_beta/gemini-1.5-flash", False),
|
||||
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
|
||||
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1005,7 +1005,7 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
|||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||
model, supports_response_schema, invalid_response
|
||||
model, supports_response_schema, vertex_location, invalid_response
|
||||
):
|
||||
load_vertex_ai_credentials()
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
|
@ -1015,17 +1015,27 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
|
|||
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
# response_schema = {
|
||||
# "type": "array",
|
||||
# "items": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "recipe_name": {
|
||||
# "type": "string",
|
||||
# },
|
||||
# },
|
||||
# "required": ["recipe_name"],
|
||||
# },
|
||||
# }
|
||||
response_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"recipe_name": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["recipe_name"],
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"recipe_name": {"type": "string"},
|
||||
"ingredients": {"type": "array", "items": {"type": "string"}},
|
||||
"prep_time": {"type": "number"},
|
||||
"difficulty": {"type": "string", "enum": ["easy", "medium", "hard"]},
|
||||
},
|
||||
"required": ["recipe_name", "ingredients", "prep_time"],
|
||||
}
|
||||
|
||||
client = HTTPHandler()
|
||||
|
@ -1044,12 +1054,13 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
|
|||
"type": "json_object",
|
||||
"response_schema": response_schema,
|
||||
},
|
||||
vertex_location=vertex_location,
|
||||
client=client,
|
||||
)
|
||||
if invalid_response is True:
|
||||
pytest.fail("Expected this to fail")
|
||||
except litellm.JSONSchemaValidationError as e:
|
||||
if invalid_response is False:
|
||||
if invalid_response is False and "claude-3" not in model:
|
||||
pytest.fail("Expected this to pass. Got={}".format(e))
|
||||
|
||||
mock_call.assert_called_once()
|
||||
|
|
|
@ -995,3 +995,8 @@ class GenericImageParsingChunk(TypedDict):
|
|||
type: str
|
||||
media_type: str
|
||||
data: str
|
||||
|
||||
|
||||
class ResponseFormatChunk(TypedDict, total=False):
|
||||
type: Required[Literal["json_object", "text"]]
|
||||
response_schema: dict
|
||||
|
|
|
@ -1351,7 +1351,9 @@ def client(original_function):
|
|||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model)
|
||||
post_call_processing(
|
||||
original_response=result, model=model, optional_params=kwargs
|
||||
)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue