From 0f9e793daf9f81ba1c55b79f91062c58071cf6d2 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 1 Feb 2024 17:47:34 -0800 Subject: [PATCH 1/6] feat(vertex_ai.py): add support for custom models via vertex ai model garden --- litellm/llms/vertex_ai.py | 236 ++++++++++++++++++++++++++++++++------ 1 file changed, 198 insertions(+), 38 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 56cef9de8..30e0c0e45 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -75,6 +75,41 @@ class VertexAIConfig: } +import asyncio + + +class TextStreamer: + """ + Fake streaming iterator for Vertex AI Model Garden calls + """ + + def __init__(self, text): + self.text = text.split() # let's assume words as a streaming unit + self.index = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.index < len(self.text): + result = self.text[self.index] + self.index += 1 + return result + else: + raise StopIteration + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index < len(self.text): + result = self.text[self.index] + self.index += 1 + return result + else: + raise StopAsyncIteration # once we run out of data to stream, we raise this error + + def _get_image_bytes_from_url(image_url: str) -> bytes: try: response = requests.get(image_url) @@ -236,12 +271,17 @@ def completion( Part, GenerationConfig, ) + from google.cloud import aiplatform + from google.protobuf import json_format # type: ignore + from google.protobuf.struct_pb2 import Value # type: ignore from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types import google.auth ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 creds, _ = google.auth.default(quota_project_id=vertex_project) - vertexai.init(project=vertex_project, location=vertex_location, credentials=creds) + vertexai.init( + project=vertex_project, location=vertex_location, credentials=creds + ) ## Load Config config = litellm.VertexAIConfig.get_config() @@ -275,6 +315,11 @@ def completion( request_str = "" response_obj = None + async_client = None + instances = None + client_options = { + "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" + } if ( model in litellm.vertex_language_models or model in litellm.vertex_vision_models @@ -294,39 +339,51 @@ def completion( llm_model = CodeGenerationModel.from_pretrained(model) mode = "text" request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" - else: # vertex_code_llm_models + elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models llm_model = CodeChatModel.from_pretrained(model) mode = "chat" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" + else: # assume vertex model garden + client = aiplatform.gapic.PredictionServiceClient( + client_options=client_options + ) - if acompletion == True: # [TODO] expand support to vertex ai chat + text models + instances = [optional_params] + instances[0]["prompt"] = prompt + instances = [ + json_format.ParseDict(instance_dict, Value()) + for instance_dict in instances + ] + llm_model = client.endpoint_path( + project=vertex_project, location=vertex_location, endpoint=model + ) + + mode = "custom" + request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n" + + if acompletion == True: + data = { + "llm_model": llm_model, + "mode": mode, + "prompt": prompt, + "logging_obj": logging_obj, + "request_str": request_str, + "model": model, + "model_response": model_response, + "encoding": encoding, + "messages": messages, + "print_verbose": print_verbose, + "client_options": client_options, + "instances": instances, + "vertex_location": vertex_location, + "vertex_project": vertex_project, + **optional_params, + } if optional_params.get("stream", False) is True: # async streaming - return async_streaming( - llm_model=llm_model, - mode=mode, - prompt=prompt, - logging_obj=logging_obj, - request_str=request_str, - model=model, - model_response=model_response, - messages=messages, - print_verbose=print_verbose, - **optional_params, - ) - return async_completion( - llm_model=llm_model, - mode=mode, - prompt=prompt, - logging_obj=logging_obj, - request_str=request_str, - model=model, - model_response=model_response, - encoding=encoding, - messages=messages, - print_verbose=print_verbose, - **optional_params, - ) + return async_streaming(**data) + + return async_completion(**data) if mode == "vision": print_verbose("\nMaking VertexAI Gemini Pro Vision Call") @@ -471,7 +528,36 @@ def completion( }, ) completion_response = llm_model.predict(prompt, **optional_params).text + elif mode == "custom": + """ + Vertex AI Model Garden + """ + request_str += ( + f"client.predict(endpoint={llm_model}, instances={instances})\n" + ) + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + response = client.predict( + endpoint=llm_model, + instances=instances, + ).predictions + completion_response = response[0] + if ( + isinstance(completion_response, str) + and "\nOutput:\n" in completion_response + ): + completion_response = completion_response.split("\nOutput:\n", 1)[1] + if "stream" in optional_params and optional_params["stream"] == True: + response = TextStreamer(completion_response) + return response ## LOGGING logging_obj.post_call( input=prompt, api_key=None, original_response=completion_response @@ -539,6 +625,10 @@ async def async_completion( encoding=None, messages=None, print_verbose=None, + client_options=None, + instances=None, + vertex_project=None, + vertex_location=None, **optional_params, ): """ @@ -627,7 +717,43 @@ async def async_completion( ) response_obj = await llm_model.predict_async(prompt, **optional_params) completion_response = response_obj.text + elif mode == "custom": + """ + Vertex AI Model Garden + """ + from google.cloud import aiplatform + async_client = aiplatform.gapic.PredictionServiceAsyncClient( + client_options=client_options + ) + llm_model = async_client.endpoint_path( + project=vertex_project, location=vertex_location, endpoint=model + ) + + request_str += ( + f"client.predict(endpoint={llm_model}, instances={instances})\n" + ) + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + response_obj = await async_client.predict( + endpoint=llm_model, + instances=instances, + ) + response = response_obj.predictions + completion_response = response[0] + if ( + isinstance(completion_response, str) + and "\nOutput:\n" in completion_response + ): + completion_response = completion_response.split("\nOutput:\n", 1)[1] ## LOGGING logging_obj.post_call( input=prompt, api_key=None, original_response=completion_response @@ -657,14 +783,12 @@ async def async_completion( # init prompt tokens # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 - if response_obj is not None: - if hasattr(response_obj, "usage_metadata") and hasattr( - response_obj.usage_metadata, "prompt_token_count" - ): - prompt_tokens = response_obj.usage_metadata.prompt_token_count - completion_tokens = ( - response_obj.usage_metadata.candidates_token_count - ) + if response_obj is not None and ( + hasattr(response_obj, "usage_metadata") + and hasattr(response_obj.usage_metadata, "prompt_token_count") + ): + prompt_tokens = response_obj.usage_metadata.prompt_token_count + completion_tokens = response_obj.usage_metadata.candidates_token_count else: prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( @@ -695,6 +819,10 @@ async def async_streaming( request_str=None, messages=None, print_verbose=None, + client_options=None, + instances=None, + vertex_project=None, + vertex_location=None, **optional_params, ): """ @@ -763,15 +891,47 @@ async def async_streaming( }, ) response = llm_model.predict_streaming_async(prompt, **optional_params) + elif mode == "custom": + from google.cloud import aiplatform + async_client = aiplatform.gapic.PredictionServiceAsyncClient( + client_options=client_options + ) + llm_model = async_client.endpoint_path( + project=vertex_project, location=vertex_location, endpoint=model + ) + + request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n" + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + response_obj = await async_client.predict( + endpoint=llm_model, + instances=instances, + ) + response = response_obj.predictions + completion_response = response[0] + if ( + isinstance(completion_response, str) + and "\nOutput:\n" in completion_response + ): + completion_response = completion_response.split("\nOutput:\n", 1)[1] + if "stream" in optional_params and optional_params["stream"] == True: + response = TextStreamer(completion_response) streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, custom_llm_provider="vertex_ai", logging_obj=logging_obj, ) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + return streamwrapper def embedding(): From 0072d796f65aa3aea58d471d1b9efb1acc160fc7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 1 Feb 2024 18:09:49 -0800 Subject: [PATCH 2/6] fix(vertex_ai.py): fix params --- litellm/llms/vertex_ai.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 30e0c0e45..9965c037a 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -817,6 +817,7 @@ async def async_streaming( model_response: ModelResponse, logging_obj=None, request_str=None, + encoding=None, messages=None, print_verbose=None, client_options=None, From 241f0aad5ec0c74f265d04f4b64c0bc4184fed94 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 1 Feb 2024 18:46:50 -0800 Subject: [PATCH 3/6] fix(utils.py): fix deepinfra streaming --- litellm/tests/test_streaming.py | 1 + litellm/utils.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 239dae94d..28a9f9902 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -465,6 +465,7 @@ def test_completion_mistral_api_stream(): def test_completion_deep_infra_stream(): # deep infra currently includes role in the 2nd chunk # waiting for them to make a fix on this + litellm.set_verbose = True try: messages = [ {"role": "system", "content": "You are a helpful assistant."}, diff --git a/litellm/utils.py b/litellm/utils.py index d3107d758..95ddb433e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8048,6 +8048,7 @@ class CustomStreamWrapper: if len(original_chunk.choices) > 0: try: delta = dict(original_chunk.choices[0].delta) + print_verbose(f"original delta: {delta}") model_response.choices[0].delta = Delta(**delta) except Exception as e: model_response.choices[0].delta = Delta() @@ -8056,9 +8057,21 @@ class CustomStreamWrapper: model_response.system_fingerprint = ( original_chunk.system_fingerprint ) + print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}") if self.sent_first_chunk == False: model_response.choices[0].delta["role"] = "assistant" self.sent_first_chunk = True + elif self.sent_first_chunk == True and hasattr( + model_response.choices[0].delta, "role" + ): + _initial_delta = model_response.choices[ + 0 + ].delta.model_dump() + _initial_delta.pop("role", None) + model_response.choices[0].delta = Delta(**_initial_delta) + print_verbose( + f"model_response.choices[0].delta: {model_response.choices[0].delta}" + ) else: ## else completion_obj["content"] = model_response_str From 245ec2430eac0809def76ad03500cb2b4b6477b4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 1 Feb 2024 19:05:20 -0800 Subject: [PATCH 4/6] fix(utils.py): fix azure exception mapping --- litellm/tests/test_completion.py | 3 +++ litellm/utils.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 6183c78c4..034abbb80 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -279,6 +279,9 @@ def test_completion_azure_gpt4_vision(): except openai.RateLimitError as e: print("got a rate liimt error", e) pass + except openai.APIStatusError as e: + print("got an api status error", e) + pass except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 95ddb433e..716b52497 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6964,6 +6964,21 @@ def exception_type( llm_provider="azure", response=original_exception.response, ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"AzureException - {original_exception.message}", + model=model, + llm_provider="azure", + response=original_exception.response, + ) + elif original_exception.status_code == 504: # gateway timeout error + exception_mapping_worked = True + raise Timeout( + message=f"AzureException - {original_exception.message}", + model=model, + llm_provider="azure", + ) else: exception_mapping_worked = True raise APIError( From c25c48e8284c3aeee30806d562b19c42fa86d1a4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 1 Feb 2024 19:47:32 -0800 Subject: [PATCH 5/6] test(test_key_generate_prisma.py): fix assert test --- litellm/tests/test_key_generate_prisma.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 9d4318fe7..d8ffcf022 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -1266,9 +1266,7 @@ async def test_user_api_key_auth(prisma_client): pytest.fail(f"This should have failed!. IT's an invalid key") except ProxyException as exc: print(exc.message) - assert ( - exc.message == "Authentication Error, No API Key passed in. api_key is None" - ) + assert exc.message == "Authentication Error, No api key passed in." # Test case: Malformed API Key (missing 'Bearer ' prefix) try: From 5db2d4d895165faee7dffe213694cbdbcf060b61 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 1 Feb 2024 21:32:29 -0800 Subject: [PATCH 6/6] test(test_proxy_server.py): fix health test --- litellm/tests/test_proxy_server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 4e0f706eb..70fef0e06 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -225,9 +225,6 @@ def test_health(client_no_auth): try: response = client_no_auth.get("/health") assert response.status_code == 200 - result = response.json() - print("\n response from health:", result) - assert result["unhealthy_count"] == 0 except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")