forked from phoenix/litellm-mirror
Merge pull request #4531 from BerriAI/litellm_vertex_anthropic_tools
fix(vertex_anthropic.py): Vertex Anthropic tool calling - native params
This commit is contained in:
commit
318077edea
6 changed files with 139 additions and 38 deletions
|
@ -426,13 +426,22 @@ class Logging:
|
||||||
self.model_call_details["additional_args"] = additional_args
|
self.model_call_details["additional_args"] = additional_args
|
||||||
self.model_call_details["log_event_type"] = "post_api_call"
|
self.model_call_details["log_event_type"] = "post_api_call"
|
||||||
|
|
||||||
verbose_logger.debug(
|
if json_logs:
|
||||||
"RAW RESPONSE:\n{}\n\n".format(
|
verbose_logger.debug(
|
||||||
self.model_call_details.get(
|
"RAW RESPONSE:\n{}\n\n".format(
|
||||||
"original_response", self.model_call_details
|
self.model_call_details.get(
|
||||||
|
"original_response", self.model_call_details
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print_verbose(
|
||||||
|
"RAW RESPONSE:\n{}\n\n".format(
|
||||||
|
self.model_call_details.get(
|
||||||
|
"original_response", self.model_call_details
|
||||||
|
)
|
||||||
)
|
)
|
||||||
),
|
)
|
||||||
)
|
|
||||||
if self.logger_fn and callable(self.logger_fn):
|
if self.logger_fn and callable(self.logger_fn):
|
||||||
try:
|
try:
|
||||||
self.logger_fn(
|
self.logger_fn(
|
||||||
|
|
|
@ -431,20 +431,6 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
headers={},
|
headers={},
|
||||||
):
|
):
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
# async_handler = AsyncHTTPHandler(
|
|
||||||
# timeout=httpx.Timeout(timeout=600.0, connect=20.0)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# response = await async_handler.post(
|
|
||||||
# api_base, headers=headers, json=data, stream=True
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if response.status_code != 200:
|
|
||||||
# raise AnthropicError(
|
|
||||||
# status_code=response.status_code, message=response.text
|
|
||||||
# )
|
|
||||||
|
|
||||||
# completion_stream = response.aiter_lines()
|
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=None,
|
completion_stream=None,
|
||||||
|
@ -484,7 +470,17 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
headers={},
|
headers={},
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
async_handler = _get_async_httpx_client()
|
async_handler = _get_async_httpx_client()
|
||||||
response = await async_handler.post(api_base, headers=headers, json=data)
|
try:
|
||||||
|
response = await async_handler.post(api_base, headers=headers, json=data)
|
||||||
|
except Exception as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=str(e),
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
raise e
|
||||||
if stream and _is_function_call:
|
if stream and _is_function_call:
|
||||||
return self.process_streaming_response(
|
return self.process_streaming_response(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -588,13 +584,16 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
optional_params["tools"] = anthropic_tools
|
optional_params["tools"] = anthropic_tools
|
||||||
|
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
|
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if is_vertex_request is False:
|
||||||
|
data["model"] = model
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -678,10 +677,27 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
return streaming_response
|
return streaming_response
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = requests.post(
|
try:
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
response = requests.post(
|
||||||
)
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=str(e),
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
raise e
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
status_code=response.status_code, message=response.text
|
status_code=response.status_code, message=response.text
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,6 +15,7 @@ import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
|
||||||
from litellm.types.utils import ResponseFormatChunk
|
from litellm.types.utils import ResponseFormatChunk
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
|
@ -121,6 +122,17 @@ class VertexAIAnthropicConfig:
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
if param == "tools":
|
if param == "tools":
|
||||||
optional_params["tools"] = value
|
optional_params["tools"] = value
|
||||||
|
if param == "tool_choice":
|
||||||
|
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
|
||||||
|
if value == "auto":
|
||||||
|
_tool_choice = {"type": "auto"}
|
||||||
|
elif value == "required":
|
||||||
|
_tool_choice = {"type": "any"}
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
|
||||||
|
|
||||||
|
if _tool_choice is not None:
|
||||||
|
optional_params["tool_choice"] = _tool_choice
|
||||||
if param == "stream":
|
if param == "stream":
|
||||||
optional_params["stream"] = value
|
optional_params["stream"] = value
|
||||||
if param == "stop":
|
if param == "stop":
|
||||||
|
@ -177,17 +189,29 @@ def get_vertex_client(
|
||||||
_credentials, cred_project_id = VertexLLM().load_auth(
|
_credentials, cred_project_id = VertexLLM().load_auth(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
)
|
)
|
||||||
|
|
||||||
vertex_ai_client = AnthropicVertex(
|
vertex_ai_client = AnthropicVertex(
|
||||||
project_id=vertex_project or cred_project_id,
|
project_id=vertex_project or cred_project_id,
|
||||||
region=vertex_location or "us-central1",
|
region=vertex_location or "us-central1",
|
||||||
access_token=_credentials.token,
|
access_token=_credentials.token,
|
||||||
)
|
)
|
||||||
|
access_token = _credentials.token
|
||||||
else:
|
else:
|
||||||
vertex_ai_client = client
|
vertex_ai_client = client
|
||||||
|
access_token = client.access_token
|
||||||
|
|
||||||
return vertex_ai_client, access_token
|
return vertex_ai_client, access_token
|
||||||
|
|
||||||
|
|
||||||
|
def create_vertex_anthropic_url(
|
||||||
|
vertex_location: str, vertex_project: str, model: str, stream: bool
|
||||||
|
) -> str:
|
||||||
|
if stream is True:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
|
||||||
|
else:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
@ -196,6 +220,8 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
headers: Optional[dict],
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
vertex_credentials=None,
|
vertex_credentials=None,
|
||||||
|
@ -207,6 +233,9 @@ def completion(
|
||||||
try:
|
try:
|
||||||
import vertexai
|
import vertexai
|
||||||
from anthropic import AnthropicVertex
|
from anthropic import AnthropicVertex
|
||||||
|
|
||||||
|
from litellm.llms.anthropic import AnthropicChatCompletion
|
||||||
|
from litellm.llms.vertex_httpx import VertexLLM
|
||||||
except:
|
except:
|
||||||
raise VertexAIError(
|
raise VertexAIError(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
@ -222,19 +251,58 @@ def completion(
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
||||||
vertex_ai_client, access_token = get_vertex_client(
|
vertex_httpx_logic = VertexLLM()
|
||||||
client=client,
|
|
||||||
vertex_project=vertex_project,
|
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||||
vertex_location=vertex_location,
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
vertex_credentials=vertex_credentials,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.VertexAIAnthropicConfig.get_config()
|
config = litellm.VertexAIAnthropicConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
if k not in optional_params:
|
if k not in optional_params:
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
|
## CONSTRUCT API BASE
|
||||||
|
stream = optional_params.get("stream", False)
|
||||||
|
|
||||||
|
api_base = create_vertex_anthropic_url(
|
||||||
|
vertex_location=vertex_location or "us-central1",
|
||||||
|
vertex_project=vertex_project or project_id,
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if headers is not None:
|
||||||
|
vertex_headers = headers
|
||||||
|
else:
|
||||||
|
vertex_headers = {}
|
||||||
|
|
||||||
|
vertex_headers.update({"Authorization": "Bearer {}".format(access_token)})
|
||||||
|
|
||||||
|
optional_params.update(
|
||||||
|
{"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
return anthropic_chat_completions.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_base=api_base,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=access_token,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
acompletion=acompletion,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=vertex_headers,
|
||||||
|
)
|
||||||
|
|
||||||
## Format Prompt
|
## Format Prompt
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
_is_json_schema = False
|
_is_json_schema = False
|
||||||
|
@ -363,7 +431,10 @@ def completion(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
message = vertex_ai_client.messages.create(**data) # type: ignore
|
vertex_ai_client: Optional[AnthropicVertex] = None
|
||||||
|
vertex_ai_client = AnthropicVertex()
|
||||||
|
if vertex_ai_client is not None:
|
||||||
|
message = vertex_ai_client.messages.create(**data) # type: ignore
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
|
|
@ -729,6 +729,9 @@ class VertexLLM(BaseLLM):
|
||||||
def load_auth(
|
def load_auth(
|
||||||
self, credentials: Optional[str], project_id: Optional[str]
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
) -> Tuple[Any, str]:
|
) -> Tuple[Any, str]:
|
||||||
|
"""
|
||||||
|
Returns Credentials, project_id
|
||||||
|
"""
|
||||||
import google.auth as google_auth
|
import google.auth as google_auth
|
||||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||||
from google.auth.transport.requests import (
|
from google.auth.transport.requests import (
|
||||||
|
@ -1035,9 +1038,7 @@ class VertexLLM(BaseLLM):
|
||||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||||
"safety_settings", None
|
"safety_settings", None
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
cached_content: Optional[str] = optional_params.pop(
|
cached_content: Optional[str] = optional_params.pop("cached_content", None)
|
||||||
"cached_content", None
|
|
||||||
)
|
|
||||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
|
|
|
@ -2008,6 +2008,8 @@ def completion(
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
|
headers=headers,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_response = vertex_ai.completion(
|
model_response = vertex_ai.completion(
|
||||||
|
|
|
@ -637,11 +637,13 @@ def test_gemini_pro_vision_base64():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
@pytest.mark.parametrize(
|
||||||
|
"model", ["vertex_ai_beta/gemini-1.5-pro", "vertex_ai/claude-3-sonnet@20240229"]
|
||||||
|
) # "vertex_ai",
|
||||||
@pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai",
|
@pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai",
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
|
async def test_gemini_pro_function_calling_httpx(model, sync_mode):
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -679,7 +681,7 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
|
||||||
]
|
]
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": "{}/gemini-1.5-pro".format(provider),
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
"tool_choice": "required",
|
"tool_choice": "required",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue