mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
ensure streaming format is exactly the same as openai
This commit is contained in:
parent
ebd4688fec
commit
21cd55ab26
6 changed files with 275 additions and 169 deletions
|
@ -80,6 +80,8 @@ last_fetched_at_keys = None
|
|||
# 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41}
|
||||
# }
|
||||
|
||||
def _generate_id(): # private helper function
|
||||
return 'chatcmpl-' + str(uuid.uuid4())
|
||||
|
||||
class Message(OpenAIObject):
|
||||
def __init__(self, content="default", role="assistant", logprobs=None, **params):
|
||||
|
@ -89,9 +91,9 @@ class Message(OpenAIObject):
|
|||
self.logprobs = logprobs
|
||||
|
||||
class Delta(OpenAIObject):
|
||||
def __init__(self, content="<special_litellm_token>", logprobs=None, role=None, **params):
|
||||
def __init__(self, content=None, logprobs=None, role=None, **params):
|
||||
super(Delta, self).__init__(**params)
|
||||
if content != "<special_litellm_token>":
|
||||
if content is not None:
|
||||
self.content = content
|
||||
if role:
|
||||
self.role = role
|
||||
|
@ -105,20 +107,35 @@ class Choices(OpenAIObject):
|
|||
self.message = message
|
||||
|
||||
class StreamingChoices(OpenAIObject):
|
||||
def __init__(self, finish_reason=None, index=0, delta=Delta(), **params):
|
||||
def __init__(self, finish_reason=None, index=0, delta: Optional[Delta]=None, **params):
|
||||
super(StreamingChoices, self).__init__(**params)
|
||||
self.finish_reason = finish_reason
|
||||
self.index = index
|
||||
self.delta = delta
|
||||
if delta:
|
||||
print(f"delta passed in: {delta}")
|
||||
self.delta = delta
|
||||
else:
|
||||
self.delta = Delta()
|
||||
|
||||
class ModelResponse(OpenAIObject):
|
||||
def __init__(self, choices=None, created=None, model=None, usage=None, stream=False, **params):
|
||||
super(ModelResponse, self).__init__(**params)
|
||||
def __init__(self, id=None, choices=None, created=None, model=None, usage=None, stream=False, **params):
|
||||
if stream:
|
||||
self.choices = self.choices = choices if choices else [StreamingChoices()]
|
||||
self.object = "chat.completion.chunk"
|
||||
self.choices = [StreamingChoices()]
|
||||
else:
|
||||
if model in litellm.open_ai_embedding_models:
|
||||
self.object = "embedding"
|
||||
else:
|
||||
self.object = "chat.completion"
|
||||
self.choices = self.choices = choices if choices else [Choices()]
|
||||
self.created = created
|
||||
if id is None:
|
||||
self.id = _generate_id()
|
||||
else:
|
||||
self.id = id
|
||||
if created is None:
|
||||
self.created = int(time.time())
|
||||
else:
|
||||
self.created = created
|
||||
self.model = model
|
||||
self.usage = (
|
||||
usage
|
||||
|
@ -129,6 +146,7 @@ class ModelResponse(OpenAIObject):
|
|||
"total_tokens": None,
|
||||
}
|
||||
)
|
||||
super(ModelResponse, self).__init__(**params)
|
||||
|
||||
def to_dict_recursive(self):
|
||||
d = super().to_dict_recursive()
|
||||
|
@ -1041,8 +1059,10 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
|
|||
|
||||
# check if model in known model provider list
|
||||
## openai - chatcompletion + text completion
|
||||
if model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models:
|
||||
if model in litellm.open_ai_chat_completion_models:
|
||||
custom_llm_provider = "openai"
|
||||
elif model in litellm.open_ai_text_completion_models:
|
||||
custom_llm_provider = "text-completion-openai"
|
||||
## anthropic
|
||||
elif model in litellm.anthropic_models:
|
||||
custom_llm_provider = "anthropic"
|
||||
|
@ -2359,6 +2379,7 @@ class CustomStreamWrapper:
|
|||
self.custom_llm_provider = custom_llm_provider
|
||||
self.logging_obj = logging_obj
|
||||
self.completion_stream = completion_stream
|
||||
self.sent_first_chunk = False
|
||||
if self.logging_obj:
|
||||
# Log the type of the received item
|
||||
self.logging_obj.post_call(str(type(completion_stream)))
|
||||
|
@ -2413,7 +2434,6 @@ class CustomStreamWrapper:
|
|||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
try:
|
||||
print(f"data json: {data_json}")
|
||||
return data_json["generated_text"]
|
||||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
@ -2430,7 +2450,6 @@ class CustomStreamWrapper:
|
|||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
try:
|
||||
print(f"data json: {data_json}")
|
||||
return data_json["text"]
|
||||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
@ -2485,8 +2504,12 @@ class CustomStreamWrapper:
|
|||
return ""
|
||||
|
||||
def __next__(self):
|
||||
model_response = ModelResponse(stream=True, model=self.model)
|
||||
try:
|
||||
# return this for all models
|
||||
if self.sent_first_chunk == False:
|
||||
model_response.choices[0].delta.role = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
completion_obj = {"content": ""} # default to role being assistant
|
||||
if self.model in litellm.anthropic_models:
|
||||
chunk = next(self.completion_stream)
|
||||
|
@ -2544,7 +2567,7 @@ class CustomStreamWrapper:
|
|||
model_response.choices[0].delta = completion_obj
|
||||
model_response.model = self.model
|
||||
|
||||
if model_response.choices[0].delta['content'] == "<special_litellm_token>":
|
||||
if model_response.choices[0].delta.content == "<special_litellm_token>":
|
||||
model_response.choices[0].delta = {
|
||||
"content": completion_obj["content"],
|
||||
}
|
||||
|
@ -2552,8 +2575,6 @@ class CustomStreamWrapper:
|
|||
except StopIteration:
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
print(e)
|
||||
model_response = ModelResponse(stream=True)
|
||||
model_response.choices[0].finish_reason = "stop"
|
||||
return model_response
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue