feat(utils.py): fix openai-like streaming

This commit is contained in:
Krrish Dholakia 2024-07-27 15:32:57 -07:00
parent 3c77f39751
commit c85ed01756
4 changed files with 20 additions and 7 deletions

View file

@ -344,6 +344,7 @@ class DatabricksChatCompletion(BaseLLM):
self,
model: str,
messages: list,
custom_llm_provider: str,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
@ -373,7 +374,7 @@ class DatabricksChatCompletion(BaseLLM):
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="databricks",
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return streamwrapper
@ -426,6 +427,7 @@ class DatabricksChatCompletion(BaseLLM):
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
@ -499,6 +501,7 @@ class DatabricksChatCompletion(BaseLLM):
logger_fn=logger_fn,
headers=headers,
client=client,
custom_llm_provider=custom_llm_provider,
)
else:
return self.acompletion_function(
@ -537,7 +540,7 @@ class DatabricksChatCompletion(BaseLLM):
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="vertex_ai_beta",
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
else:

View file

@ -177,7 +177,6 @@ class VertexAIPartnerModels(BaseLLM):
credentials=vertex_credentials, project_id=vertex_project
)
openai_chat_completions = OpenAIChatCompletion()
openai_like_chat_completions = DatabricksChatCompletion()
## Load Config
@ -223,6 +222,7 @@ class VertexAIPartnerModels(BaseLLM):
client=client,
timeout=timeout,
encoding=encoding,
custom_llm_provider="vertex_ai_beta",
)
except Exception as e:

View file

@ -1867,6 +1867,7 @@ def completion(
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding,
custom_llm_provider="databricks",
)
except Exception as e:
## LOGGING - log the original exception returned

View file

@ -9271,11 +9271,20 @@ class CustomStreamWrapper:
try:
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
from litellm.types.utils import GenericStreamingChunk as GChunk
if (
isinstance(chunk, dict)
and all(
key in chunk for key in GChunk.__annotations__
) # check if chunk is a generic streaming chunk
) or (
self.custom_llm_provider
and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
)
):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None:
raise StopIteration