feat(utils.py): fix openai-like streaming

This commit is contained in:
Krrish Dholakia 2024-07-27 15:32:57 -07:00
parent 3c77f39751
commit c85ed01756
4 changed files with 20 additions and 7 deletions

View file

@ -344,6 +344,7 @@ class DatabricksChatCompletion(BaseLLM):
self, self,
model: str, model: str,
messages: list, messages: list,
custom_llm_provider: str,
api_base: str, api_base: str,
custom_prompt_dict: dict, custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
@ -373,7 +374,7 @@ class DatabricksChatCompletion(BaseLLM):
logging_obj=logging_obj, logging_obj=logging_obj,
), ),
model=model, model=model,
custom_llm_provider="databricks", custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj, logging_obj=logging_obj,
) )
return streamwrapper return streamwrapper
@ -426,6 +427,7 @@ class DatabricksChatCompletion(BaseLLM):
model: str, model: str,
messages: list, messages: list,
api_base: str, api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict, custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
@ -499,6 +501,7 @@ class DatabricksChatCompletion(BaseLLM):
logger_fn=logger_fn, logger_fn=logger_fn,
headers=headers, headers=headers,
client=client, client=client,
custom_llm_provider=custom_llm_provider,
) )
else: else:
return self.acompletion_function( return self.acompletion_function(
@ -537,7 +540,7 @@ class DatabricksChatCompletion(BaseLLM):
logging_obj=logging_obj, logging_obj=logging_obj,
), ),
model=model, model=model,
custom_llm_provider="vertex_ai_beta", custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj, logging_obj=logging_obj,
) )
else: else:

View file

@ -177,7 +177,6 @@ class VertexAIPartnerModels(BaseLLM):
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials, project_id=vertex_project
) )
openai_chat_completions = OpenAIChatCompletion()
openai_like_chat_completions = DatabricksChatCompletion() openai_like_chat_completions = DatabricksChatCompletion()
## Load Config ## Load Config
@ -223,6 +222,7 @@ class VertexAIPartnerModels(BaseLLM):
client=client, client=client,
timeout=timeout, timeout=timeout,
encoding=encoding, encoding=encoding,
custom_llm_provider="vertex_ai_beta",
) )
except Exception as e: except Exception as e:

View file

@ -1867,6 +1867,7 @@ def completion(
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding, encoding=encoding,
custom_llm_provider="databricks",
) )
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned

View file

@ -9271,11 +9271,20 @@ class CustomStreamWrapper:
try: try:
# return this for all models # return this for all models
completion_obj = {"content": ""} completion_obj = {"content": ""}
if self.custom_llm_provider and ( from litellm.types.utils import GenericStreamingChunk as GChunk
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers if (
isinstance(chunk, dict)
and all(
key in chunk for key in GChunk.__annotations__
) # check if chunk is a generic streaming chunk
) or (
self.custom_llm_provider
and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
)
): ):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
raise StopIteration raise StopIteration