mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
palm streaming
This commit is contained in:
parent
0d47663b8b
commit
d1692e89dd
2 changed files with 24 additions and 3 deletions
|
@ -800,6 +800,7 @@ def completion(
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# palm does not support streaming as yet :(
|
||||||
model_response = palm.completion(
|
model_response = palm.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -812,10 +813,12 @@ def completion(
|
||||||
api_key=palm_api_key,
|
api_key=palm_api_key,
|
||||||
logging_obj=logging
|
logging_obj=logging
|
||||||
)
|
)
|
||||||
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
|
# fake palm streaming
|
||||||
# don't try to access stream object,
|
if stream == True:
|
||||||
|
# fake streaming for palm
|
||||||
|
resp_string = model_response["choices"][0]["message"]["content"]
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response, model, custom_llm_provider="palm", logging_obj=logging
|
resp_string, model, custom_llm_provider="palm", logging_obj=logging
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
|
|
|
@ -985,6 +985,11 @@ def get_optional_params( # use the openai defaults
|
||||||
optional_params["frequency_penalty"] = frequency_penalty # TODO: Check if should be repetition penalty
|
optional_params["frequency_penalty"] = frequency_penalty # TODO: Check if should be repetition penalty
|
||||||
if stop != None:
|
if stop != None:
|
||||||
optional_params["stop"] = stop #TG AI expects a list, example ["\n\n\n\n","<|endoftext|>"]
|
optional_params["stop"] = stop #TG AI expects a list, example ["\n\n\n\n","<|endoftext|>"]
|
||||||
|
elif custom_llm_provider == "palm":
|
||||||
|
if temperature != 1:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if top_p != 1:
|
||||||
|
optional_params["top_p"] = top_p
|
||||||
elif (
|
elif (
|
||||||
model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models
|
model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models
|
||||||
): # chat-bison has diff args from chat-bison@001, ty Google :)
|
): # chat-bison has diff args from chat-bison@001, ty Google :)
|
||||||
|
@ -3088,6 +3093,19 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = new_chunk
|
completion_obj["content"] = new_chunk
|
||||||
self.completion_stream = self.completion_stream[chunk_size:]
|
self.completion_stream = self.completion_stream[chunk_size:]
|
||||||
time.sleep(0.05)
|
time.sleep(0.05)
|
||||||
|
elif self.custom_llm_provider == "palm":
|
||||||
|
# fake streaming
|
||||||
|
if len(self.completion_stream)==0:
|
||||||
|
if self.sent_last_chunk:
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
model_response.choices[0].finish_reason = "stop"
|
||||||
|
self.sent_last_chunk = True
|
||||||
|
chunk_size = 30
|
||||||
|
new_chunk = self.completion_stream[:chunk_size]
|
||||||
|
completion_obj["content"] = new_chunk
|
||||||
|
self.completion_stream = self.completion_stream[chunk_size:]
|
||||||
|
time.sleep(0.05)
|
||||||
else: # openai chat/azure models
|
else: # openai chat/azure models
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
model_response = chunk
|
model_response = chunk
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue