Merge pull request #3552 from BerriAI/litellm_predibase_support

feat(predibase.py): add support for predibase provider
This commit is contained in:
Krish Dholakia 2024-05-09 22:21:16 -07:00 committed by GitHub
commit a671046b45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 7661 additions and 73 deletions

View file

@ -5,6 +5,7 @@ import sys, os, asyncio
import traceback
import time, pytest
from pydantic import BaseModel
from typing import Tuple
sys.path.insert(
0, os.path.abspath("../..")
@ -142,7 +143,7 @@ def validate_last_format(chunk):
), "'finish_reason' should be a string."
def streaming_format_tests(idx, chunk):
def streaming_format_tests(idx, chunk) -> Tuple[str, bool]:
extracted_chunk = ""
finished = False
print(f"chunk: {chunk}")
@ -306,6 +307,70 @@ def test_completion_azure_stream():
# test_completion_azure_stream()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_predibase_streaming(sync_mode):
try:
litellm.set_verbose = True
if sync_mode:
response = completion(
model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95",
api_base="https://serving.app.predibase.com",
api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
stream=True,
)
complete_response = ""
for idx, init_chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, init_chunk)
complete_response += chunk
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
print(f"custom_llm_provider: {custom_llm_provider}")
assert custom_llm_provider == "predibase"
if finished:
assert isinstance(
init_chunk.choices[0], litellm.utils.StreamingChoices
)
break
if complete_response.strip() == "":
raise Exception("Empty response received")
else:
response = await litellm.acompletion(
model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95",
api_base="https://serving.app.predibase.com",
api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
stream=True,
)
# await response
complete_response = ""
idx = 0
async for init_chunk in response:
chunk, finished = streaming_format_tests(idx, init_chunk)
complete_response += chunk
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
print(f"custom_llm_provider: {custom_llm_provider}")
assert custom_llm_provider == "predibase"
idx += 1
if finished:
assert isinstance(
init_chunk.choices[0], litellm.utils.StreamingChoices
)
break
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"complete_response: {complete_response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_azure_function_calling_stream():