diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 4edbce5e1..add281e43 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -426,22 +426,13 @@ class Logging: self.model_call_details["additional_args"] = additional_args self.model_call_details["log_event_type"] = "post_api_call" - if json_logs: - verbose_logger.debug( - "RAW RESPONSE:\n{}\n\n".format( - self.model_call_details.get( - "original_response", self.model_call_details - ) - ), - ) - else: - print_verbose( - "RAW RESPONSE:\n{}\n\n".format( - self.model_call_details.get( - "original_response", self.model_call_details - ) + verbose_logger.debug( + "RAW RESPONSE:\n{}\n\n".format( + self.model_call_details.get( + "original_response", self.model_call_details ) - ) + ), + ) if self.logger_fn and callable(self.logger_fn): try: self.logger_fn( diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index ce15dd359..1051a56b7 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -431,6 +431,20 @@ class AnthropicChatCompletion(BaseLLM): headers={}, ): data["stream"] = True + # async_handler = AsyncHTTPHandler( + # timeout=httpx.Timeout(timeout=600.0, connect=20.0) + # ) + + # response = await async_handler.post( + # api_base, headers=headers, json=data, stream=True + # ) + + # if response.status_code != 200: + # raise AnthropicError( + # status_code=response.status_code, message=response.text + # ) + + # completion_stream = response.aiter_lines() streamwrapper = CustomStreamWrapper( completion_stream=None, @@ -470,17 +484,7 @@ class AnthropicChatCompletion(BaseLLM): headers={}, ) -> Union[ModelResponse, CustomStreamWrapper]: async_handler = _get_async_httpx_client() - try: - response = await async_handler.post(api_base, headers=headers, json=data) - except Exception as e: - ## LOGGING - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=str(e), - additional_args={"complete_input_dict": data}, - ) - raise e + response = await async_handler.post(api_base, headers=headers, json=data) if stream and _is_function_call: return self.process_streaming_response( model=model, @@ -584,16 +588,13 @@ class AnthropicChatCompletion(BaseLLM): optional_params["tools"] = anthropic_tools stream = optional_params.pop("stream", None) - is_vertex_request: bool = optional_params.pop("is_vertex_request", False) data = { + "model": model, "messages": messages, **optional_params, } - if is_vertex_request is False: - data["model"] = model - ## LOGGING logging_obj.pre_call( input=messages, @@ -677,27 +678,10 @@ class AnthropicChatCompletion(BaseLLM): return streaming_response else: - try: - response = requests.post( - api_base, headers=headers, data=json.dumps(data) - ) - except Exception as e: - ## LOGGING - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=str(e), - additional_args={"complete_input_dict": data}, - ) - raise e + response = requests.post( + api_base, headers=headers, data=json.dumps(data) + ) if response.status_code != 200: - ## LOGGING - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) raise AnthropicError( status_code=response.status_code, message=response.text ) diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index 71dc2aacd..6b39716f1 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -15,7 +15,6 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.llms.anthropic import AnthropicMessagesToolChoice from litellm.types.utils import ResponseFormatChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage @@ -122,17 +121,6 @@ class VertexAIAnthropicConfig: optional_params["max_tokens"] = value if param == "tools": optional_params["tools"] = value - if param == "tool_choice": - _tool_choice: Optional[AnthropicMessagesToolChoice] = None - if value == "auto": - _tool_choice = {"type": "auto"} - elif value == "required": - _tool_choice = {"type": "any"} - elif isinstance(value, dict): - _tool_choice = {"type": "tool", "name": value["function"]["name"]} - - if _tool_choice is not None: - optional_params["tool_choice"] = _tool_choice if param == "stream": optional_params["stream"] = value if param == "stop": @@ -189,29 +177,17 @@ def get_vertex_client( _credentials, cred_project_id = VertexLLM().load_auth( credentials=vertex_credentials, project_id=vertex_project ) - vertex_ai_client = AnthropicVertex( project_id=vertex_project or cred_project_id, region=vertex_location or "us-central1", access_token=_credentials.token, ) - access_token = _credentials.token else: vertex_ai_client = client - access_token = client.access_token return vertex_ai_client, access_token -def create_vertex_anthropic_url( - vertex_location: str, vertex_project: str, model: str, stream: bool -) -> str: - if stream is True: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict" - else: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict" - - def completion( model: str, messages: list, @@ -220,8 +196,6 @@ def completion( encoding, logging_obj, optional_params: dict, - custom_prompt_dict: dict, - headers: Optional[dict], vertex_project=None, vertex_location=None, vertex_credentials=None, @@ -233,9 +207,6 @@ def completion( try: import vertexai from anthropic import AnthropicVertex - - from litellm.llms.anthropic import AnthropicChatCompletion - from litellm.llms.vertex_httpx import VertexLLM except: raise VertexAIError( status_code=400, @@ -251,58 +222,19 @@ def completion( ) try: - vertex_httpx_logic = VertexLLM() - - access_token, project_id = vertex_httpx_logic._ensure_access_token( - credentials=vertex_credentials, project_id=vertex_project + vertex_ai_client, access_token = get_vertex_client( + client=client, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, ) - anthropic_chat_completions = AnthropicChatCompletion() - ## Load Config config = litellm.VertexAIAnthropicConfig.get_config() for k, v in config.items(): if k not in optional_params: optional_params[k] = v - ## CONSTRUCT API BASE - stream = optional_params.get("stream", False) - - api_base = create_vertex_anthropic_url( - vertex_location=vertex_location or "us-central1", - vertex_project=vertex_project or project_id, - model=model, - stream=stream, - ) - - if headers is not None: - vertex_headers = headers - else: - vertex_headers = {} - - vertex_headers.update({"Authorization": "Bearer {}".format(access_token)}) - - optional_params.update( - {"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True} - ) - - return anthropic_chat_completions.completion( - model=model, - messages=messages, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=access_token, - logging_obj=logging_obj, - optional_params=optional_params, - acompletion=acompletion, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=vertex_headers, - ) - ## Format Prompt _is_function_call = False _is_json_schema = False @@ -431,10 +363,7 @@ def completion( }, ) - vertex_ai_client: Optional[AnthropicVertex] = None - vertex_ai_client = AnthropicVertex() - if vertex_ai_client is not None: - message = vertex_ai_client.messages.create(**data) # type: ignore + message = vertex_ai_client.messages.create(**data) # type: ignore ## LOGGING logging_obj.post_call( diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index af114f8d8..2ea0e199e 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -729,9 +729,6 @@ class VertexLLM(BaseLLM): def load_auth( self, credentials: Optional[str], project_id: Optional[str] ) -> Tuple[Any, str]: - """ - Returns Credentials, project_id - """ import google.auth as google_auth from google.auth.credentials import Credentials # type: ignore[import-untyped] from google.auth.transport.requests import ( @@ -1038,7 +1035,9 @@ class VertexLLM(BaseLLM): safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( "safety_settings", None ) # type: ignore - cached_content: Optional[str] = optional_params.pop("cached_content", None) + cached_content: Optional[str] = optional_params.pop( + "cached_content", None + ) generation_config: Optional[GenerationConfig] = GenerationConfig( **optional_params ) diff --git a/litellm/main.py b/litellm/main.py index 72eeff262..d6819b5ec 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2008,8 +2008,6 @@ def completion( vertex_credentials=vertex_credentials, logging_obj=logging, acompletion=acompletion, - headers=headers, - custom_prompt_dict=custom_prompt_dict, ) else: model_response = vertex_ai.completion( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index d8bb6d432..c4705325b 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -637,13 +637,11 @@ def test_gemini_pro_vision_base64(): pytest.fail(f"An exception occurred - {str(e)}") -# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") -@pytest.mark.parametrize( - "model", ["vertex_ai_beta/gemini-1.5-pro", "vertex_ai/claude-3-sonnet@20240229"] -) # "vertex_ai", +@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") +@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai", @pytest.mark.asyncio -async def test_gemini_pro_function_calling_httpx(model, sync_mode): +async def test_gemini_pro_function_calling_httpx(provider, sync_mode): try: load_vertex_ai_credentials() litellm.set_verbose = True @@ -681,7 +679,7 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode): ] data = { - "model": model, + "model": "{}/gemini-1.5-pro".format(provider), "messages": messages, "tools": tools, "tool_choice": "required",