Merge pull request #3552 from BerriAI/litellm_predibase_support

feat(predibase.py): add support for predibase provider
This commit is contained in:
Krish Dholakia 2024-05-09 22:21:16 -07:00 committed by GitHub
commit a671046b45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 7661 additions and 73 deletions

View file

@ -74,6 +74,7 @@ from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
@ -110,6 +111,7 @@ anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
####### COMPLETION ENDPOINTS ################
@ -318,6 +320,7 @@ async def acompletion(
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -1785,6 +1788,52 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "predibase":
tenant_id = (
optional_params.pop("tenant_id", None)
or optional_params.pop("predibase_tenant_id", None)
or litellm.predibase_tenant_id
or get_secret("PREDIBASE_TENANT_ID")
)
api_base = (
optional_params.pop("api_base", None)
or optional_params.pop("base_url", None)
or litellm.api_base
or get_secret("PREDIBASE_API_BASE")
)
api_key = (
api_key
or litellm.api_key
or litellm.predibase_key
or get_secret("PREDIBASE_API_KEY")
)
_model_response = predibase_chat_completions.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
api_key=api_key,
tenant_id=tenant_id,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
return _model_response
response = _model_response
elif custom_llm_provider == "ai21":
custom_llm_provider = "ai21"
ai21_key = (