forked from phoenix/litellm-mirror
fix(vertex_ai_anthropic.py): support pre-filling "{" for json mode
This commit is contained in:
parent
b699d9a8b9
commit
4b1e85f54e
4 changed files with 137 additions and 55 deletions
|
@ -1,24 +1,32 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Handler file for calling claude-3 on vertex ai
|
## Handler file for calling claude-3 on vertex ai
|
||||||
import os, types
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, copy # type: ignore
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
import time, uuid
|
|
||||||
from typing import Callable, Optional, List
|
import httpx # type: ignore
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
import requests # type: ignore
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.types.utils import ResponseFormatChunk
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
contains_tag,
|
|
||||||
prompt_factory,
|
|
||||||
custom_prompt,
|
|
||||||
construct_tool_use_system_prompt,
|
construct_tool_use_system_prompt,
|
||||||
|
contains_tag,
|
||||||
|
custom_prompt,
|
||||||
extract_between_tags,
|
extract_between_tags,
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
|
prompt_factory,
|
||||||
|
response_schema_prompt,
|
||||||
)
|
)
|
||||||
import httpx # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -104,6 +112,7 @@ class VertexAIAnthropicConfig:
|
||||||
"stop",
|
"stop",
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
|
"response_format",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
@ -120,6 +129,8 @@ class VertexAIAnthropicConfig:
|
||||||
optional_params["temperature"] = value
|
optional_params["temperature"] = value
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
|
if param == "response_format" and "response_schema" in value:
|
||||||
|
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,7 +140,6 @@ class VertexAIAnthropicConfig:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# makes headers for API call
|
|
||||||
def refresh_auth(
|
def refresh_auth(
|
||||||
credentials,
|
credentials,
|
||||||
) -> str: # used when user passes in credentials as json string
|
) -> str: # used when user passes in credentials as json string
|
||||||
|
@ -144,6 +154,40 @@ def refresh_auth(
|
||||||
return credentials.token
|
return credentials.token
|
||||||
|
|
||||||
|
|
||||||
|
def get_vertex_client(
|
||||||
|
client: Any,
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
) -> Tuple[Any, Optional[str]]:
|
||||||
|
args = locals()
|
||||||
|
from litellm.llms.vertex_httpx import VertexLLM
|
||||||
|
|
||||||
|
try:
|
||||||
|
from anthropic import AnthropicVertex
|
||||||
|
except Exception:
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=400,
|
||||||
|
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
_credentials, cred_project_id = VertexLLM().load_auth(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
vertex_ai_client = AnthropicVertex(
|
||||||
|
project_id=vertex_project or cred_project_id,
|
||||||
|
region=vertex_location or "us-central1",
|
||||||
|
access_token=_credentials.token,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vertex_ai_client = client
|
||||||
|
|
||||||
|
return vertex_ai_client, access_token
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
@ -151,10 +195,10 @@ def completion(
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
vertex_credentials=None,
|
vertex_credentials=None,
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
|
@ -178,6 +222,13 @@ def completion(
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vertex_ai_client, access_token = get_vertex_client(
|
||||||
|
client=client,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.VertexAIAnthropicConfig.get_config()
|
config = litellm.VertexAIAnthropicConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
@ -186,6 +237,7 @@ def completion(
|
||||||
|
|
||||||
## Format Prompt
|
## Format Prompt
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
|
_is_json_schema = False
|
||||||
messages = copy.deepcopy(messages)
|
messages = copy.deepcopy(messages)
|
||||||
optional_params = copy.deepcopy(optional_params)
|
optional_params = copy.deepcopy(optional_params)
|
||||||
# Separate system prompt from rest of message
|
# Separate system prompt from rest of message
|
||||||
|
@ -200,6 +252,29 @@ def completion(
|
||||||
messages.pop(idx)
|
messages.pop(idx)
|
||||||
if len(system_prompt) > 0:
|
if len(system_prompt) > 0:
|
||||||
optional_params["system"] = system_prompt
|
optional_params["system"] = system_prompt
|
||||||
|
# Checks for 'response_schema' support - if passed in
|
||||||
|
if "response_format" in optional_params:
|
||||||
|
response_format_chunk = ResponseFormatChunk(
|
||||||
|
**optional_params["response_format"] # type: ignore
|
||||||
|
)
|
||||||
|
supports_response_schema = litellm.supports_response_schema(
|
||||||
|
model=model, custom_llm_provider="vertex_ai"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
supports_response_schema is False
|
||||||
|
and response_format_chunk["type"] == "json_object"
|
||||||
|
and "response_schema" in response_format_chunk
|
||||||
|
):
|
||||||
|
_is_json_schema = True
|
||||||
|
user_response_schema_message = response_schema_prompt(
|
||||||
|
model=model,
|
||||||
|
response_schema=response_format_chunk["response_schema"],
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
{"role": "user", "content": user_response_schema_message}
|
||||||
|
)
|
||||||
|
messages.append({"role": "assistant", "content": "{"})
|
||||||
|
optional_params.pop("response_format")
|
||||||
# Format rest of message according to anthropic guidelines
|
# Format rest of message according to anthropic guidelines
|
||||||
try:
|
try:
|
||||||
messages = prompt_factory(
|
messages = prompt_factory(
|
||||||
|
@ -233,32 +308,6 @@ def completion(
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
||||||
)
|
)
|
||||||
access_token = None
|
|
||||||
if client is None:
|
|
||||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
|
||||||
import google.oauth2.service_account
|
|
||||||
|
|
||||||
try:
|
|
||||||
json_obj = json.loads(vertex_credentials)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
json_obj = json.load(open(vertex_credentials))
|
|
||||||
|
|
||||||
creds = (
|
|
||||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
|
||||||
json_obj,
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
### CHECK IF ACCESS
|
|
||||||
access_token = refresh_auth(credentials=creds)
|
|
||||||
|
|
||||||
vertex_ai_client = AnthropicVertex(
|
|
||||||
project_id=vertex_project,
|
|
||||||
region=vertex_location,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vertex_ai_client = client
|
|
||||||
|
|
||||||
if acompletion == True:
|
if acompletion == True:
|
||||||
"""
|
"""
|
||||||
|
@ -315,7 +364,16 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
message = vertex_ai_client.messages.create(**data) # type: ignore
|
message = vertex_ai_client.messages.create(**data) # type: ignore
|
||||||
text_content = message.content[0].text
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=message,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
text_content: str = message.content[0].text
|
||||||
## TOOL CALLING - OUTPUT PARSE
|
## TOOL CALLING - OUTPUT PARSE
|
||||||
if text_content is not None and contains_tag("invoke", text_content):
|
if text_content is not None and contains_tag("invoke", text_content):
|
||||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||||
|
@ -339,7 +397,13 @@ def completion(
|
||||||
)
|
)
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
else:
|
else:
|
||||||
model_response.choices[0].message.content = text_content # type: ignore
|
if (
|
||||||
|
_is_json_schema
|
||||||
|
): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
|
||||||
|
json_response = "{" + text_content[: text_content.rfind("}") + 1]
|
||||||
|
model_response.choices[0].message.content = json_response # type: ignore
|
||||||
|
else:
|
||||||
|
model_response.choices[0].message.content = text_content # type: ignore
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
|
|
|
@ -993,10 +993,10 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, supports_response_schema",
|
"model, vertex_location, supports_response_schema",
|
||||||
[
|
[
|
||||||
("vertex_ai_beta/gemini-1.5-pro-001", True),
|
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
|
||||||
("vertex_ai_beta/gemini-1.5-flash", False),
|
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -1005,7 +1005,7 @@ def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pro_json_schema_args_sent_httpx(
|
async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||||
model, supports_response_schema, invalid_response
|
model, supports_response_schema, vertex_location, invalid_response
|
||||||
):
|
):
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
@ -1015,17 +1015,27 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||||
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
|
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
# response_schema = {
|
||||||
|
# "type": "array",
|
||||||
|
# "items": {
|
||||||
|
# "type": "object",
|
||||||
|
# "properties": {
|
||||||
|
# "recipe_name": {
|
||||||
|
# "type": "string",
|
||||||
|
# },
|
||||||
|
# },
|
||||||
|
# "required": ["recipe_name"],
|
||||||
|
# },
|
||||||
|
# }
|
||||||
response_schema = {
|
response_schema = {
|
||||||
"type": "array",
|
"type": "object",
|
||||||
"items": {
|
"properties": {
|
||||||
"type": "object",
|
"recipe_name": {"type": "string"},
|
||||||
"properties": {
|
"ingredients": {"type": "array", "items": {"type": "string"}},
|
||||||
"recipe_name": {
|
"prep_time": {"type": "number"},
|
||||||
"type": "string",
|
"difficulty": {"type": "string", "enum": ["easy", "medium", "hard"]},
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["recipe_name"],
|
|
||||||
},
|
},
|
||||||
|
"required": ["recipe_name", "ingredients", "prep_time"],
|
||||||
}
|
}
|
||||||
|
|
||||||
client = HTTPHandler()
|
client = HTTPHandler()
|
||||||
|
@ -1044,12 +1054,13 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||||
"type": "json_object",
|
"type": "json_object",
|
||||||
"response_schema": response_schema,
|
"response_schema": response_schema,
|
||||||
},
|
},
|
||||||
|
vertex_location=vertex_location,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
if invalid_response is True:
|
if invalid_response is True:
|
||||||
pytest.fail("Expected this to fail")
|
pytest.fail("Expected this to fail")
|
||||||
except litellm.JSONSchemaValidationError as e:
|
except litellm.JSONSchemaValidationError as e:
|
||||||
if invalid_response is False:
|
if invalid_response is False and "claude-3" not in model:
|
||||||
pytest.fail("Expected this to pass. Got={}".format(e))
|
pytest.fail("Expected this to pass. Got={}".format(e))
|
||||||
|
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
|
|
|
@ -995,3 +995,8 @@ class GenericImageParsingChunk(TypedDict):
|
||||||
type: str
|
type: str
|
||||||
media_type: str
|
media_type: str
|
||||||
data: str
|
data: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormatChunk(TypedDict, total=False):
|
||||||
|
type: Required[Literal["json_object", "text"]]
|
||||||
|
response_schema: dict
|
||||||
|
|
|
@ -1351,7 +1351,9 @@ def client(original_function):
|
||||||
).total_seconds() * 1000 # return response latency in ms like openai
|
).total_seconds() * 1000 # return response latency in ms like openai
|
||||||
|
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
post_call_processing(original_response=result, model=model)
|
post_call_processing(
|
||||||
|
original_response=result, model=model, optional_params=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue