fix streaming objects

This commit is contained in:
Krrish Dholakia 2023-09-14 17:16:43 -07:00
parent 29c3247753
commit 0b436cf767
4 changed files with 25 additions and 11 deletions

View file

@ -40,6 +40,7 @@ def test_completion_cohere_stream():
# Add any assertions here to check the response # Add any assertions here to check the response
for chunk in response: for chunk in response:
print(f"chunk: {chunk}") print(f"chunk: {chunk}")
if "content" in chunk["choices"][0]["delta"]:
complete_response += chunk["choices"][0]["delta"]["content"] complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
@ -79,6 +80,7 @@ def test_openai_text_completion_call():
for chunk in response: for chunk in response:
chunk_time = time.time() chunk_time = time.time()
print(f"chunk: {chunk}") print(f"chunk: {chunk}")
if "content" in chunk["choices"][0]["delta"]:
complete_response += chunk["choices"][0]["delta"]["content"] complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
@ -98,7 +100,8 @@ def ai21_completion_call():
for chunk in response: for chunk in response:
chunk_time = time.time() chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}") print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"]) print(chunk)
if "content" in chunk["choices"][0]["delta"]:
complete_response += chunk["choices"][0]["delta"]["content"] complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
@ -106,7 +109,6 @@ def ai21_completion_call():
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
ai21_completion_call()
# test on openai completion call # test on openai completion call
def test_openai_chat_completion_call(): def test_openai_chat_completion_call():
try: try:
@ -118,7 +120,8 @@ def test_openai_chat_completion_call():
for chunk in response: for chunk in response:
chunk_time = time.time() chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}") print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"]) print(chunk)
if "content" in chunk["choices"][0]["delta"]:
complete_response += chunk["choices"][0]["delta"]["content"] complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
@ -126,6 +129,7 @@ def test_openai_chat_completion_call():
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
# test_openai_chat_completion_call()
async def completion_call(): async def completion_call():
try: try:
response = completion( response = completion(
@ -139,6 +143,7 @@ async def completion_call():
chunk_time = time.time() chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}") print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"]) print(chunk["choices"][0]["delta"])
if "content" in chunk["choices"][0]["delta"]:
complete_response += chunk["choices"][0]["delta"]["content"] complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
@ -205,6 +210,8 @@ def test_together_ai_completion_call_replit():
) )
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
except KeyError as e:
pass
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
@ -232,6 +239,8 @@ def test_together_ai_completion_call_starcoder():
print(complete_response) print(complete_response)
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
except KeyError as e:
pass
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
@ -281,6 +290,8 @@ async def ai21_async_completion_call():
complete_response += chunk["choices"][0]["delta"]["content"] complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
except KeyError as e:
pass
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass

View file

@ -103,7 +103,7 @@ class Choices(OpenAIObject):
self.message = message self.message = message
class StreamingChoices(OpenAIObject): class StreamingChoices(OpenAIObject):
def __init__(self, finish_reason="stop", index=0, delta=Delta(), **params): def __init__(self, finish_reason=None, index=0, delta: Optional[Delta]={}, **params):
super(StreamingChoices, self).__init__(**params) super(StreamingChoices, self).__init__(**params)
self.finish_reason = finish_reason self.finish_reason = finish_reason
self.index = index self.index = index
@ -2493,7 +2493,10 @@ class CustomStreamWrapper:
model_response.choices[0].delta = completion_obj model_response.choices[0].delta = completion_obj
return model_response return model_response
except Exception as e: except Exception as e:
raise StopIteration model_response = ModelResponse(stream=True)
model_response.choices[0].finish_reason = "stop"
return model_response
# raise StopIteration
async def __anext__(self): async def __anext__(self):
try: try:

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.632" version = "0.1.633"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"