mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(fixes-for-text-completion-streaming): fixes for text completion streaming
This commit is contained in:
parent
39fb3f2a74
commit
ff12e023ae
2 changed files with 38 additions and 18 deletions
|
@ -469,6 +469,7 @@ def completion(
|
||||||
"caching_groups",
|
"caching_groups",
|
||||||
"ttl",
|
"ttl",
|
||||||
"cache",
|
"cache",
|
||||||
|
"parent_call"
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -2619,7 +2620,7 @@ def text_completion(
|
||||||
# only use engine when model not passed
|
# only use engine when model not passed
|
||||||
model = kwargs["engine"]
|
model = kwargs["engine"]
|
||||||
kwargs.pop("engine")
|
kwargs.pop("engine")
|
||||||
|
kwargs["parent_call"] = kwargs.get("parent_call", "text_completion")
|
||||||
text_completion_response = TextCompletionResponse()
|
text_completion_response = TextCompletionResponse()
|
||||||
|
|
||||||
optional_params: Dict[str, Any] = {}
|
optional_params: Dict[str, Any] = {}
|
||||||
|
@ -2726,6 +2727,7 @@ def text_completion(
|
||||||
if kwargs.get("acompletion", False) == True:
|
if kwargs.get("acompletion", False) == True:
|
||||||
return response
|
return response
|
||||||
if stream == True or kwargs.get("stream", False) == True:
|
if stream == True or kwargs.get("stream", False) == True:
|
||||||
|
print(f"original model response: {response}")
|
||||||
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
||||||
return response
|
return response
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
|
@ -3162,22 +3164,23 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
|
||||||
else:
|
else:
|
||||||
completion_output = ""
|
completion_output = ""
|
||||||
# # Update usage information if needed
|
# # Update usage information if needed
|
||||||
try:
|
print(f"INSIDE TEXT COMPLETION STREAM CHUNK BUILDER")
|
||||||
response["usage"]["prompt_tokens"] = token_counter(
|
_usage = litellm.Usage
|
||||||
model=model, messages=messages
|
print(f"messages: {messages}")
|
||||||
|
_usage.prompt_tokens = token_counter(
|
||||||
|
model=model, messages=messages, count_response_tokens=True
|
||||||
)
|
)
|
||||||
except: # don't allow this failing to block a complete streaming response from being returned
|
print(f"received prompt tokens: {_usage.prompt_tokens}")
|
||||||
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
|
_usage.completion_tokens = token_counter(
|
||||||
response["usage"]["prompt_tokens"] = 0
|
|
||||||
response["usage"]["completion_tokens"] = token_counter(
|
|
||||||
model=model,
|
model=model,
|
||||||
text=combined_content,
|
text=combined_content,
|
||||||
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
|
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
|
||||||
)
|
)
|
||||||
response["usage"]["total_tokens"] = (
|
_usage.total_tokens = (
|
||||||
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
_usage.prompt_tokens + _usage.completion_tokens
|
||||||
)
|
)
|
||||||
return response
|
response["usage"] = _usage
|
||||||
|
return litellm.TextCompletionResponse(**response)
|
||||||
|
|
||||||
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
||||||
id = chunks[0]["id"]
|
id = chunks[0]["id"]
|
||||||
|
|
|
@ -2038,12 +2038,19 @@ def client(original_function):
|
||||||
if (
|
if (
|
||||||
"complete_response" in kwargs
|
"complete_response" in kwargs
|
||||||
and kwargs["complete_response"] == True
|
and kwargs["complete_response"] == True
|
||||||
|
and kwargs.get("parent_call", None) is None
|
||||||
):
|
):
|
||||||
chunks = []
|
chunks = []
|
||||||
for idx, chunk in enumerate(result):
|
for idx, chunk in enumerate(result):
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
call_type = original_function.__name__
|
||||||
|
if call_type == CallTypes.completion.value:
|
||||||
return litellm.stream_chunk_builder(
|
return litellm.stream_chunk_builder(
|
||||||
chunks, messages=kwargs.get("messages", None)
|
chunks, messages=kwargs.get("messages")
|
||||||
|
)
|
||||||
|
elif call_type == CallTypes.text_completion.value:
|
||||||
|
return litellm.stream_chunk_builder(
|
||||||
|
chunks, messages=[{"role": "user", "content": kwargs.get("prompt")}]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
@ -2695,9 +2702,10 @@ def token_counter(
|
||||||
raise ValueError("text and messages cannot both be None")
|
raise ValueError("text and messages cannot both be None")
|
||||||
elif isinstance(text, List):
|
elif isinstance(text, List):
|
||||||
text = "".join(t for t in text if isinstance(t, str))
|
text = "".join(t for t in text if isinstance(t, str))
|
||||||
|
print_verbose(f"text: {text}")
|
||||||
if model is not None:
|
if model is not None:
|
||||||
tokenizer_json = _select_tokenizer(model=model)
|
tokenizer_json = _select_tokenizer(model=model)
|
||||||
|
print(f"tokenizer_json['type']: {tokenizer_json['type']}")
|
||||||
if tokenizer_json["type"] == "huggingface_tokenizer":
|
if tokenizer_json["type"] == "huggingface_tokenizer":
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Token Counter - using hugging face token counter, for model={model}"
|
f"Token Counter - using hugging face token counter, for model={model}"
|
||||||
|
@ -2731,6 +2739,7 @@ def token_counter(
|
||||||
num_tokens = len(enc)
|
num_tokens = len(enc)
|
||||||
else:
|
else:
|
||||||
num_tokens = len(encoding.encode(text)) # type: ignore
|
num_tokens = len(encoding.encode(text)) # type: ignore
|
||||||
|
print_verbose(f"final num tokens returned: {num_tokens}")
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@ -7760,6 +7769,8 @@ class TextCompletionStreamWrapper:
|
||||||
|
|
||||||
def convert_to_text_completion_object(self, chunk: ModelResponse):
|
def convert_to_text_completion_object(self, chunk: ModelResponse):
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(chunk, ModelResponse):
|
||||||
|
return
|
||||||
response = TextCompletionResponse()
|
response = TextCompletionResponse()
|
||||||
response["id"] = chunk.get("id", None)
|
response["id"] = chunk.get("id", None)
|
||||||
response["object"] = "text_completion"
|
response["object"] = "text_completion"
|
||||||
|
@ -7784,12 +7795,18 @@ class TextCompletionStreamWrapper:
|
||||||
# model_response = ModelResponse(stream=True, model=self.model)
|
# model_response = ModelResponse(stream=True, model=self.model)
|
||||||
response = TextCompletionResponse()
|
response = TextCompletionResponse()
|
||||||
try:
|
try:
|
||||||
for chunk in self.completion_stream:
|
while True:
|
||||||
|
if isinstance(self.completion_stream, str) or isinstance(
|
||||||
|
self.completion_stream, bytes
|
||||||
|
) or isinstance(self.completion_stream, ModelResponse):
|
||||||
|
chunk = self.completion_stream
|
||||||
|
else:
|
||||||
|
chunk = next(self.completion_stream)
|
||||||
|
|
||||||
if chunk == "None" or chunk is None:
|
if chunk == "None" or chunk is None:
|
||||||
raise Exception
|
raise Exception
|
||||||
processed_chunk = self.convert_to_text_completion_object(chunk=chunk)
|
processed_chunk = self.convert_to_text_completion_object(chunk=chunk)
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
raise StopIteration
|
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue