diff --git a/litellm/main.py b/litellm/main.py index bd40e1285..4dd13efef 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -638,12 +638,12 @@ def completion( ) return response response = model_response - elif model in litellm.vertex_chat_models: + elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models: try: import vertexai except: raise Exception("vertexai import failed please run `pip install google-cloud-aiplatform`") - from vertexai.preview.language_models import ChatModel, InputOutputTextPair + from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair vertexai.init( project=litellm.vertex_project, location=litellm.vertex_location @@ -653,8 +653,10 @@ def completion( prompt = " ".join([message["content"] for message in messages]) ## LOGGING logging.pre_call(input=prompt, api_key=None) - - chat_model = ChatModel.from_pretrained(model) + if model in litellm.vertex_chat_models: + chat_model = ChatModel.from_pretrained(model) + else: # vertex_code_chat_models + chat_model = CodeChatModel.from_pretrained(model) chat = chat_model.start_chat() completion_response = chat.send_message(prompt, **optional_params) @@ -669,12 +671,12 @@ def completion( model_response["created"] = time.time() model_response["model"] = model response = model_response - elif model in litellm.vertex_text_models: + elif model in litellm.vertex_text_models or model in litellm.vertex_code_text_models: try: import vertexai except: raise Exception("vertexai import failed please run `pip install google-cloud-aiplatform`") - from vertexai.language_models import TextGenerationModel + from vertexai.language_models import TextGenerationModel, CodeGenerationModel vertexai.init( project=litellm.vertex_project, location=litellm.vertex_location @@ -684,8 +686,12 @@ def completion( prompt = " ".join([message["content"] for message in messages]) ## LOGGING logging.pre_call(input=prompt, api_key=None) + + if model in litellm.vertex_text_models: + vertex_model = TextGenerationModel.from_pretrained(model) + else: + vertex_model = CodeGenerationModel.from_pretrained(model) - vertex_model = TextGenerationModel.from_pretrained(model) completion_response = vertex_model.predict(prompt, **optional_params) ## LOGGING diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 084fcc03a..2946e4593 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -12,9 +12,6 @@ import pytest import litellm from litellm import embedding, completion, text_completion, completion_cost -litellm.vertex_project = "pathrise-convert-1606954137718" -litellm.vertex_location = "us-central1" - user_message = "Write a short poem about the sky" messages = [{"content": user_message, "role": "user"}] @@ -647,7 +644,7 @@ def test_completion_bedrock_titan(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_bedrock_titan() +# test_completion_bedrock_titan() def test_completion_bedrock_ai21(): @@ -663,7 +660,7 @@ def test_completion_bedrock_ai21(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_bedrock_ai21() +# test_completion_bedrock_ai21() # test_completion_sagemaker() @@ -725,28 +722,17 @@ test_completion_bedrock_ai21() # test_completion_custom_api_base() # def test_vertex_ai(): -# model_name = "chat-bison" -# try: -# response = completion(model=model_name, messages=messages, logger_fn=logger_fn) -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") - - -# def test_petals(): -# model_name = "stabilityai/StableBeluga2" -# try: -# response = completion( -# model=model_name, -# messages=messages, -# custom_llm_provider="petals", -# force_timeout=120, -# ) -# # Add any assertions here to check the response -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") - +# litellm.vertex_project = "hardy-device-386718" +# litellm.vertex_location = "us-central1" +# test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models +# for model in test_models: +# try: +# print("making request", model) +# response = completion(model=model, messages=[{"role": "user", "content": "write code for saying hi"}]) +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") +# test_vertex_ai() def test_completion_with_fallbacks():