palm streaming

This commit is contained in:
ishaan-jaff 2023-09-26 12:27:49 -07:00
parent 0d47663b8b
commit d1692e89dd
2 changed files with 24 additions and 3 deletions

View file

@ -800,6 +800,7 @@ def completion(
or litellm.api_key or litellm.api_key
) )
# palm does not support streaming as yet :(
model_response = palm.completion( model_response = palm.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -812,10 +813,12 @@ def completion(
api_key=palm_api_key, api_key=palm_api_key,
logging_obj=logging logging_obj=logging
) )
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: # fake palm streaming
# don't try to access stream object, if stream == True:
# fake streaming for palm
resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model, custom_llm_provider="palm", logging_obj=logging resp_string, model, custom_llm_provider="palm", logging_obj=logging
) )
return response return response
response = model_response response = model_response

View file

@ -985,6 +985,11 @@ def get_optional_params( # use the openai defaults
optional_params["frequency_penalty"] = frequency_penalty # TODO: Check if should be repetition penalty optional_params["frequency_penalty"] = frequency_penalty # TODO: Check if should be repetition penalty
if stop != None: if stop != None:
optional_params["stop"] = stop #TG AI expects a list, example ["\n\n\n\n","<|endoftext|>"] optional_params["stop"] = stop #TG AI expects a list, example ["\n\n\n\n","<|endoftext|>"]
elif custom_llm_provider == "palm":
if temperature != 1:
optional_params["temperature"] = temperature
if top_p != 1:
optional_params["top_p"] = top_p
elif ( elif (
model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models
): # chat-bison has diff args from chat-bison@001, ty Google :) ): # chat-bison has diff args from chat-bison@001, ty Google :)
@ -3088,6 +3093,19 @@ class CustomStreamWrapper:
completion_obj["content"] = new_chunk completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:] self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05) time.sleep(0.05)
elif self.custom_llm_provider == "palm":
# fake streaming
if len(self.completion_stream)==0:
if self.sent_last_chunk:
raise StopIteration
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
chunk_size = 30
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
else: # openai chat/azure models else: # openai chat/azure models
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
model_response = chunk model_response = chunk