diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 55736772af..f5731618a3 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -828,11 +828,14 @@ def get_response_cost_from_hidden_params( _hidden_params_dict = hidden_params additional_headers = _hidden_params_dict.get("additional_headers", {}) - if additional_headers and "x-litellm-response-cost" in additional_headers: - response_cost = additional_headers["x-litellm-response-cost"] + if ( + additional_headers + and "llm_provider-x-litellm-response-cost" in additional_headers + ): + response_cost = additional_headers["llm_provider-x-litellm-response-cost"] if response_cost is None: return None - return float(additional_headers["x-litellm-response-cost"]) + return float(additional_headers["llm_provider-x-litellm-response-cost"]) return None diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index a3f91fbacc..0d792527b4 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -55,7 +55,9 @@ def get_supports_response_schema( from typing import Literal, Optional -all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"] +all_gemini_url_modes = Literal[ + "chat", "embedding", "batch_embedding", "image_generation" +] def _get_vertex_url( @@ -91,7 +93,11 @@ def _get_vertex_url( if model.isdigit(): # https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" - + elif mode == "image_generation": + endpoint = "predict" + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + if model.isdigit(): + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" if not url or not endpoint: raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}") return url, endpoint @@ -127,6 +133,10 @@ def _get_gemini_url( url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( _gemini_model_name, endpoint, gemini_api_key ) + elif mode == "image_generation": + raise ValueError( + "LiteLLM's `gemini/` route does not support image generation yet. Let us know if you need this feature by opening an issue at https://github.com/BerriAI/litellm/issues" + ) return url, endpoint diff --git a/litellm/llms/vertex_ai/image_generation/image_generation_handler.py b/litellm/llms/vertex_ai/image_generation/image_generation_handler.py index 1d5322c08d..e83f4b6f03 100644 --- a/litellm/llms/vertex_ai/image_generation/image_generation_handler.py +++ b/litellm/llms/vertex_ai/image_generation/image_generation_handler.py @@ -43,22 +43,23 @@ class VertexImageGeneration(VertexLLM): def image_generation( self, prompt: str, + api_base: Optional[str], vertex_project: Optional[str], vertex_location: Optional[str], vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], model_response: ImageResponse, logging_obj: Any, - model: Optional[ - str - ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model client: Optional[Any] = None, optional_params: Optional[dict] = None, timeout: Optional[int] = None, aimg_generation=False, + extra_headers: Optional[dict] = None, ) -> ImageResponse: if aimg_generation is True: return self.aimage_generation( # type: ignore prompt=prompt, + api_base=api_base, vertex_project=vertex_project, vertex_location=vertex_location, vertex_credentials=vertex_credentials, @@ -83,13 +84,27 @@ class VertexImageGeneration(VertexLLM): else: sync_handler = client # type: ignore - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + # url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + auth_header: Optional[str] = None auth_header, _ = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project, custom_llm_provider="vertex_ai", ) + auth_header, api_base = self._get_token_and_url( + model=model, + gemini_api_key=None, + auth_header=auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=False, + custom_llm_provider="vertex_ai", + api_base=api_base, + should_use_v1beta1_features=False, + mode="image_generation", + ) optional_params = optional_params or { "sampleCount": 1 } # default optional params @@ -99,31 +114,21 @@ class VertexImageGeneration(VertexLLM): "parameters": optional_params, } - request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) + headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) logging_obj.pre_call( input=prompt, - api_key=None, + api_key="", additional_args={ "complete_input_dict": optional_params, - "request_str": request_str, + "api_base": api_base, + "headers": headers, }, ) response = sync_handler.post( - url=url, - headers={ - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {auth_header}", - }, + url=api_base, + headers=headers, data=json.dumps(request_data), ) @@ -138,17 +143,17 @@ class VertexImageGeneration(VertexLLM): async def aimage_generation( self, prompt: str, + api_base: Optional[str], vertex_project: Optional[str], vertex_location: Optional[str], vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], model_response: litellm.ImageResponse, logging_obj: Any, - model: Optional[ - str - ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model client: Optional[AsyncHTTPHandler] = None, optional_params: Optional[dict] = None, timeout: Optional[int] = None, + extra_headers: Optional[dict] = None, ): response = None if client is None: @@ -169,7 +174,6 @@ class VertexImageGeneration(VertexLLM): # make POST request to # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" """ Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 @@ -188,11 +192,25 @@ class VertexImageGeneration(VertexLLM): } \ "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" """ + auth_header: Optional[str] = None auth_header, _ = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project, custom_llm_provider="vertex_ai", ) + auth_header, api_base = self._get_token_and_url( + model=model, + gemini_api_key=None, + auth_header=auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=False, + custom_llm_provider="vertex_ai", + api_base=api_base, + should_use_v1beta1_features=False, + mode="image_generation", + ) optional_params = optional_params or { "sampleCount": 1 } # default optional params @@ -202,22 +220,21 @@ class VertexImageGeneration(VertexLLM): "parameters": optional_params, } - request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) + logging_obj.pre_call( input=prompt, - api_key=None, + api_key="", additional_args={ "complete_input_dict": optional_params, - "request_str": request_str, + "api_base": api_base, + "headers": headers, }, ) response = await self.async_handler.post( - url=url, - headers={ - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {auth_header}", - }, + url=api_base, + headers=headers, data=json.dumps(request_data), ) diff --git a/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py index 3ef40703e8..2e8051d4d2 100644 --- a/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py @@ -111,7 +111,7 @@ class VertexEmbedding(VertexBase): ) try: - response = client.post(api_base, headers=headers, json=vertex_request) # type: ignore + response = client.post(url=api_base, headers=headers, json=vertex_request) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code diff --git a/litellm/main.py b/litellm/main.py index 3d4152d634..94e19aab0c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2350,6 +2350,8 @@ def completion( # type: ignore # noqa: PLR0915 or litellm.api_key ) + api_base = api_base or litellm.api_base or get_secret("GEMINI_API_BASE") + new_params = deepcopy(optional_params) response = vertex_chat_completion.completion( # type: ignore model=model, @@ -2392,6 +2394,8 @@ def completion( # type: ignore # noqa: PLR0915 or get_secret("VERTEXAI_CREDENTIALS") ) + api_base = api_base or litellm.api_base or get_secret("VERTEXAI_API_BASE") + new_params = deepcopy(optional_params) if ( model.startswith("meta/") @@ -3657,6 +3661,8 @@ def embedding( # noqa: PLR0915 api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key ) + api_base = api_base or litellm.api_base or get_secret_str("GEMINI_API_BASE") + response = google_batch_embeddings.batch_embeddings( # type: ignore model=model, input=input, @@ -3671,6 +3677,8 @@ def embedding( # noqa: PLR0915 print_verbose=print_verbose, custom_llm_provider="gemini", api_key=gemini_api_key, + api_base=api_base, + client=client, ) elif custom_llm_provider == "vertex_ai": @@ -3695,6 +3703,13 @@ def embedding( # noqa: PLR0915 or get_secret_str("VERTEX_CREDENTIALS") ) + api_base = ( + api_base + or litellm.api_base + or get_secret_str("VERTEXAI_API_BASE") + or get_secret_str("VERTEX_API_BASE") + ) + if ( "image" in optional_params or "video" in optional_params @@ -3716,6 +3731,7 @@ def embedding( # noqa: PLR0915 print_verbose=print_verbose, custom_llm_provider="vertex_ai", client=client, + api_base=api_base, ) else: response = vertex_embedding.embedding( @@ -3733,6 +3749,8 @@ def embedding( # noqa: PLR0915 aembedding=aembedding, print_verbose=print_verbose, api_key=api_key, + api_base=api_base, + client=client, ) elif custom_llm_provider == "oobabooga": response = oobabooga.embedding( @@ -4695,6 +4713,14 @@ def image_generation( # noqa: PLR0915 or optional_params.pop("vertex_ai_credentials", None) or get_secret_str("VERTEXAI_CREDENTIALS") ) + + api_base = ( + api_base + or litellm.api_base + or get_secret_str("VERTEXAI_API_BASE") + or get_secret_str("VERTEX_API_BASE") + ) + model_response = vertex_image_generation.image_generation( model=model, prompt=prompt, @@ -4706,6 +4732,8 @@ def image_generation( # noqa: PLR0915 vertex_location=vertex_ai_location, vertex_credentials=vertex_credentials, aimg_generation=aimg_generation, + api_base=api_base, + client=client, ) elif ( custom_llm_provider in litellm._custom_providers diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index cf09749d81..09db9f10ee 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,5 +1,5 @@ model_list: - - model_name: "gpt-3.5-turbo" + - model_name: "gpt-4o" litellm_params: model: azure/chatgpt-v-2 api_key: os.environ/AZURE_API_KEY diff --git a/tests/litellm/test_cost_calculator.py b/tests/litellm/test_cost_calculator.py index 9c9f6d9043..c0073e2c56 100644 --- a/tests/litellm/test_cost_calculator.py +++ b/tests/litellm/test_cost_calculator.py @@ -15,9 +15,11 @@ from pydantic import BaseModel from litellm.cost_calculator import response_cost_calculator -def test_cost_calculator(): +def test_cost_calculator_with_response_cost_in_additional_headers(): class MockResponse(BaseModel): - _hidden_params = {"additional_headers": {"x-litellm-response-cost": 1000}} + _hidden_params = { + "additional_headers": {"llm_provider-x-litellm-response-cost": 1000} + } result = response_cost_calculator( response_object=MockResponse(), diff --git a/tests/llm_translation/test_litellm_proxy_provider.py b/tests/llm_translation/test_litellm_proxy_provider.py index 8484a66dad..c38e386063 100644 --- a/tests/llm_translation/test_litellm_proxy_provider.py +++ b/tests/llm_translation/test_litellm_proxy_provider.py @@ -31,7 +31,7 @@ async def test_litellm_gateway_from_sdk(): openai_client = OpenAI(api_key="fake-key") with patch.object( - openai_client.chat.completions, "create", new=MagicMock() + openai_client.chat.completions.with_raw_response, "create", new=MagicMock() ) as mock_call: try: completion( @@ -374,3 +374,78 @@ async def test_litellm_gateway_from_sdk_rerank(is_async): assert request_body["query"] == "What is machine learning?" assert request_body["model"] == "rerank-english-v2.0" assert len(request_body["documents"]) == 2 + + +def test_litellm_gateway_from_sdk_with_response_cost_in_additional_headers(): + litellm.set_verbose = True + litellm._turn_on_debug() + + from openai import OpenAI + + openai_client = OpenAI(api_key="fake-key") + + # Create mock response object + mock_response = MagicMock() + mock_response.headers = {"x-litellm-response-cost": "120"} + mock_response.parse.return_value = litellm.ModelResponse( + **{ + "id": "chatcmpl-BEkxQvRGp9VAushfAsOZCbhMFLsoy", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": None, + "message": { + "content": "Hello! How can I assist you today?", + "refusal": None, + "role": "assistant", + "annotations": [], + "audio": None, + "function_call": None, + "tool_calls": None, + }, + } + ], + "created": 1742856796, + "model": "gpt-4o-2024-08-06", + "object": "chat.completion", + "service_tier": "default", + "system_fingerprint": "fp_6ec83003ad", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 9, + "total_tokens": 19, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + }, + } + ) + + with patch.object( + openai_client.chat.completions.with_raw_response, + "create", + return_value=mock_response, + ) as mock_call: + response = litellm.completion( + model="litellm_proxy/gpt-4o", + messages=[{"role": "user", "content": "Hello world"}], + api_base="http://0.0.0.0:4000", + api_key="sk-PIp1h0RekR", + client=openai_client, + ) + + # Assert the headers were properly passed through + print(f"additional_headers: {response._hidden_params['additional_headers']}") + assert ( + response._hidden_params["additional_headers"][ + "llm_provider-x-litellm-response-cost" + ] + == "120" + ) + + assert response._hidden_params["response_cost"] == 120 diff --git a/tests/load_tests/test_vertex_embeddings_load_test.py b/tests/load_tests/test_vertex_embeddings_load_test.py index eb440c9437..24543e29d0 100644 --- a/tests/load_tests/test_vertex_embeddings_load_test.py +++ b/tests/load_tests/test_vertex_embeddings_load_test.py @@ -109,12 +109,13 @@ def analyze_results(vertex_times): @pytest.mark.asyncio -async def test_embedding_performance(): +async def test_embedding_performance(monkeypatch): """ Run load test on vertex AI embeddings to ensure vertex median response time is less than 300ms 20 RPS for 20 seconds """ + monkeypatch.setattr(litellm, "api_base", None) duration_seconds = 20 requests_per_second = 20 vertex_times = await run_load_test(duration_seconds, requests_per_second) diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 25993d6d5b..5e3ebf6a66 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -31,6 +31,7 @@ from litellm import ( completion, completion_cost, embedding, + image_generation, ) from litellm.llms.vertex_ai.gemini.transformation import ( _gemini_convert_messages_with_history, @@ -3327,3 +3328,46 @@ def test_signed_s3_url_with_format(): json_str = json.dumps(mock_client.call_args.kwargs["json"]) assert "image/jpeg" in json_str assert "image/png" not in json_str + + +@pytest.mark.parametrize("provider", ["vertex_ai", "gemini"]) +@pytest.mark.parametrize("route", ["completion", "embedding", "image_generation"]) +def test_litellm_api_base(monkeypatch, provider, route): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + + import litellm + + monkeypatch.setattr(litellm, "api_base", "https://litellm.com") + + load_vertex_ai_credentials() + + if route == "image_generation" and provider == "gemini": + pytest.skip("Gemini does not support image generation") + + with patch.object(client, "post", new=MagicMock()) as mock_client: + try: + if route == "completion": + response = completion( + model=f"{provider}/gemini-2.0-flash-001", + messages=[{"role": "user", "content": "Hello, world!"}], + client=client, + ) + elif route == "embedding": + response = embedding( + model=f"{provider}/gemini-2.0-flash-001", + input=["Hello, world!"], + client=client, + ) + elif route == "image_generation": + response = image_generation( + model=f"{provider}/gemini-2.0-flash-001", + prompt="Hello, world!", + client=client, + ) + except Exception as e: + print(e) + + mock_client.assert_called() + assert mock_client.call_args.kwargs["url"].startswith("https://litellm.com")