diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py index 18dd4ab65..1a276f52c 100644 --- a/litellm/llms/databricks.py +++ b/litellm/llms/databricks.py @@ -344,6 +344,7 @@ class DatabricksChatCompletion(BaseLLM): self, model: str, messages: list, + custom_llm_provider: str, api_base: str, custom_prompt_dict: dict, model_response: ModelResponse, @@ -373,7 +374,7 @@ class DatabricksChatCompletion(BaseLLM): logging_obj=logging_obj, ), model=model, - custom_llm_provider="databricks", + custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) return streamwrapper @@ -426,6 +427,7 @@ class DatabricksChatCompletion(BaseLLM): model: str, messages: list, api_base: str, + custom_llm_provider: str, custom_prompt_dict: dict, model_response: ModelResponse, print_verbose: Callable, @@ -499,6 +501,7 @@ class DatabricksChatCompletion(BaseLLM): logger_fn=logger_fn, headers=headers, client=client, + custom_llm_provider=custom_llm_provider, ) else: return self.acompletion_function( @@ -537,7 +540,7 @@ class DatabricksChatCompletion(BaseLLM): logging_obj=logging_obj, ), model=model, - custom_llm_provider="vertex_ai_beta", + custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) else: diff --git a/litellm/llms/vertex_ai_partner.py b/litellm/llms/vertex_ai_partner.py index 66f8a1740..eb24c4d26 100644 --- a/litellm/llms/vertex_ai_partner.py +++ b/litellm/llms/vertex_ai_partner.py @@ -177,7 +177,6 @@ class VertexAIPartnerModels(BaseLLM): credentials=vertex_credentials, project_id=vertex_project ) - openai_chat_completions = OpenAIChatCompletion() openai_like_chat_completions = DatabricksChatCompletion() ## Load Config @@ -223,6 +222,7 @@ class VertexAIPartnerModels(BaseLLM): client=client, timeout=timeout, encoding=encoding, + custom_llm_provider="vertex_ai_beta", ) except Exception as e: diff --git a/litellm/main.py b/litellm/main.py index c88119df9..4abd44707 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1867,6 +1867,7 @@ def completion( custom_prompt_dict=custom_prompt_dict, client=client, # pass AsyncOpenAI, OpenAI client encoding=encoding, + custom_llm_provider="databricks", ) except Exception as e: ## LOGGING - log the original exception returned diff --git a/litellm/utils.py b/litellm/utils.py index 27b3f60c1..7df136846 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9271,11 +9271,20 @@ class CustomStreamWrapper: try: # return this for all models completion_obj = {"content": ""} - if self.custom_llm_provider and ( - self.custom_llm_provider == "anthropic" - or self.custom_llm_provider in litellm._custom_providers + from litellm.types.utils import GenericStreamingChunk as GChunk + + if ( + isinstance(chunk, dict) + and all( + key in chunk for key in GChunk.__annotations__ + ) # check if chunk is a generic streaming chunk + ) or ( + self.custom_llm_provider + and ( + self.custom_llm_provider == "anthropic" + or self.custom_llm_provider in litellm._custom_providers + ) ): - from litellm.types.utils import GenericStreamingChunk as GChunk if self.received_finish_reason is not None: raise StopIteration