ensure streaming format is exactly the same as openai

This commit is contained in:
Krrish Dholakia 2023-09-16 10:34:20 -07:00
parent ebd4688fec
commit 21cd55ab26
6 changed files with 275 additions and 169 deletions

View file

@ -80,6 +80,8 @@ last_fetched_at_keys = None
# 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41}
# }
def _generate_id(): # private helper function
return 'chatcmpl-' + str(uuid.uuid4())
class Message(OpenAIObject):
def __init__(self, content="default", role="assistant", logprobs=None, **params):
@ -89,9 +91,9 @@ class Message(OpenAIObject):
self.logprobs = logprobs
class Delta(OpenAIObject):
def __init__(self, content="<special_litellm_token>", logprobs=None, role=None, **params):
def __init__(self, content=None, logprobs=None, role=None, **params):
super(Delta, self).__init__(**params)
if content != "<special_litellm_token>":
if content is not None:
self.content = content
if role:
self.role = role
@ -105,20 +107,35 @@ class Choices(OpenAIObject):
self.message = message
class StreamingChoices(OpenAIObject):
def __init__(self, finish_reason=None, index=0, delta=Delta(), **params):
def __init__(self, finish_reason=None, index=0, delta: Optional[Delta]=None, **params):
super(StreamingChoices, self).__init__(**params)
self.finish_reason = finish_reason
self.index = index
self.delta = delta
if delta:
print(f"delta passed in: {delta}")
self.delta = delta
else:
self.delta = Delta()
class ModelResponse(OpenAIObject):
def __init__(self, choices=None, created=None, model=None, usage=None, stream=False, **params):
super(ModelResponse, self).__init__(**params)
def __init__(self, id=None, choices=None, created=None, model=None, usage=None, stream=False, **params):
if stream:
self.choices = self.choices = choices if choices else [StreamingChoices()]
self.object = "chat.completion.chunk"
self.choices = [StreamingChoices()]
else:
if model in litellm.open_ai_embedding_models:
self.object = "embedding"
else:
self.object = "chat.completion"
self.choices = self.choices = choices if choices else [Choices()]
self.created = created
if id is None:
self.id = _generate_id()
else:
self.id = id
if created is None:
self.created = int(time.time())
else:
self.created = created
self.model = model
self.usage = (
usage
@ -129,6 +146,7 @@ class ModelResponse(OpenAIObject):
"total_tokens": None,
}
)
super(ModelResponse, self).__init__(**params)
def to_dict_recursive(self):
d = super().to_dict_recursive()
@ -1041,8 +1059,10 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
# check if model in known model provider list
## openai - chatcompletion + text completion
if model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models:
if model in litellm.open_ai_chat_completion_models:
custom_llm_provider = "openai"
elif model in litellm.open_ai_text_completion_models:
custom_llm_provider = "text-completion-openai"
## anthropic
elif model in litellm.anthropic_models:
custom_llm_provider = "anthropic"
@ -2359,6 +2379,7 @@ class CustomStreamWrapper:
self.custom_llm_provider = custom_llm_provider
self.logging_obj = logging_obj
self.completion_stream = completion_stream
self.sent_first_chunk = False
if self.logging_obj:
# Log the type of the received item
self.logging_obj.post_call(str(type(completion_stream)))
@ -2413,7 +2434,6 @@ class CustomStreamWrapper:
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
try:
print(f"data json: {data_json}")
return data_json["generated_text"]
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
@ -2430,7 +2450,6 @@ class CustomStreamWrapper:
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
try:
print(f"data json: {data_json}")
return data_json["text"]
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
@ -2485,8 +2504,12 @@ class CustomStreamWrapper:
return ""
def __next__(self):
model_response = ModelResponse(stream=True, model=self.model)
try:
# return this for all models
if self.sent_first_chunk == False:
model_response.choices[0].delta.role = "assistant"
self.sent_first_chunk = True
completion_obj = {"content": ""} # default to role being assistant
if self.model in litellm.anthropic_models:
chunk = next(self.completion_stream)
@ -2544,7 +2567,7 @@ class CustomStreamWrapper:
model_response.choices[0].delta = completion_obj
model_response.model = self.model
if model_response.choices[0].delta['content'] == "<special_litellm_token>":
if model_response.choices[0].delta.content == "<special_litellm_token>":
model_response.choices[0].delta = {
"content": completion_obj["content"],
}
@ -2552,8 +2575,6 @@ class CustomStreamWrapper:
except StopIteration:
raise StopIteration
except Exception as e:
print(e)
model_response = ModelResponse(stream=True)
model_response.choices[0].finish_reason = "stop"
return model_response