diff --git a/litellm/main.py b/litellm/main.py index 4dd13efef..518da5afa 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -659,6 +659,14 @@ def completion( chat_model = CodeChatModel.from_pretrained(model) chat = chat_model.start_chat() + + if stream: + model_response = chat.send_message_streaming(prompt, **optional_params) + response = CustomStreamWrapper( + model_response, model, custom_llm_provider="vertexai", logging_obj=logging + ) + return response + completion_response = chat.send_message(prompt, **optional_params) ## LOGGING @@ -692,6 +700,13 @@ def completion( else: vertex_model = CodeGenerationModel.from_pretrained(model) + if stream: + model_response = vertex_model.predict_streaming(prompt, **optional_params) + response = CustomStreamWrapper( + model_response, model, custom_llm_provider="vertexai", logging_obj=logging + ) + return response + completion_response = vertex_model.predict(prompt, **optional_params) ## LOGGING diff --git a/litellm/utils.py b/litellm/utils.py index 5f8bdbb75..c23b1e495 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -885,8 +885,8 @@ def get_optional_params( # use the openai defaults if stop != None: optional_params["stop"] = stop #TG AI expects a list, example ["\n\n\n\n","<|endoftext|>"] elif ( - model == "chat-bison" - ): # chat-bison has diff args from chat-bison@001 ty Google + model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models + ): # chat-bison has diff args from chat-bison@001, ty Google :) if temperature != 1: optional_params["temperature"] = temperature if top_p != 1: @@ -900,6 +900,12 @@ def get_optional_params( # use the openai defaults optional_params["temperature"] = temperature optional_params["top_p"] = top_p optional_params["top_k"] = top_k + if max_tokens != float("inf"): + optional_params["max_output_tokens"] = max_tokens + elif model in model in litellm.vertex_code_text_models: + optional_params["temperature"] = temperature + if max_tokens != float("inf"): + optional_params["max_output_tokens"] = max_tokens elif custom_llm_provider == "baseten": optional_params["temperature"] = temperature optional_params["stream"] = stream @@ -2482,6 +2488,9 @@ class CustomStreamWrapper: elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud": chunk = next(self.completion_stream) completion_obj["content"] = self.handle_nlp_cloud_chunk(chunk) + elif self.model in (litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models): + chunk = next(self.completion_stream) + completion_obj["content"] = str(chunk) elif self.model in litellm.cohere_models or self.custom_llm_provider == "cohere": chunk = next(self.completion_stream) completion_obj["content"] = self.handle_cohere_chunk(chunk) diff --git a/pyproject.toml b/pyproject.toml index 4e1d51ec4..aa40a40ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.639" +version = "0.1.640" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"