fixes to streaming

This commit is contained in:
Krrish Dholakia 2023-09-14 19:27:12 -07:00
parent 722958a4cc
commit 46e86b7433
3 changed files with 5 additions and 8 deletions

View file

@ -40,15 +40,13 @@ def test_completion_cohere_stream():
# Add any assertions here to check the response
for chunk in response:
print(f"chunk: {chunk}")
if "content" in chunk["choices"][0]["delta"]:
complete_response += chunk["choices"][0]["delta"]["content"]
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test on baseten completion call
# try:
# response = completion(
@ -290,8 +288,6 @@ async def ai21_async_completion_call():
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
except KeyError as e:
pass
except:
print(f"error occurred: {traceback.format_exc()}")
pass

View file

@ -103,7 +103,7 @@ class Choices(OpenAIObject):
self.message = message
class StreamingChoices(OpenAIObject):
def __init__(self, finish_reason=None, index=0, delta: Optional[Union[Dict, Delta]]={}, **params):
def __init__(self, finish_reason=None, index=0, delta=Delta(), **params):
super(StreamingChoices, self).__init__(**params)
self.finish_reason = finish_reason
self.index = index
@ -2492,11 +2492,12 @@ class CustomStreamWrapper:
model_response = ModelResponse(stream=True)
model_response.choices[0].delta = completion_obj
return model_response
except StopIteration:
raise StopIteration
except Exception as e:
model_response = ModelResponse(stream=True)
model_response.choices[0].finish_reason = "stop"
return model_response
# raise StopIteration
async def __anext__(self):
try: