fix databricks streaming test

This commit is contained in:
Ishaan Jaff 2024-08-16 16:56:08 -07:00
parent 5d1fcc545b
commit 51da6ab64e
2 changed files with 21 additions and 2 deletions

View file

@ -0,0 +1,16 @@
from litellm.types.utils import GenericStreamingChunk as GChunk
def generic_chunk_has_all_required_fields(chunk: dict) -> bool:
"""
Checks if the provided chunk dictionary contains all required fields for GenericStreamingChunk.
:param chunk: The dictionary to check.
:return: True if all required fields are present, False otherwise.
"""
_all_fields = GChunk.__annotations__
# this is an optional field in GenericStreamingChunk, it's not required to be present
_all_fields.pop("provider_specific_fields", None)
return all(key in chunk for key in _all_fields)

View file

@ -9565,12 +9565,15 @@ class CustomStreamWrapper:
try:
# return this for all models
completion_obj = {"content": ""}
from litellm.litellm_core_utils.streaming_utils import (
generic_chunk_has_all_required_fields,
)
from litellm.types.utils import GenericStreamingChunk as GChunk
if (
isinstance(chunk, dict)
and all(
key in chunk for key in GChunk.__annotations__
and generic_chunk_has_all_required_fields(
chunk=chunk
) # check if chunk is a generic streaming chunk
) or (
self.custom_llm_provider