mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into litellm_fix_httpx_transport
This commit is contained in:
commit
8661da1980
142 changed files with 6725 additions and 2086 deletions
249
litellm/utils.py
249
litellm/utils.py
|
@ -533,6 +533,8 @@ def function_setup(
|
|||
call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value
|
||||
):
|
||||
messages = kwargs.get("input", "speech")
|
||||
else:
|
||||
messages = "default-message-value"
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = litellm.litellm_core_utils.litellm_logging.Logging(
|
||||
model=model,
|
||||
|
@ -563,10 +565,8 @@ def function_setup(
|
|||
)
|
||||
return logging_obj, kwargs
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.debug(
|
||||
f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
|
||||
verbose_logger.error(
|
||||
f"litellm.utils.py::function_setup() - [Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
@ -2433,7 +2433,10 @@ def get_optional_params(
|
|||
non_default_params=passed_params, optional_params=optional_params
|
||||
)
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
elif (
|
||||
custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "vertex_ai_beta"
|
||||
):
|
||||
optional_params = litellm.VertexAIConfig().map_special_auth_params(
|
||||
non_default_params=passed_params, optional_params=optional_params
|
||||
)
|
||||
|
@ -2554,6 +2557,24 @@ def get_optional_params(
|
|||
message=f"Function calling is not supported by {custom_llm_provider}.",
|
||||
)
|
||||
|
||||
if "tools" in non_default_params:
|
||||
tools = non_default_params["tools"]
|
||||
for (
|
||||
tool
|
||||
) in (
|
||||
tools
|
||||
): # clean out 'additionalProperties = False'. Causes vertexai/gemini OpenAI API Schema errors - https://github.com/langchain-ai/langchainjs/issues/5240
|
||||
tool_function = tool.get("function", {})
|
||||
parameters = tool_function.get("parameters", None)
|
||||
if parameters is not None:
|
||||
new_parameters = copy.deepcopy(parameters)
|
||||
if (
|
||||
"additionalProperties" in new_parameters
|
||||
and new_parameters["additionalProperties"] is False
|
||||
):
|
||||
new_parameters.pop("additionalProperties", None)
|
||||
tool_function["parameters"] = new_parameters
|
||||
|
||||
def _check_valid_arg(supported_params):
|
||||
verbose_logger.debug(
|
||||
f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}"
|
||||
|
@ -3183,7 +3204,9 @@ def get_optional_params(
|
|||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.NvidiaNimConfig().map_openai_params(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -3636,7 +3659,7 @@ def get_model_region(
|
|||
model=_model,
|
||||
api_key=litellm_params.api_key,
|
||||
api_base=litellm_params.api_base,
|
||||
api_version=litellm_params.api_version or "2023-07-01-preview",
|
||||
api_version=litellm_params.api_version or litellm.AZURE_DEFAULT_API_VERSION,
|
||||
timeout=10,
|
||||
mode=mode or "chat",
|
||||
)
|
||||
|
@ -3775,7 +3798,7 @@ def get_supported_openai_params(
|
|||
elif custom_llm_provider == "fireworks_ai":
|
||||
return litellm.FireworksAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "nvidia_nim":
|
||||
return litellm.NvidiaNimConfig().get_supported_openai_params()
|
||||
return litellm.NvidiaNimConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "volcengine":
|
||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "groq":
|
||||
|
@ -3858,6 +3881,8 @@ def get_supported_openai_params(
|
|||
"top_logprobs",
|
||||
"response_format",
|
||||
"stop",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
||||
# mistal and codestral api have the exact same params
|
||||
|
@ -3916,6 +3941,11 @@ def get_supported_openai_params(
|
|||
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "vertex_ai_beta":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
elif custom_llm_provider == "aleph_alpha":
|
||||
|
@ -4682,6 +4712,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_character_above_128k_tokens", None
|
||||
),
|
||||
output_vector_size=_model_info.get("output_vector_size", None),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
|
@ -4696,7 +4727,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -7536,7 +7569,7 @@ def exception_type(
|
|||
if original_exception.status_code == 400:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"{exception_provider} - {message}",
|
||||
message=f"{exception_provider} - {error_str}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
|
@ -7545,7 +7578,7 @@ def exception_type(
|
|||
elif original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=f"AuthenticationError: {exception_provider} - {message}",
|
||||
message=f"AuthenticationError: {exception_provider} - {error_str}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
|
@ -7554,7 +7587,7 @@ def exception_type(
|
|||
elif original_exception.status_code == 404:
|
||||
exception_mapping_worked = True
|
||||
raise NotFoundError(
|
||||
message=f"NotFoundError: {exception_provider} - {message}",
|
||||
message=f"NotFoundError: {exception_provider} - {error_str}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
response=original_exception.response,
|
||||
|
@ -7563,7 +7596,7 @@ def exception_type(
|
|||
elif original_exception.status_code == 408:
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"Timeout Error: {exception_provider} - {message}",
|
||||
message=f"Timeout Error: {exception_provider} - {error_str}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
litellm_debug_info=extra_information,
|
||||
|
@ -7571,7 +7604,7 @@ def exception_type(
|
|||
elif original_exception.status_code == 422:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"BadRequestError: {exception_provider} - {message}",
|
||||
message=f"BadRequestError: {exception_provider} - {error_str}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
response=original_exception.response,
|
||||
|
@ -7580,7 +7613,7 @@ def exception_type(
|
|||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"RateLimitError: {exception_provider} - {message}",
|
||||
message=f"RateLimitError: {exception_provider} - {error_str}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
response=original_exception.response,
|
||||
|
@ -7589,7 +7622,7 @@ def exception_type(
|
|||
elif original_exception.status_code == 503:
|
||||
exception_mapping_worked = True
|
||||
raise ServiceUnavailableError(
|
||||
message=f"ServiceUnavailableError: {exception_provider} - {message}",
|
||||
message=f"ServiceUnavailableError: {exception_provider} - {error_str}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
response=original_exception.response,
|
||||
|
@ -7598,7 +7631,7 @@ def exception_type(
|
|||
elif original_exception.status_code == 504: # gateway timeout error
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"Timeout Error: {exception_provider} - {message}",
|
||||
message=f"Timeout Error: {exception_provider} - {error_str}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
litellm_debug_info=extra_information,
|
||||
|
@ -7607,7 +7640,7 @@ def exception_type(
|
|||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
status_code=original_exception.status_code,
|
||||
message=f"APIError: {exception_provider} - {message}",
|
||||
message=f"APIError: {exception_provider} - {error_str}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
request=original_exception.request,
|
||||
|
@ -7616,7 +7649,7 @@ def exception_type(
|
|||
else:
|
||||
# if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors
|
||||
raise APIConnectionError(
|
||||
message=f"APIConnectionError: {exception_provider} - {message}",
|
||||
message=f"APIConnectionError: {exception_provider} - {error_str}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
litellm_debug_info=extra_information,
|
||||
|
@ -7967,6 +8000,7 @@ class CustomStreamWrapper:
|
|||
)
|
||||
self.messages = getattr(logging_obj, "messages", None)
|
||||
self.sent_stream_usage = False
|
||||
self.tool_call = False
|
||||
self.chunks: List = (
|
||||
[]
|
||||
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
||||
|
@ -8033,6 +8067,11 @@ class CustomStreamWrapper:
|
|||
return hold, curr_chunk
|
||||
|
||||
def handle_anthropic_text_chunk(self, chunk):
|
||||
"""
|
||||
For old anthropic models - claude-1, claude-2.
|
||||
|
||||
Claude-3 is handled from within Anthropic.py VIA ModelResponseIterator()
|
||||
"""
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
|
@ -8061,44 +8100,6 @@ class CustomStreamWrapper:
|
|||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_anthropic_chunk(self, chunk):
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
if str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
type_chunk = data_json.get("type", None)
|
||||
if type_chunk == "content_block_delta":
|
||||
"""
|
||||
Anthropic content chunk
|
||||
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
||||
"""
|
||||
text = data_json.get("delta", {}).get("text", "")
|
||||
elif type_chunk == "message_delta":
|
||||
"""
|
||||
Anthropic
|
||||
chunk = {'type': 'message_delta', 'delta': {'stop_reason': 'max_tokens', 'stop_sequence': None}, 'usage': {'output_tokens': 10}}
|
||||
"""
|
||||
# TODO - get usage from this chunk, set in response
|
||||
finish_reason = data_json.get("delta", {}).get("stop_reason", None)
|
||||
is_finished = True
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
elif "error" in str_line:
|
||||
raise ValueError(f"Unable to parse response. Original response: {str_line}")
|
||||
else:
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_vertexai_anthropic_chunk(self, chunk):
|
||||
"""
|
||||
- MessageStartEvent(message=Message(id='msg_01LeRRgvX4gwkX3ryBVgtuYZ', content=[], model='claude-3-sonnet-20240229', role='assistant', stop_reason=None, stop_sequence=None, type='message', usage=Usage(input_tokens=8, output_tokens=1)), type='message_start'); custom_llm_provider: vertex_ai
|
||||
|
@ -8809,7 +8810,7 @@ class CustomStreamWrapper:
|
|||
verbose_logger.debug(traceback.format_exc())
|
||||
return ""
|
||||
|
||||
def model_response_creator(self):
|
||||
def model_response_creator(self, chunk: Optional[dict] = None):
|
||||
_model = self.model
|
||||
_received_llm_provider = self.custom_llm_provider
|
||||
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
|
||||
|
@ -8818,13 +8819,18 @@ class CustomStreamWrapper:
|
|||
and _received_llm_provider != _logging_obj_llm_provider
|
||||
):
|
||||
_model = "{}/{}".format(_logging_obj_llm_provider, _model)
|
||||
if chunk is None:
|
||||
chunk = {}
|
||||
else:
|
||||
# pop model keyword
|
||||
chunk.pop("model", None)
|
||||
model_response = ModelResponse(
|
||||
stream=True, model=_model, stream_options=self.stream_options
|
||||
stream=True, model=_model, stream_options=self.stream_options, **chunk
|
||||
)
|
||||
if self.response_id is not None:
|
||||
model_response.id = self.response_id
|
||||
else:
|
||||
self.response_id = model_response.id
|
||||
self.response_id = model_response.id # type: ignore
|
||||
if self.system_fingerprint is not None:
|
||||
model_response.system_fingerprint = self.system_fingerprint
|
||||
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
|
||||
|
@ -8849,10 +8855,37 @@ class CustomStreamWrapper:
|
|||
# return this for all models
|
||||
completion_obj = {"content": ""}
|
||||
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
|
||||
response_obj = self.handle_anthropic_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
raise StopIteration
|
||||
anthropic_response_obj: GChunk = chunk
|
||||
completion_obj["content"] = anthropic_response_obj["text"]
|
||||
if anthropic_response_obj["is_finished"]:
|
||||
self.received_finish_reason = anthropic_response_obj[
|
||||
"finish_reason"
|
||||
]
|
||||
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
and anthropic_response_obj["usage"] is not None
|
||||
):
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=anthropic_response_obj["usage"]["prompt_tokens"],
|
||||
completion_tokens=anthropic_response_obj["usage"][
|
||||
"completion_tokens"
|
||||
],
|
||||
total_tokens=anthropic_response_obj["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
if (
|
||||
"tool_use" in anthropic_response_obj
|
||||
and anthropic_response_obj["tool_use"] is not None
|
||||
):
|
||||
completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]]
|
||||
|
||||
response_obj = anthropic_response_obj
|
||||
elif (
|
||||
self.custom_llm_provider
|
||||
and self.custom_llm_provider == "anthropic_text"
|
||||
|
@ -8961,7 +8994,6 @@ class CustomStreamWrapper:
|
|||
and self.stream_options.get("include_usage", False) is True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"]["prompt_tokens"],
|
||||
completion_tokens=response_obj["usage"]["completion_tokens"],
|
||||
|
@ -9089,7 +9121,6 @@ class CustomStreamWrapper:
|
|||
and self.stream_options.get("include_usage", False) is True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||
|
@ -9161,7 +9192,6 @@ class CustomStreamWrapper:
|
|||
and self.stream_options.get("include_usage", False) == True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||
completion_tokens=response_obj["usage"].completion_tokens,
|
||||
|
@ -9180,7 +9210,6 @@ class CustomStreamWrapper:
|
|||
and self.stream_options.get("include_usage", False) == True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||
completion_tokens=response_obj["usage"].completion_tokens,
|
||||
|
@ -9197,7 +9226,6 @@ class CustomStreamWrapper:
|
|||
and self.stream_options.get("include_usage", False) == True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||
completion_tokens=response_obj["usage"].completion_tokens,
|
||||
|
@ -9215,9 +9243,16 @@ class CustomStreamWrapper:
|
|||
"is_finished": True,
|
||||
"finish_reason": chunk.choices[0].finish_reason,
|
||||
"original_chunk": chunk,
|
||||
"tool_calls": (
|
||||
chunk.choices[0].delta.tool_calls
|
||||
if hasattr(chunk.choices[0].delta, "tool_calls")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["tool_calls"] is not None:
|
||||
completion_obj["tool_calls"] = response_obj["tool_calls"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if hasattr(chunk, "id"):
|
||||
model_response.id = chunk.id
|
||||
|
@ -9239,7 +9274,9 @@ class CustomStreamWrapper:
|
|||
if response_obj["is_finished"]:
|
||||
if response_obj["finish_reason"] == "error":
|
||||
raise Exception(
|
||||
"Mistral API raised a streaming error - finish_reason: error, no content string given."
|
||||
"{} raised a streaming error - finish_reason: error, no content string given. Received Chunk={}".format(
|
||||
self.custom_llm_provider, response_obj
|
||||
)
|
||||
)
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
if response_obj.get("original_chunk", None) is not None:
|
||||
|
@ -9261,7 +9298,6 @@ class CustomStreamWrapper:
|
|||
and self.stream_options["include_usage"] == True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||
completion_tokens=response_obj["usage"].completion_tokens,
|
||||
|
@ -9374,6 +9410,10 @@ class CustomStreamWrapper:
|
|||
)
|
||||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||
|
||||
## CHECK FOR TOOL USE
|
||||
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
|
||||
self.tool_call = True
|
||||
|
||||
## RETURN ARG
|
||||
if (
|
||||
"content" in completion_obj
|
||||
|
@ -9552,6 +9592,12 @@ class CustomStreamWrapper:
|
|||
)
|
||||
else:
|
||||
model_response.choices[0].finish_reason = "stop"
|
||||
|
||||
## if tool use
|
||||
if (
|
||||
model_response.choices[0].finish_reason == "stop" and self.tool_call
|
||||
): # don't overwrite for other - potential error finish reasons
|
||||
model_response.choices[0].finish_reason = "tool_calls"
|
||||
return model_response
|
||||
|
||||
def __next__(self):
|
||||
|
@ -9586,11 +9632,26 @@ class CustomStreamWrapper:
|
|||
self.rules.post_call_rules(
|
||||
input=self.response_uptil_now, model=self.model
|
||||
)
|
||||
# RETURN RESULT
|
||||
# HANDLE STREAM OPTIONS
|
||||
self.chunks.append(response)
|
||||
if hasattr(
|
||||
response, "usage"
|
||||
): # remove usage from chunk, only send on final chunk
|
||||
# Convert the object to a dictionary
|
||||
obj_dict = response.dict()
|
||||
|
||||
# Remove an attribute (e.g., 'attr2')
|
||||
if "usage" in obj_dict:
|
||||
del obj_dict["usage"]
|
||||
|
||||
# Create a new object without the removed attribute
|
||||
response = self.model_response_creator(chunk=obj_dict)
|
||||
|
||||
# RETURN RESULT
|
||||
return response
|
||||
|
||||
except StopIteration:
|
||||
if self.sent_last_chunk == True:
|
||||
if self.sent_last_chunk is True:
|
||||
if (
|
||||
self.sent_stream_usage == False
|
||||
and self.stream_options is not None
|
||||
|
@ -9716,6 +9777,18 @@ class CustomStreamWrapper:
|
|||
)
|
||||
print_verbose(f"final returned processed chunk: {processed_chunk}")
|
||||
self.chunks.append(processed_chunk)
|
||||
if hasattr(
|
||||
processed_chunk, "usage"
|
||||
): # remove usage from chunk, only send on final chunk
|
||||
# Convert the object to a dictionary
|
||||
obj_dict = processed_chunk.dict()
|
||||
|
||||
# Remove an attribute (e.g., 'attr2')
|
||||
if "usage" in obj_dict:
|
||||
del obj_dict["usage"]
|
||||
|
||||
# Create a new object without the removed attribute
|
||||
processed_chunk = self.model_response_creator(chunk=obj_dict)
|
||||
return processed_chunk
|
||||
raise StopAsyncIteration
|
||||
else: # temporary patch for non-aiohttp async calls
|
||||
|
@ -9758,11 +9831,11 @@ class CustomStreamWrapper:
|
|||
self.chunks.append(processed_chunk)
|
||||
return processed_chunk
|
||||
except StopAsyncIteration:
|
||||
if self.sent_last_chunk == True:
|
||||
if self.sent_last_chunk is True:
|
||||
if (
|
||||
self.sent_stream_usage == False
|
||||
self.sent_stream_usage is False
|
||||
and self.stream_options is not None
|
||||
and self.stream_options.get("include_usage", False) == True
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
):
|
||||
# send the final chunk with stream options
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
|
@ -9796,7 +9869,29 @@ class CustomStreamWrapper:
|
|||
)
|
||||
return processed_chunk
|
||||
except StopIteration:
|
||||
if self.sent_last_chunk == True:
|
||||
if self.sent_last_chunk is True:
|
||||
if (
|
||||
self.sent_stream_usage is False
|
||||
and self.stream_options is not None
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
):
|
||||
# send the final chunk with stream options
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
chunks=self.chunks, messages=self.messages
|
||||
)
|
||||
response = self.model_response_creator()
|
||||
response.usage = complete_streaming_response.usage
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler, args=(response,)
|
||||
).start() # log response
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
response,
|
||||
)
|
||||
)
|
||||
self.sent_stream_usage = True
|
||||
return response
|
||||
raise StopAsyncIteration
|
||||
else:
|
||||
self.sent_last_chunk = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue