mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Merge pull request #3552 from BerriAI/litellm_predibase_support
feat(predibase.py): add support for predibase provider
This commit is contained in:
commit
a671046b45
9 changed files with 7661 additions and 73 deletions
|
@ -5,6 +5,7 @@ import sys, os, asyncio
|
|||
import traceback
|
||||
import time, pytest
|
||||
from pydantic import BaseModel
|
||||
from typing import Tuple
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -142,7 +143,7 @@ def validate_last_format(chunk):
|
|||
), "'finish_reason' should be a string."
|
||||
|
||||
|
||||
def streaming_format_tests(idx, chunk):
|
||||
def streaming_format_tests(idx, chunk) -> Tuple[str, bool]:
|
||||
extracted_chunk = ""
|
||||
finished = False
|
||||
print(f"chunk: {chunk}")
|
||||
|
@ -306,6 +307,70 @@ def test_completion_azure_stream():
|
|||
|
||||
|
||||
# test_completion_azure_stream()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_predibase_streaming(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
tenant_id="c4768f95",
|
||||
api_base="https://serving.app.predibase.com",
|
||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
complete_response = ""
|
||||
for idx, init_chunk in enumerate(response):
|
||||
chunk, finished = streaming_format_tests(idx, init_chunk)
|
||||
complete_response += chunk
|
||||
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
|
||||
print(f"custom_llm_provider: {custom_llm_provider}")
|
||||
assert custom_llm_provider == "predibase"
|
||||
if finished:
|
||||
assert isinstance(
|
||||
init_chunk.choices[0], litellm.utils.StreamingChoices
|
||||
)
|
||||
break
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
tenant_id="c4768f95",
|
||||
api_base="https://serving.app.predibase.com",
|
||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# await response
|
||||
|
||||
complete_response = ""
|
||||
idx = 0
|
||||
async for init_chunk in response:
|
||||
chunk, finished = streaming_format_tests(idx, init_chunk)
|
||||
complete_response += chunk
|
||||
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
|
||||
print(f"custom_llm_provider: {custom_llm_provider}")
|
||||
assert custom_llm_provider == "predibase"
|
||||
idx += 1
|
||||
if finished:
|
||||
assert isinstance(
|
||||
init_chunk.choices[0], litellm.utils.StreamingChoices
|
||||
)
|
||||
break
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
|
||||
print(f"complete_response: {complete_response}")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_azure_function_calling_stream():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue