From 19c982d0f945be84ba42107f117bf8a497d3590f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 3 Jul 2024 21:55:00 -0700 Subject: [PATCH] fix: linting fixes --- litellm/litellm_core_utils/litellm_logging.py | 21 +++++++--- litellm/llms/anthropic.py | 5 ++- litellm/llms/vertex_ai_anthropic.py | 40 ++++++++++++++++--- litellm/main.py | 2 + .../tests/test_amazing_vertex_completion.py | 10 +++-- 5 files changed, 62 insertions(+), 16 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index add281e43..4edbce5e1 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -426,13 +426,22 @@ class Logging: self.model_call_details["additional_args"] = additional_args self.model_call_details["log_event_type"] = "post_api_call" - verbose_logger.debug( - "RAW RESPONSE:\n{}\n\n".format( - self.model_call_details.get( - "original_response", self.model_call_details + if json_logs: + verbose_logger.debug( + "RAW RESPONSE:\n{}\n\n".format( + self.model_call_details.get( + "original_response", self.model_call_details + ) + ), + ) + else: + print_verbose( + "RAW RESPONSE:\n{}\n\n".format( + self.model_call_details.get( + "original_response", self.model_call_details + ) ) - ), - ) + ) if self.logger_fn and callable(self.logger_fn): try: self.logger_fn( diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index b077a31dc..cff0dad35 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -601,13 +601,16 @@ class AnthropicChatCompletion(BaseLLM): optional_params["tools"] = anthropic_tools stream = optional_params.pop("stream", None) + is_vertex_request: bool = optional_params.pop("is_vertex_request", False) data = { - "model": model, "messages": messages, **optional_params, } + if is_vertex_request is False: + data["model"] = model + ## LOGGING logging_obj.pre_call( input=messages, diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index 99418695b..44a7a448e 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -15,6 +15,7 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.anthropic import AnthropicMessagesToolChoice from litellm.types.utils import ResponseFormatChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage @@ -121,6 +122,17 @@ class VertexAIAnthropicConfig: optional_params["max_tokens"] = value if param == "tools": optional_params["tools"] = value + if param == "tool_choice": + _tool_choice: Optional[AnthropicMessagesToolChoice] = None + if value == "auto": + _tool_choice = {"type": "auto"} + elif value == "required": + _tool_choice = {"type": "any"} + elif isinstance(value, dict): + _tool_choice = {"type": "tool", "name": value["function"]["name"]} + + if _tool_choice is not None: + optional_params["tool_choice"] = _tool_choice if param == "stream": optional_params["stream"] = value if param == "stop": @@ -177,17 +189,29 @@ def get_vertex_client( _credentials, cred_project_id = VertexLLM().load_auth( credentials=vertex_credentials, project_id=vertex_project ) + vertex_ai_client = AnthropicVertex( project_id=vertex_project or cred_project_id, region=vertex_location or "us-central1", access_token=_credentials.token, ) + access_token = _credentials.token else: vertex_ai_client = client + access_token = client.access_token return vertex_ai_client, access_token +def create_vertex_anthropic_url( + vertex_location: str, vertex_project: str, model: str, stream: bool +) -> str: + if stream is True: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict" + else: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict" + + def completion( model: str, messages: list, @@ -196,6 +220,8 @@ def completion( encoding, logging_obj, optional_params: dict, + custom_prompt_dict: dict, + headers: Optional[dict], vertex_project=None, vertex_location=None, vertex_credentials=None, @@ -207,6 +233,9 @@ def completion( try: import vertexai from anthropic import AnthropicVertex + + from litellm.llms.anthropic import AnthropicChatCompletion + from litellm.llms.vertex_httpx import VertexLLM except: raise VertexAIError( status_code=400, @@ -222,13 +251,14 @@ def completion( ) try: - vertex_ai_client, access_token = get_vertex_client( - client=client, - vertex_project=vertex_project, - vertex_location=vertex_location, - vertex_credentials=vertex_credentials, + vertex_httpx_logic = VertexLLM() + + access_token, project_id = vertex_httpx_logic._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project ) + anthropic_chat_completions = AnthropicChatCompletion() + ## Load Config config = litellm.VertexAIAnthropicConfig.get_config() for k, v in config.items(): diff --git a/litellm/main.py b/litellm/main.py index ad91c19ad..9b42b0d07 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2008,6 +2008,8 @@ def completion( vertex_credentials=vertex_credentials, logging_obj=logging, acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, ) else: model_response = vertex_ai.completion( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 7c95c52a5..c4a5ec7ca 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -640,11 +640,13 @@ def test_gemini_pro_vision_base64(): pytest.fail(f"An exception occurred - {str(e)}") -@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") -@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") +@pytest.mark.parametrize( + "model", ["vertex_ai_beta/gemini-1.5-pro", "vertex_ai/claude-3-sonnet@20240229"] +) # "vertex_ai", @pytest.mark.parametrize("sync_mode", [True]) # "vertex_ai", @pytest.mark.asyncio -async def test_gemini_pro_function_calling_httpx(provider, sync_mode): +async def test_gemini_pro_function_calling_httpx(model, sync_mode): try: load_vertex_ai_credentials() litellm.set_verbose = True @@ -682,7 +684,7 @@ async def test_gemini_pro_function_calling_httpx(provider, sync_mode): ] data = { - "model": "{}/gemini-1.5-pro".format(provider), + "model": model, "messages": messages, "tools": tools, "tool_choice": "required",