From d1692e89dd8abd615cad2781944f5dcef497f52f Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 26 Sep 2023 12:27:49 -0700 Subject: [PATCH] palm streaming --- litellm/main.py | 9 ++++++--- litellm/utils.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 3779df87c3..2ce471902a 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -800,6 +800,7 @@ def completion( or litellm.api_key ) + # palm does not support streaming as yet :( model_response = palm.completion( model=model, messages=messages, @@ -812,10 +813,12 @@ def completion( api_key=palm_api_key, logging_obj=logging ) - if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: - # don't try to access stream object, + # fake palm streaming + if stream == True: + # fake streaming for palm + resp_string = model_response["choices"][0]["message"]["content"] response = CustomStreamWrapper( - model_response, model, custom_llm_provider="palm", logging_obj=logging + resp_string, model, custom_llm_provider="palm", logging_obj=logging ) return response response = model_response diff --git a/litellm/utils.py b/litellm/utils.py index 96217f4162..dd6c2378a5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -985,6 +985,11 @@ def get_optional_params( # use the openai defaults optional_params["frequency_penalty"] = frequency_penalty # TODO: Check if should be repetition penalty if stop != None: optional_params["stop"] = stop #TG AI expects a list, example ["\n\n\n\n","<|endoftext|>"] + elif custom_llm_provider == "palm": + if temperature != 1: + optional_params["temperature"] = temperature + if top_p != 1: + optional_params["top_p"] = top_p elif ( model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models ): # chat-bison has diff args from chat-bison@001, ty Google :) @@ -3088,6 +3093,19 @@ class CustomStreamWrapper: completion_obj["content"] = new_chunk self.completion_stream = self.completion_stream[chunk_size:] time.sleep(0.05) + elif self.custom_llm_provider == "palm": + # fake streaming + if len(self.completion_stream)==0: + if self.sent_last_chunk: + raise StopIteration + else: + model_response.choices[0].finish_reason = "stop" + self.sent_last_chunk = True + chunk_size = 30 + new_chunk = self.completion_stream[:chunk_size] + completion_obj["content"] = new_chunk + self.completion_stream = self.completion_stream[chunk_size:] + time.sleep(0.05) else: # openai chat/azure models chunk = next(self.completion_stream) model_response = chunk