diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index add281e43..4edbce5e1 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -426,13 +426,22 @@ class Logging: self.model_call_details["additional_args"] = additional_args self.model_call_details["log_event_type"] = "post_api_call" - verbose_logger.debug( - "RAW RESPONSE:\n{}\n\n".format( - self.model_call_details.get( - "original_response", self.model_call_details + if json_logs: + verbose_logger.debug( + "RAW RESPONSE:\n{}\n\n".format( + self.model_call_details.get( + "original_response", self.model_call_details + ) + ), + ) + else: + print_verbose( + "RAW RESPONSE:\n{}\n\n".format( + self.model_call_details.get( + "original_response", self.model_call_details + ) ) - ), - ) + ) if self.logger_fn and callable(self.logger_fn): try: self.logger_fn( diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 1051a56b7..ce15dd359 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -431,20 +431,6 @@ class AnthropicChatCompletion(BaseLLM): headers={}, ): data["stream"] = True - # async_handler = AsyncHTTPHandler( - # timeout=httpx.Timeout(timeout=600.0, connect=20.0) - # ) - - # response = await async_handler.post( - # api_base, headers=headers, json=data, stream=True - # ) - - # if response.status_code != 200: - # raise AnthropicError( - # status_code=response.status_code, message=response.text - # ) - - # completion_stream = response.aiter_lines() streamwrapper = CustomStreamWrapper( completion_stream=None, @@ -484,7 +470,17 @@ class AnthropicChatCompletion(BaseLLM): headers={}, ) -> Union[ModelResponse, CustomStreamWrapper]: async_handler = _get_async_httpx_client() - response = await async_handler.post(api_base, headers=headers, json=data) + try: + response = await async_handler.post(api_base, headers=headers, json=data) + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"complete_input_dict": data}, + ) + raise e if stream and _is_function_call: return self.process_streaming_response( model=model, @@ -588,13 +584,16 @@ class AnthropicChatCompletion(BaseLLM): optional_params["tools"] = anthropic_tools stream = optional_params.pop("stream", None) + is_vertex_request: bool = optional_params.pop("is_vertex_request", False) data = { - "model": model, "messages": messages, **optional_params, } + if is_vertex_request is False: + data["model"] = model + ## LOGGING logging_obj.pre_call( input=messages, @@ -678,10 +677,27 @@ class AnthropicChatCompletion(BaseLLM): return streaming_response else: - response = requests.post( - api_base, headers=headers, data=json.dumps(data) - ) + try: + response = requests.post( + api_base, headers=headers, data=json.dumps(data) + ) + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"complete_input_dict": data}, + ) + raise e if response.status_code != 200: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) raise AnthropicError( status_code=response.status_code, message=response.text ) diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index 6b39716f1..71dc2aacd 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -15,6 +15,7 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.anthropic import AnthropicMessagesToolChoice from litellm.types.utils import ResponseFormatChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage @@ -121,6 +122,17 @@ class VertexAIAnthropicConfig: optional_params["max_tokens"] = value if param == "tools": optional_params["tools"] = value + if param == "tool_choice": + _tool_choice: Optional[AnthropicMessagesToolChoice] = None + if value == "auto": + _tool_choice = {"type": "auto"} + elif value == "required": + _tool_choice = {"type": "any"} + elif isinstance(value, dict): + _tool_choice = {"type": "tool", "name": value["function"]["name"]} + + if _tool_choice is not None: + optional_params["tool_choice"] = _tool_choice if param == "stream": optional_params["stream"] = value if param == "stop": @@ -177,17 +189,29 @@ def get_vertex_client( _credentials, cred_project_id = VertexLLM().load_auth( credentials=vertex_credentials, project_id=vertex_project ) + vertex_ai_client = AnthropicVertex( project_id=vertex_project or cred_project_id, region=vertex_location or "us-central1", access_token=_credentials.token, ) + access_token = _credentials.token else: vertex_ai_client = client + access_token = client.access_token return vertex_ai_client, access_token +def create_vertex_anthropic_url( + vertex_location: str, vertex_project: str, model: str, stream: bool +) -> str: + if stream is True: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict" + else: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict" + + def completion( model: str, messages: list, @@ -196,6 +220,8 @@ def completion( encoding, logging_obj, optional_params: dict, + custom_prompt_dict: dict, + headers: Optional[dict], vertex_project=None, vertex_location=None, vertex_credentials=None, @@ -207,6 +233,9 @@ def completion( try: import vertexai from anthropic import AnthropicVertex + + from litellm.llms.anthropic import AnthropicChatCompletion + from litellm.llms.vertex_httpx import VertexLLM except: raise VertexAIError( status_code=400, @@ -222,19 +251,58 @@ def completion( ) try: - vertex_ai_client, access_token = get_vertex_client( - client=client, - vertex_project=vertex_project, - vertex_location=vertex_location, - vertex_credentials=vertex_credentials, + vertex_httpx_logic = VertexLLM() + + access_token, project_id = vertex_httpx_logic._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project ) + anthropic_chat_completions = AnthropicChatCompletion() + ## Load Config config = litellm.VertexAIAnthropicConfig.get_config() for k, v in config.items(): if k not in optional_params: optional_params[k] = v + ## CONSTRUCT API BASE + stream = optional_params.get("stream", False) + + api_base = create_vertex_anthropic_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + model=model, + stream=stream, + ) + + if headers is not None: + vertex_headers = headers + else: + vertex_headers = {} + + vertex_headers.update({"Authorization": "Bearer {}".format(access_token)}) + + optional_params.update( + {"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True} + ) + + return anthropic_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=access_token, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=vertex_headers, + ) + ## Format Prompt _is_function_call = False _is_json_schema = False @@ -363,7 +431,10 @@ def completion( }, ) - message = vertex_ai_client.messages.create(**data) # type: ignore + vertex_ai_client: Optional[AnthropicVertex] = None + vertex_ai_client = AnthropicVertex() + if vertex_ai_client is not None: + message = vertex_ai_client.messages.create(**data) # type: ignore ## LOGGING logging_obj.post_call( diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 2ea0e199e..af114f8d8 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -729,6 +729,9 @@ class VertexLLM(BaseLLM): def load_auth( self, credentials: Optional[str], project_id: Optional[str] ) -> Tuple[Any, str]: + """ + Returns Credentials, project_id + """ import google.auth as google_auth from google.auth.credentials import Credentials # type: ignore[import-untyped] from google.auth.transport.requests import ( @@ -1035,9 +1038,7 @@ class VertexLLM(BaseLLM): safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( "safety_settings", None ) # type: ignore - cached_content: Optional[str] = optional_params.pop( - "cached_content", None - ) + cached_content: Optional[str] = optional_params.pop("cached_content", None) generation_config: Optional[GenerationConfig] = GenerationConfig( **optional_params ) diff --git a/litellm/main.py b/litellm/main.py index d6819b5ec..72eeff262 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2008,6 +2008,8 @@ def completion( vertex_credentials=vertex_credentials, logging_obj=logging, acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, ) else: model_response = vertex_ai.completion( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index c4705325b..d8bb6d432 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -637,11 +637,13 @@ def test_gemini_pro_vision_base64(): pytest.fail(f"An exception occurred - {str(e)}") -@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") -@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") +@pytest.mark.parametrize( + "model", ["vertex_ai_beta/gemini-1.5-pro", "vertex_ai/claude-3-sonnet@20240229"] +) # "vertex_ai", @pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai", @pytest.mark.asyncio -async def test_gemini_pro_function_calling_httpx(provider, sync_mode): +async def test_gemini_pro_function_calling_httpx(model, sync_mode): try: load_vertex_ai_credentials() litellm.set_verbose = True @@ -679,7 +681,7 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode): ] data = { - "model": "{}/gemini-1.5-pro".format(provider), + "model": model, "messages": messages, "tools": tools, "tool_choice": "required",