forked from phoenix/litellm-mirror
feat(databricks.py): adds databricks support - completion, async, streaming
Closes https://github.com/BerriAI/litellm/issues/2160
This commit is contained in:
parent
54591e3920
commit
d2229dcd21
9 changed files with 691 additions and 5 deletions
|
@ -951,6 +951,62 @@ def test_vertex_ai_stream():
|
|||
# test_completion_vertexai_stream_bad_key()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_databricks_streaming(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
model_name = "databricks/databricks-dbrx-instruct"
|
||||
try:
|
||||
if sync_mode:
|
||||
final_chunk: Optional[litellm.ModelResponse] = None
|
||||
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10, # type: ignore
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
has_finish_reason = False
|
||||
for idx, chunk in enumerate(response):
|
||||
final_chunk = chunk
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
has_finish_reason = True
|
||||
break
|
||||
complete_response += chunk
|
||||
if has_finish_reason == False:
|
||||
raise Exception("finish reason not set")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
else:
|
||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=100, # type: ignore
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
has_finish_reason = False
|
||||
idx = 0
|
||||
final_chunk: Optional[litellm.ModelResponse] = None
|
||||
async for chunk in response:
|
||||
final_chunk = chunk
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
has_finish_reason = True
|
||||
break
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
if has_finish_reason == False:
|
||||
raise Exception("finish reason not set")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue