fix(utils.py): fix streaming to not return usage dict

Fixes https://github.com/BerriAI/litellm/issues/3237
This commit is contained in:
Krrish Dholakia 2024-04-24 08:06:07 -07:00
parent 70c98617da
commit 48c2c3d78a
24 changed files with 107 additions and 83 deletions

View file

@ -16,11 +16,11 @@ repos:
name: Check if files match
entry: python3 ci_cd/check_files_match.py
language: system
- repo: local
hooks:
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/
# - repo: local
# hooks:
# - id: mypy
# name: mypy
# entry: python3 -m mypy --ignore-missing-imports
# language: system
# types: [python]
# files: ^litellm/

View file

@ -298,7 +298,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -137,7 +137,8 @@ class AnthropicTextCompletion(BaseLLM):
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -55,9 +55,11 @@ def completion(
"inputs": prompt,
"prompt": prompt,
"parameters": optional_params,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
"stream": (
True
if "stream" in optional_params and optional_params["stream"] == True
else False
),
}
## LOGGING
@ -71,9 +73,11 @@ def completion(
completion_url_fragment_1 + model + completion_url_fragment_2,
headers=headers,
data=json.dumps(data),
stream=True
if "stream" in optional_params and optional_params["stream"] == True
else False,
stream=(
True
if "stream" in optional_params and optional_params["stream"] == True
else False
),
)
if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True
@ -102,28 +106,28 @@ def completion(
and "data" in completion_response["model_output"]
and isinstance(completion_response["model_output"]["data"], list)
):
model_response["choices"][0]["message"][
"content"
] = completion_response["model_output"]["data"][0]
model_response["choices"][0]["message"]["content"] = (
completion_response["model_output"]["data"][0]
)
elif isinstance(completion_response["model_output"], str):
model_response["choices"][0]["message"][
"content"
] = completion_response["model_output"]
model_response["choices"][0]["message"]["content"] = (
completion_response["model_output"]
)
elif "completion" in completion_response and isinstance(
completion_response["completion"], str
):
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
model_response["choices"][0]["message"]["content"] = (
completion_response["completion"]
)
elif isinstance(completion_response, list) and len(completion_response) > 0:
if "generated_text" not in completion_response:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code,
)
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
model_response["choices"][0]["message"]["content"] = (
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS
if (
"details" in completion_response[0]
@ -155,7 +159,8 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -1028,7 +1028,7 @@ def completion(
total_tokens=response_body["usage"]["input_tokens"]
+ response_body["usage"]["output_tokens"],
)
model_response.usage = _usage
setattr(model_response, "usage", _usage)
else:
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
@ -1071,8 +1071,10 @@ def completion(
status_code=response_metadata.get("HTTPStatusCode", 500),
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
if getattr(model_response.usage, "total_tokens", None) is None:
## CALCULATING USAGE - bedrock charges on time, not tokens - have some mapping of cost here.
if not hasattr(model_response, "usage"):
setattr(model_response, "usage", Usage())
if getattr(model_response.usage, "total_tokens", None) is None: # type: ignore
prompt_tokens = response_metadata.get(
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
)
@ -1089,7 +1091,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
model_response["created"] = int(time.time())
model_response["model"] = model

View file

@ -167,7 +167,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -237,7 +237,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -305,5 +305,5 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -311,7 +311,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -152,9 +152,9 @@ def completion(
else:
try:
if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response["answer"]
model_response["choices"][0]["message"]["content"] = (
completion_response["answer"]
)
except Exception as e:
raise MaritalkError(
message=response.text, status_code=response.status_code
@ -174,7 +174,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -185,9 +185,9 @@ def completion(
else:
try:
if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
model_response["choices"][0]["message"]["content"] = (
completion_response["generated_text"]
)
except:
raise NLPCloudError(
message=json.dumps(completion_response),
@ -205,7 +205,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -99,9 +99,9 @@ def completion(
)
else:
try:
model_response["choices"][0]["message"][
"content"
] = completion_response["choices"][0]["message"]["content"]
model_response["choices"][0]["message"]["content"] = (
completion_response["choices"][0]["message"]["content"]
)
except:
raise OobaboogaError(
message=json.dumps(completion_response),
@ -115,7 +115,7 @@ def completion(
completion_tokens=completion_response["usage"]["completion_tokens"],
total_tokens=completion_response["usage"]["total_tokens"],
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -223,7 +223,7 @@ class OpenAITextCompletionConfig:
model_response_object.choices = choice_list
if "usage" in response_object:
model_response_object.usage = response_object["usage"]
setattr(model_response_object, "usage", response_object["usage"])
if "id" in response_object:
model_response_object.id = response_object["id"]

View file

@ -191,7 +191,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -41,9 +41,9 @@ class PetalsConfig:
"""
max_length: Optional[int] = None
max_new_tokens: Optional[
int
] = litellm.max_tokens # petals requires max tokens to be set
max_new_tokens: Optional[int] = (
litellm.max_tokens
) # petals requires max tokens to be set
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
@ -203,7 +203,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -345,7 +345,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -399,7 +399,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
@ -617,7 +617,7 @@ async def async_completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -2,6 +2,7 @@
Deprecated. We now do together ai calls via the openai client.
Reference: https://docs.together.ai/docs/openai-api-compatibility
"""
import os, types
import json
from enum import Enum
@ -225,7 +226,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -789,7 +789,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
@ -996,7 +996,7 @@ async def async_completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))

View file

@ -349,7 +349,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
@ -422,7 +422,7 @@ async def async_completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -104,7 +104,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response
@ -186,7 +186,7 @@ def batch_completions(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
final_outputs.append(model_response)
return final_outputs

View file

@ -407,8 +407,10 @@ def mock_completion(
model_response["created"] = int(time.time())
model_response["model"] = model
model_response.usage = Usage(
prompt_tokens=10, completion_tokens=20, total_tokens=30
setattr(
model_response,
"usage",
Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
try:
@ -652,6 +654,7 @@ def completion(
model
] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse()
setattr(model_response, "usage", litellm.Usage())
if (
kwargs.get("azure", False) == True
): # don't remove flag check, to remain backwards compatible for repos like Codium

View file

@ -61,6 +61,7 @@ def validate_first_format(chunk):
assert isinstance(chunk["created"], int), "'created' should be an integer."
assert isinstance(chunk["model"], str), "'model' should be a string."
assert isinstance(chunk["choices"], list), "'choices' should be a list."
assert not hasattr(chunk, "usage"), "Chunk cannot contain usage"
for choice in chunk["choices"]:
assert isinstance(choice["index"], int), "'index' should be an integer."
@ -90,6 +91,7 @@ def validate_second_format(chunk):
assert isinstance(chunk["created"], int), "'created' should be an integer."
assert isinstance(chunk["model"], str), "'model' should be a string."
assert isinstance(chunk["choices"], list), "'choices' should be a list."
assert not hasattr(chunk, "usage"), "Chunk cannot contain usage"
for choice in chunk["choices"]:
assert isinstance(choice["index"], int), "'index' should be an integer."
@ -127,6 +129,7 @@ def validate_last_format(chunk):
assert isinstance(chunk["created"], int), "'created' should be an integer."
assert isinstance(chunk["model"], str), "'model' should be a string."
assert isinstance(chunk["choices"], list), "'choices' should be a list."
assert not hasattr(chunk, "usage"), "Chunk cannot contain usage"
for choice in chunk["choices"]:
assert isinstance(choice["index"], int), "'index' should be an integer."

View file

@ -529,9 +529,6 @@ class ModelResponse(OpenAIObject):
backend changes have been made that might impact determinism.
"""
usage: Optional[Usage] = None
"""Usage statistics for the completion request."""
_hidden_params: dict = {}
def __init__(
@ -586,20 +583,27 @@ class ModelResponse(OpenAIObject):
else:
created = created
model = model
if usage:
if usage is not None:
usage = usage
else:
elif stream is None or stream == False:
usage = Usage()
if hidden_params:
self._hidden_params = hidden_params
init_values = {
"id": id,
"choices": choices,
"created": created,
"model": model,
"object": object,
"system_fingerprint": system_fingerprint,
}
if usage is not None:
init_values["usage"] = usage
super().__init__(
id=id,
choices=choices,
created=created,
model=model,
object=object,
system_fingerprint=system_fingerprint,
usage=usage,
**init_values,
**params,
)
@ -6852,10 +6856,14 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] =
model_response_object.choices = choice_list
if "usage" in response_object and response_object["usage"] is not None:
model_response_object.usage = Usage(
completion_tokens=response_object["usage"].get("completion_tokens", 0),
prompt_tokens=response_object["usage"].get("prompt_tokens", 0),
total_tokens=response_object["usage"].get("total_tokens", 0),
setattr(
model_response_object,
"usage",
Usage(
completion_tokens=response_object["usage"].get("completion_tokens", 0),
prompt_tokens=response_object["usage"].get("prompt_tokens", 0),
total_tokens=response_object["usage"].get("total_tokens", 0),
),
)
if "id" in response_object:
@ -10042,6 +10050,7 @@ class CustomStreamWrapper:
"content" in completion_obj
and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) == 0
and hasattr(model_response, "usage")
and hasattr(model_response.usage, "prompt_tokens")
):
if self.sent_first_chunk == False: