feat(databricks.py): adds databricks support - completion, async, streaming

Closes https://github.com/BerriAI/litellm/issues/2160
This commit is contained in:
Krrish Dholakia 2024-05-23 16:29:46 -07:00
parent 54591e3920
commit d2229dcd21
9 changed files with 691 additions and 5 deletions

View file

@ -951,6 +951,62 @@ def test_vertex_ai_stream():
# test_completion_vertexai_stream_bad_key()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_databricks_streaming(sync_mode):
litellm.set_verbose = True
model_name = "databricks/databricks-dbrx-instruct"
try:
if sync_mode:
final_chunk: Optional[litellm.ModelResponse] = None
response: litellm.CustomStreamWrapper = completion( # type: ignore
model=model_name,
messages=messages,
max_tokens=10, # type: ignore
stream=True,
)
complete_response = ""
# Add any assertions here to check the response
has_finish_reason = False
for idx, chunk in enumerate(response):
final_chunk = chunk
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
has_finish_reason = True
break
complete_response += chunk
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "":
raise Exception("Empty response received")
else:
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model=model_name,
messages=messages,
max_tokens=100, # type: ignore
stream=True,
)
complete_response = ""
# Add any assertions here to check the response
has_finish_reason = False
idx = 0
final_chunk: Optional[litellm.ModelResponse] = None
async for chunk in response:
final_chunk = chunk
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
has_finish_reason = True
break
complete_response += chunk
idx += 1
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "":
raise Exception("Empty response received")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio
async def test_completion_replicate_llama3_streaming(sync_mode):