add fake streaming for petals

This commit is contained in:
ishaan-jaff 2023-09-30 10:22:04 -07:00
parent 141c9c5bac
commit 8d1f5ba69d
2 changed files with 16 additions and 3 deletions

View file

@ -1099,10 +1099,11 @@ def completion(
encoding=encoding, encoding=encoding,
logging_obj=logging logging_obj=logging
) )
if inspect.isgenerator(model_response) or (stream == True): if stream==True: ## [BETA]
# don't try to access stream object, # Fake streaming for petals
resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model, custom_llm_provider="petals", logging_obj=logging resp_string, model, custom_llm_provider="petals", logging_obj=logging
) )
return response return response
response = model_response response = model_response

View file

@ -3158,6 +3158,18 @@ class CustomStreamWrapper:
completion_obj["content"] = new_chunk completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:] self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05) time.sleep(0.05)
elif self.custom_llm_provider == "petals":
if len(self.completion_stream)==0:
if self.sent_last_chunk:
raise StopIteration
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
chunk_size = 30
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "palm": elif self.custom_llm_provider == "palm":
# fake streaming # fake streaming
if len(self.completion_stream)==0: if len(self.completion_stream)==0: