Merge pull request #4531 from BerriAI/litellm_vertex_anthropic_tools

fix(vertex_anthropic.py): Vertex Anthropic tool calling - native params
This commit is contained in:
Krish Dholakia 2024-07-03 17:25:32 -07:00 committed by GitHub
commit 318077edea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 139 additions and 38 deletions

View file

@ -426,6 +426,7 @@ class Logging:
self.model_call_details["additional_args"] = additional_args self.model_call_details["additional_args"] = additional_args
self.model_call_details["log_event_type"] = "post_api_call" self.model_call_details["log_event_type"] = "post_api_call"
if json_logs:
verbose_logger.debug( verbose_logger.debug(
"RAW RESPONSE:\n{}\n\n".format( "RAW RESPONSE:\n{}\n\n".format(
self.model_call_details.get( self.model_call_details.get(
@ -433,6 +434,14 @@ class Logging:
) )
), ),
) )
else:
print_verbose(
"RAW RESPONSE:\n{}\n\n".format(
self.model_call_details.get(
"original_response", self.model_call_details
)
)
)
if self.logger_fn and callable(self.logger_fn): if self.logger_fn and callable(self.logger_fn):
try: try:
self.logger_fn( self.logger_fn(

View file

@ -431,20 +431,6 @@ class AnthropicChatCompletion(BaseLLM):
headers={}, headers={},
): ):
data["stream"] = True data["stream"] = True
# async_handler = AsyncHTTPHandler(
# timeout=httpx.Timeout(timeout=600.0, connect=20.0)
# )
# response = await async_handler.post(
# api_base, headers=headers, json=data, stream=True
# )
# if response.status_code != 200:
# raise AnthropicError(
# status_code=response.status_code, message=response.text
# )
# completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=None, completion_stream=None,
@ -484,7 +470,17 @@ class AnthropicChatCompletion(BaseLLM):
headers={}, headers={},
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
async_handler = _get_async_httpx_client() async_handler = _get_async_httpx_client()
try:
response = await async_handler.post(api_base, headers=headers, json=data) response = await async_handler.post(api_base, headers=headers, json=data)
except Exception as e:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"complete_input_dict": data},
)
raise e
if stream and _is_function_call: if stream and _is_function_call:
return self.process_streaming_response( return self.process_streaming_response(
model=model, model=model,
@ -588,13 +584,16 @@ class AnthropicChatCompletion(BaseLLM):
optional_params["tools"] = anthropic_tools optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
data = { data = {
"model": model,
"messages": messages, "messages": messages,
**optional_params, **optional_params,
} }
if is_vertex_request is False:
data["model"] = model
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
@ -678,10 +677,27 @@ class AnthropicChatCompletion(BaseLLM):
return streaming_response return streaming_response
else: else:
try:
response = requests.post( response = requests.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data)
) )
except Exception as e:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"complete_input_dict": data},
)
raise e
if response.status_code != 200: if response.status_code != 200:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
raise AnthropicError( raise AnthropicError(
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.text
) )

View file

@ -15,6 +15,7 @@ import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
from litellm.types.utils import ResponseFormatChunk from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -121,6 +122,17 @@ class VertexAIAnthropicConfig:
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
if param == "tools": if param == "tools":
optional_params["tools"] = value optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream": if param == "stream":
optional_params["stream"] = value optional_params["stream"] = value
if param == "stop": if param == "stop":
@ -177,17 +189,29 @@ def get_vertex_client(
_credentials, cred_project_id = VertexLLM().load_auth( _credentials, cred_project_id = VertexLLM().load_auth(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials, project_id=vertex_project
) )
vertex_ai_client = AnthropicVertex( vertex_ai_client = AnthropicVertex(
project_id=vertex_project or cred_project_id, project_id=vertex_project or cred_project_id,
region=vertex_location or "us-central1", region=vertex_location or "us-central1",
access_token=_credentials.token, access_token=_credentials.token,
) )
access_token = _credentials.token
else: else:
vertex_ai_client = client vertex_ai_client = client
access_token = client.access_token
return vertex_ai_client, access_token return vertex_ai_client, access_token
def create_vertex_anthropic_url(
vertex_location: str, vertex_project: str, model: str, stream: bool
) -> str:
if stream is True:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -196,6 +220,8 @@ def completion(
encoding, encoding,
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,
vertex_credentials=None, vertex_credentials=None,
@ -207,6 +233,9 @@ def completion(
try: try:
import vertexai import vertexai
from anthropic import AnthropicVertex from anthropic import AnthropicVertex
from litellm.llms.anthropic import AnthropicChatCompletion
from litellm.llms.vertex_httpx import VertexLLM
except: except:
raise VertexAIError( raise VertexAIError(
status_code=400, status_code=400,
@ -222,19 +251,58 @@ def completion(
) )
try: try:
vertex_ai_client, access_token = get_vertex_client( vertex_httpx_logic = VertexLLM()
client=client,
vertex_project=vertex_project, access_token, project_id = vertex_httpx_logic._ensure_access_token(
vertex_location=vertex_location, credentials=vertex_credentials, project_id=vertex_project
vertex_credentials=vertex_credentials,
) )
anthropic_chat_completions = AnthropicChatCompletion()
## Load Config ## Load Config
config = litellm.VertexAIAnthropicConfig.get_config() config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: if k not in optional_params:
optional_params[k] = v optional_params[k] = v
## CONSTRUCT API BASE
stream = optional_params.get("stream", False)
api_base = create_vertex_anthropic_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
model=model,
stream=stream,
)
if headers is not None:
vertex_headers = headers
else:
vertex_headers = {}
vertex_headers.update({"Authorization": "Bearer {}".format(access_token)})
optional_params.update(
{"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True}
)
return anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=access_token,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=vertex_headers,
)
## Format Prompt ## Format Prompt
_is_function_call = False _is_function_call = False
_is_json_schema = False _is_json_schema = False
@ -363,6 +431,9 @@ def completion(
}, },
) )
vertex_ai_client: Optional[AnthropicVertex] = None
vertex_ai_client = AnthropicVertex()
if vertex_ai_client is not None:
message = vertex_ai_client.messages.create(**data) # type: ignore message = vertex_ai_client.messages.create(**data) # type: ignore
## LOGGING ## LOGGING

View file

@ -729,6 +729,9 @@ class VertexLLM(BaseLLM):
def load_auth( def load_auth(
self, credentials: Optional[str], project_id: Optional[str] self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[Any, str]: ) -> Tuple[Any, str]:
"""
Returns Credentials, project_id
"""
import google.auth as google_auth import google.auth as google_auth
from google.auth.credentials import Credentials # type: ignore[import-untyped] from google.auth.credentials import Credentials # type: ignore[import-untyped]
from google.auth.transport.requests import ( from google.auth.transport.requests import (
@ -1035,9 +1038,7 @@ class VertexLLM(BaseLLM):
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
"safety_settings", None "safety_settings", None
) # type: ignore ) # type: ignore
cached_content: Optional[str] = optional_params.pop( cached_content: Optional[str] = optional_params.pop("cached_content", None)
"cached_content", None
)
generation_config: Optional[GenerationConfig] = GenerationConfig( generation_config: Optional[GenerationConfig] = GenerationConfig(
**optional_params **optional_params
) )

View file

@ -2008,6 +2008,8 @@ def completion(
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
) )
else: else:
model_response = vertex_ai.completion( model_response = vertex_ai.completion(

View file

@ -637,11 +637,13 @@ def test_gemini_pro_vision_base64():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") # @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.parametrize(
"model", ["vertex_ai_beta/gemini-1.5-pro", "vertex_ai/claude-3-sonnet@20240229"]
) # "vertex_ai",
@pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai", @pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai",
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_pro_function_calling_httpx(provider, sync_mode): async def test_gemini_pro_function_calling_httpx(model, sync_mode):
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
@ -679,7 +681,7 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode):
] ]
data = { data = {
"model": "{}/gemini-1.5-pro".format(provider), "model": model,
"messages": messages, "messages": messages,
"tools": tools, "tools": tools,
"tool_choice": "required", "tool_choice": "required",