mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* feat: initial commit for watsonx chat endpoint support Closes https://github.com/BerriAI/litellm/issues/6562 * feat(watsonx/chat/handler.py): support tool calling for watsonx Closes https://github.com/BerriAI/litellm/issues/6562 * fix(streaming_utils.py): return empty chunk instead of failing if streaming value is invalid dict ensures streaming works for ibm watsonx * fix(openai_like/chat/handler.py): ensure asynchttphandler is passed correctly for openai like calls * fix: ensure exception mapping works well for watsonx calls * fix(openai_like/chat/handler.py): handle async streaming correctly * feat(main.py): Make it clear when a user is passing an invalid message add validation for user content message Closes https://github.com/BerriAI/litellm/issues/6565 * fix: cleanup * fix(utils.py): loosen validation check, to just make sure content types are valid make litellm robust to future content updates * fix: fix linting erro * fix: fix linting errors * fix(utils.py): make validation check more flexible * test: handle langfuse list index out of range error * Litellm dev 11 02 2024 (#6561) * fix(dual_cache.py): update in-memory check for redis batch get cache Fixes latency delay for async_batch_redis_cache * fix(service_logger.py): fix race condition causing otel service logging to be overwritten if service_callbacks set * feat(user_api_key_auth.py): add parent otel component for auth allows us to isolate how much latency is added by auth checks * perf(parallel_request_limiter.py): move async_set_cache_pipeline (from max parallel request limiter) out of execution path (background task) reduces latency by 200ms * feat(user_api_key_auth.py): have user api key auth object return user tpm/rpm limits - reduces redis calls in downstream task (parallel_request_limiter) Reduces latency by 400-800ms * fix(parallel_request_limiter.py): use batch get cache to reduce user/key/team usage object calls reduces latency by 50-100ms * fix: fix linting error * fix(_service_logger.py): fix import * fix(user_api_key_auth.py): fix service logging * fix(dual_cache.py): don't pass 'self' * fix: fix python3.8 error * fix: fix init] * bump: version 1.51.4 → 1.51.5 * build(deps): bump cookie and express in /docs/my-website (#6566) Bumps [cookie](https://github.com/jshttp/cookie) and [express](https://github.com/expressjs/express). These dependencies needed to be updated together. Updates `cookie` from 0.6.0 to 0.7.1 - [Release notes](https://github.com/jshttp/cookie/releases) - [Commits](https://github.com/jshttp/cookie/compare/v0.6.0...v0.7.1) Updates `express` from 4.20.0 to 4.21.1 - [Release notes](https://github.com/expressjs/express/releases) - [Changelog](https://github.com/expressjs/express/blob/4.21.1/History.md) - [Commits](https://github.com/expressjs/express/compare/4.20.0...4.21.1) --- updated-dependencies: - dependency-name: cookie dependency-type: indirect - dependency-name: express dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * docs(virtual_keys.md): update Dockerfile reference (#6554) Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> * (proxy fix) - call connect on prisma client when running setup (#6534) * critical fix - call connect on prisma client when running setup * fix test_proxy_server_prisma_setup * fix test_proxy_server_prisma_setup * Add 3.5 haiku (#6588) * feat: add claude-3-5-haiku-20241022 entries * feat: add claude-3-5-haiku-20241022 and vertex_ai/claude-3-5-haiku@20241022 models * add missing entries, remove vision * remove image token costs * Litellm perf improvements 3 (#6573) * perf: move writing key to cache, to background task * perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils adds 200ms on calls with pgdb connected * fix(litellm_pre_call_utils.py'): rename call_type to actual call used * perf(proxy_server.py): remove db logic from _get_config_from_file was causing db calls to occur on every llm request, if team_id was set on key * fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db reduces latency/call by ~100ms * fix(proxy_server.py): minor fix on existing_settings not incl alerting * fix(exception_mapping_utils.py): map databricks exception string * fix(auth_checks.py): fix auth check logic * test: correctly mark flaky test * fix(utils.py): handle auth token error for tokenizers.from_pretrained * build: fix map * build: fix map * build: fix json for model map * Litellm dev 11 02 2024 (#6561) * fix(dual_cache.py): update in-memory check for redis batch get cache Fixes latency delay for async_batch_redis_cache * fix(service_logger.py): fix race condition causing otel service logging to be overwritten if service_callbacks set * feat(user_api_key_auth.py): add parent otel component for auth allows us to isolate how much latency is added by auth checks * perf(parallel_request_limiter.py): move async_set_cache_pipeline (from max parallel request limiter) out of execution path (background task) reduces latency by 200ms * feat(user_api_key_auth.py): have user api key auth object return user tpm/rpm limits - reduces redis calls in downstream task (parallel_request_limiter) Reduces latency by 400-800ms * fix(parallel_request_limiter.py): use batch get cache to reduce user/key/team usage object calls reduces latency by 50-100ms * fix: fix linting error * fix(_service_logger.py): fix import * fix(user_api_key_auth.py): fix service logging * fix(dual_cache.py): don't pass 'self' * fix: fix python3.8 error * fix: fix init] * Litellm perf improvements 3 (#6573) * perf: move writing key to cache, to background task * perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils adds 200ms on calls with pgdb connected * fix(litellm_pre_call_utils.py'): rename call_type to actual call used * perf(proxy_server.py): remove db logic from _get_config_from_file was causing db calls to occur on every llm request, if team_id was set on key * fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db reduces latency/call by ~100ms * fix(proxy_server.py): minor fix on existing_settings not incl alerting * fix(exception_mapping_utils.py): map databricks exception string * fix(auth_checks.py): fix auth check logic * test: correctly mark flaky test * fix(utils.py): handle auth token error for tokenizers.from_pretrained * fix ImageObject conversion (#6584) * (fix) litellm.text_completion raises a non-blocking error on simple usage (#6546) * unit test test_huggingface_text_completion_logprobs * fix return TextCompletionHandler convert_chat_to_text_completion * fix hf rest api * fix test_huggingface_text_completion_logprobs * fix linting errors * fix importLiteLLMResponseObjectHandler * fix test for LiteLLMResponseObjectHandler * fix test text completion * fix allow using 15 seconds for premium license check * testing fix bedrock deprecated cohere.command-text-v14 * (feat) add `Predicted Outputs` for OpenAI (#6594) * bump openai to openai==1.54.0 * add 'prediction' param * testing fix bedrock deprecated cohere.command-text-v14 * test test_openai_prediction_param.py * test_openai_prediction_param_with_caching * doc Predicted Outputs * doc Predicted Output * (fix) Vertex Improve Performance when using `image_url` (#6593) * fix transformation vertex * test test_process_gemini_image * test_image_completion_request * testing fix - bedrock has deprecated cohere.command-text-v14 * fix vertex pdf * bump: version 1.51.5 → 1.52.0 * fix(lowest_tpm_rpm_routing.py): fix parallel rate limit check (#6577) * fix(lowest_tpm_rpm_routing.py): fix parallel rate limit check * fix(lowest_tpm_rpm_v2.py): return headers in correct format * test: update test * build(deps): bump cookie and express in /docs/my-website (#6566) Bumps [cookie](https://github.com/jshttp/cookie) and [express](https://github.com/expressjs/express). These dependencies needed to be updated together. Updates `cookie` from 0.6.0 to 0.7.1 - [Release notes](https://github.com/jshttp/cookie/releases) - [Commits](https://github.com/jshttp/cookie/compare/v0.6.0...v0.7.1) Updates `express` from 4.20.0 to 4.21.1 - [Release notes](https://github.com/expressjs/express/releases) - [Changelog](https://github.com/expressjs/express/blob/4.21.1/History.md) - [Commits](https://github.com/expressjs/express/compare/4.20.0...4.21.1) --- updated-dependencies: - dependency-name: cookie dependency-type: indirect - dependency-name: express dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * docs(virtual_keys.md): update Dockerfile reference (#6554) Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> * (proxy fix) - call connect on prisma client when running setup (#6534) * critical fix - call connect on prisma client when running setup * fix test_proxy_server_prisma_setup * fix test_proxy_server_prisma_setup * Add 3.5 haiku (#6588) * feat: add claude-3-5-haiku-20241022 entries * feat: add claude-3-5-haiku-20241022 and vertex_ai/claude-3-5-haiku@20241022 models * add missing entries, remove vision * remove image token costs * Litellm perf improvements 3 (#6573) * perf: move writing key to cache, to background task * perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils adds 200ms on calls with pgdb connected * fix(litellm_pre_call_utils.py'): rename call_type to actual call used * perf(proxy_server.py): remove db logic from _get_config_from_file was causing db calls to occur on every llm request, if team_id was set on key * fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db reduces latency/call by ~100ms * fix(proxy_server.py): minor fix on existing_settings not incl alerting * fix(exception_mapping_utils.py): map databricks exception string * fix(auth_checks.py): fix auth check logic * test: correctly mark flaky test * fix(utils.py): handle auth token error for tokenizers.from_pretrained * build: fix map * build: fix map * build: fix json for model map * test: remove eol model * fix(proxy_server.py): fix db config loading logic * fix(proxy_server.py): fix order of config / db updates, to ensure fields not overwritten * test: skip test if required env var is missing * test: fix test --------- Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: paul-gauthier <69695708+paul-gauthier@users.noreply.github.com> * test: mark flaky test * test: handle anthropic api instability * test: update test * test: bump num retries on langfuse tests - their api is quite bad --------- Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: paul-gauthier <69695708+paul-gauthier@users.noreply.github.com>
639 lines
21 KiB
Python
639 lines
21 KiB
Python
import asyncio
|
|
import httpx
|
|
import json
|
|
import pytest
|
|
import sys
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
import os
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import litellm
|
|
from litellm.exceptions import BadRequestError
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
from litellm.utils import CustomStreamWrapper
|
|
|
|
try:
|
|
import databricks.sdk
|
|
|
|
databricks_sdk_installed = True
|
|
except ImportError:
|
|
databricks_sdk_installed = False
|
|
|
|
|
|
def mock_chat_response() -> Dict[str, Any]:
|
|
return {
|
|
"id": "chatcmpl_3f78f09a-489c-4b8d-a587-f162c7497891",
|
|
"object": "chat.completion",
|
|
"created": 1726285449,
|
|
"model": "dbrx-instruct-071224",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Hello! I'm an AI assistant. I'm doing well. How can I help?",
|
|
"function_call": None,
|
|
"tool_calls": None,
|
|
},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 230,
|
|
"completion_tokens": 38,
|
|
"completion_tokens_details": None,
|
|
"total_tokens": 268,
|
|
"prompt_tokens_details": None,
|
|
},
|
|
"system_fingerprint": None,
|
|
}
|
|
|
|
|
|
def mock_chat_streaming_response_chunks() -> List[str]:
|
|
return [
|
|
json.dumps(
|
|
{
|
|
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1726469651,
|
|
"model": "dbrx-instruct-071224",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant", "content": "Hello"},
|
|
"finish_reason": None,
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 230,
|
|
"completion_tokens": 1,
|
|
"total_tokens": 231,
|
|
},
|
|
}
|
|
),
|
|
json.dumps(
|
|
{
|
|
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1726469651,
|
|
"model": "dbrx-instruct-071224",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"content": " world"},
|
|
"finish_reason": None,
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 230,
|
|
"completion_tokens": 1,
|
|
"total_tokens": 231,
|
|
},
|
|
}
|
|
),
|
|
json.dumps(
|
|
{
|
|
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1726469651,
|
|
"model": "dbrx-instruct-071224",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"content": "!"},
|
|
"finish_reason": "stop",
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 230,
|
|
"completion_tokens": 1,
|
|
"total_tokens": 231,
|
|
},
|
|
}
|
|
),
|
|
]
|
|
|
|
|
|
def mock_chat_streaming_response_chunks_bytes() -> List[bytes]:
|
|
string_chunks = mock_chat_streaming_response_chunks()
|
|
bytes_chunks = [chunk.encode("utf-8") + b"\n" for chunk in string_chunks]
|
|
# Simulate the end of the stream
|
|
bytes_chunks.append(b"")
|
|
return bytes_chunks
|
|
|
|
|
|
def mock_http_handler_chat_streaming_response() -> MagicMock:
|
|
mock_stream_chunks = mock_chat_streaming_response_chunks()
|
|
|
|
def mock_iter_lines():
|
|
for chunk in mock_stream_chunks:
|
|
for line in chunk.splitlines():
|
|
yield line
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.iter_lines.side_effect = mock_iter_lines
|
|
mock_response.status_code = 200
|
|
|
|
return mock_response
|
|
|
|
|
|
def mock_http_handler_chat_async_streaming_response() -> MagicMock:
|
|
mock_stream_chunks = mock_chat_streaming_response_chunks()
|
|
|
|
async def mock_iter_lines():
|
|
for chunk in mock_stream_chunks:
|
|
for line in chunk.splitlines():
|
|
yield line
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.aiter_lines.return_value = mock_iter_lines()
|
|
mock_response.status_code = 200
|
|
|
|
return mock_response
|
|
|
|
|
|
def mock_databricks_client_chat_streaming_response() -> MagicMock:
|
|
mock_stream_chunks = mock_chat_streaming_response_chunks_bytes()
|
|
|
|
def mock_read_from_stream(size=-1):
|
|
if mock_stream_chunks:
|
|
return mock_stream_chunks.pop(0)
|
|
return b""
|
|
|
|
mock_response = MagicMock()
|
|
streaming_response_mock = MagicMock()
|
|
streaming_response_iterator_mock = MagicMock()
|
|
# Mock the __getitem__("content") method to return the streaming response
|
|
mock_response.__getitem__.return_value = streaming_response_mock
|
|
# Mock the streaming response __enter__ method to return the streaming response iterator
|
|
streaming_response_mock.__enter__.return_value = streaming_response_iterator_mock
|
|
|
|
streaming_response_iterator_mock.read1.side_effect = mock_read_from_stream
|
|
streaming_response_iterator_mock.closed = False
|
|
|
|
return mock_response
|
|
|
|
|
|
def mock_embedding_response() -> Dict[str, Any]:
|
|
return {
|
|
"object": "list",
|
|
"model": "bge-large-en-v1.5",
|
|
"data": [
|
|
{
|
|
"index": 0,
|
|
"object": "embedding",
|
|
"embedding": [
|
|
0.06768798828125,
|
|
-0.01291656494140625,
|
|
-0.0501708984375,
|
|
0.0245361328125,
|
|
-0.030364990234375,
|
|
],
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 8,
|
|
"total_tokens": 8,
|
|
"completion_tokens": 0,
|
|
"completion_tokens_details": None,
|
|
"prompt_tokens_details": None,
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize("set_base", [True, False])
|
|
def test_throws_if_api_base_or_api_key_not_set_without_databricks_sdk(
|
|
monkeypatch, set_base
|
|
):
|
|
# Simulate that the databricks SDK is not installed
|
|
monkeypatch.setitem(sys.modules, "databricks.sdk", None)
|
|
|
|
err_msg = "the Databricks base URL and API key are not set"
|
|
|
|
if set_base:
|
|
monkeypatch.setenv(
|
|
"DATABRICKS_API_BASE",
|
|
"https://my.workspace.cloud.databricks.com/serving-endpoints",
|
|
)
|
|
monkeypatch.delenv(
|
|
"DATABRICKS_API_KEY",
|
|
)
|
|
else:
|
|
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
|
|
monkeypatch.delenv(
|
|
"DATABRICKS_API_BASE",
|
|
)
|
|
|
|
with pytest.raises(BadRequestError) as exc:
|
|
litellm.completion(
|
|
model="databricks/dbrx-instruct-071224",
|
|
messages=[{"role": "user", "content": "How are you?"}],
|
|
)
|
|
assert err_msg in str(exc)
|
|
|
|
with pytest.raises(BadRequestError) as exc:
|
|
litellm.embedding(
|
|
model="databricks/bge-12312",
|
|
input=["Hello", "World"],
|
|
)
|
|
assert err_msg in str(exc)
|
|
|
|
|
|
def test_completions_with_sync_http_handler(monkeypatch):
|
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
|
api_key = "dapimykey"
|
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
|
|
|
sync_handler = HTTPHandler()
|
|
mock_response = Mock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_chat_response()
|
|
|
|
expected_response_json = {
|
|
**mock_chat_response(),
|
|
**{
|
|
"model": "databricks/dbrx-instruct-071224",
|
|
},
|
|
}
|
|
|
|
messages = [{"role": "user", "content": "How are you?"}]
|
|
|
|
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
|
response = litellm.completion(
|
|
model="databricks/dbrx-instruct-071224",
|
|
messages=messages,
|
|
client=sync_handler,
|
|
temperature=0.5,
|
|
extraparam="testpassingextraparam",
|
|
)
|
|
assert response.to_dict() == expected_response_json
|
|
|
|
mock_post.assert_called_once_with(
|
|
f"{base_url}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
data=json.dumps(
|
|
{
|
|
"model": "dbrx-instruct-071224",
|
|
"messages": messages,
|
|
"temperature": 0.5,
|
|
"extraparam": "testpassingextraparam",
|
|
"stream": False,
|
|
}
|
|
),
|
|
)
|
|
|
|
|
|
def test_completions_with_async_http_handler(monkeypatch):
|
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
|
api_key = "dapimykey"
|
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
|
|
|
async_handler = AsyncHTTPHandler()
|
|
mock_response = Mock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_chat_response()
|
|
|
|
expected_response_json = {
|
|
**mock_chat_response(),
|
|
**{
|
|
"model": "databricks/dbrx-instruct-071224",
|
|
},
|
|
}
|
|
|
|
messages = [{"role": "user", "content": "How are you?"}]
|
|
|
|
with patch.object(
|
|
AsyncHTTPHandler, "post", return_value=mock_response
|
|
) as mock_post:
|
|
response = asyncio.run(
|
|
litellm.acompletion(
|
|
model="databricks/dbrx-instruct-071224",
|
|
messages=messages,
|
|
client=async_handler,
|
|
temperature=0.5,
|
|
extraparam="testpassingextraparam",
|
|
)
|
|
)
|
|
assert response.to_dict() == expected_response_json
|
|
|
|
mock_post.assert_called_once_with(
|
|
f"{base_url}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
data=json.dumps(
|
|
{
|
|
"model": "dbrx-instruct-071224",
|
|
"messages": messages,
|
|
"temperature": 0.5,
|
|
"extraparam": "testpassingextraparam",
|
|
"stream": False,
|
|
}
|
|
),
|
|
)
|
|
|
|
|
|
def test_completions_streaming_with_sync_http_handler(monkeypatch):
|
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
|
api_key = "dapimykey"
|
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
|
|
|
sync_handler = HTTPHandler()
|
|
|
|
messages = [{"role": "user", "content": "How are you?"}]
|
|
mock_response = mock_http_handler_chat_streaming_response()
|
|
|
|
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
|
response_stream: CustomStreamWrapper = litellm.completion(
|
|
model="databricks/dbrx-instruct-071224",
|
|
messages=messages,
|
|
client=sync_handler,
|
|
temperature=0.5,
|
|
extraparam="testpassingextraparam",
|
|
stream=True,
|
|
)
|
|
response = list(response_stream)
|
|
assert "dbrx-instruct-071224" in str(response)
|
|
assert "chatcmpl" in str(response)
|
|
assert len(response) == 4
|
|
|
|
mock_post.assert_called_once_with(
|
|
f"{base_url}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
data=json.dumps(
|
|
{
|
|
"model": "dbrx-instruct-071224",
|
|
"messages": messages,
|
|
"temperature": 0.5,
|
|
"stream": True,
|
|
"extraparam": "testpassingextraparam",
|
|
}
|
|
),
|
|
stream=True,
|
|
)
|
|
|
|
|
|
def test_completions_streaming_with_async_http_handler(monkeypatch):
|
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
|
api_key = "dapimykey"
|
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
|
|
|
async_handler = AsyncHTTPHandler()
|
|
|
|
messages = [{"role": "user", "content": "How are you?"}]
|
|
mock_response = mock_http_handler_chat_async_streaming_response()
|
|
|
|
with patch.object(
|
|
AsyncHTTPHandler, "post", return_value=mock_response
|
|
) as mock_post:
|
|
response_stream: CustomStreamWrapper = asyncio.run(
|
|
litellm.acompletion(
|
|
model="databricks/dbrx-instruct-071224",
|
|
messages=messages,
|
|
client=async_handler,
|
|
temperature=0.5,
|
|
extraparam="testpassingextraparam",
|
|
stream=True,
|
|
)
|
|
)
|
|
|
|
# Use async list gathering for the response
|
|
async def gather_responses():
|
|
return [item async for item in response_stream]
|
|
|
|
response = asyncio.run(gather_responses())
|
|
assert "dbrx-instruct-071224" in str(response)
|
|
assert "chatcmpl" in str(response)
|
|
assert len(response) == 4
|
|
|
|
mock_post.assert_called_once_with(
|
|
f"{base_url}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
data=json.dumps(
|
|
{
|
|
"model": "dbrx-instruct-071224",
|
|
"messages": messages,
|
|
"temperature": 0.5,
|
|
"stream": True,
|
|
"extraparam": "testpassingextraparam",
|
|
}
|
|
),
|
|
stream=True,
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
|
|
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
|
|
from databricks.sdk import WorkspaceClient
|
|
from databricks.sdk.config import Config
|
|
|
|
sync_handler = HTTPHandler()
|
|
mock_response = Mock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_chat_response()
|
|
|
|
expected_response_json = {
|
|
**mock_chat_response(),
|
|
**{
|
|
"model": "databricks/dbrx-instruct-071224",
|
|
},
|
|
}
|
|
|
|
base_url = "https://my.workspace.cloud.databricks.com"
|
|
api_key = "dapimykey"
|
|
headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
}
|
|
messages = [{"role": "user", "content": "How are you?"}]
|
|
|
|
mock_workspace_client: WorkspaceClient = MagicMock()
|
|
mock_config: Config = MagicMock()
|
|
# Simulate the behavior of the config property and its methods
|
|
mock_config.authenticate.side_effect = lambda: headers
|
|
mock_config.host = base_url # Assign directly as if it's a property
|
|
mock_workspace_client.config = mock_config
|
|
|
|
with patch(
|
|
"databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
|
|
), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
|
response = litellm.completion(
|
|
model="databricks/dbrx-instruct-071224",
|
|
messages=messages,
|
|
client=sync_handler,
|
|
temperature=0.5,
|
|
extraparam="testpassingextraparam",
|
|
)
|
|
assert response.to_dict() == expected_response_json
|
|
|
|
mock_post.assert_called_once_with(
|
|
f"{base_url}/serving-endpoints/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
data=json.dumps(
|
|
{
|
|
"model": "dbrx-instruct-071224",
|
|
"messages": messages,
|
|
"temperature": 0.5,
|
|
"extraparam": "testpassingextraparam",
|
|
"stream": False,
|
|
}
|
|
),
|
|
)
|
|
|
|
|
|
def test_embeddings_with_sync_http_handler(monkeypatch):
|
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
|
api_key = "dapimykey"
|
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
|
|
|
sync_handler = HTTPHandler()
|
|
mock_response = Mock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_embedding_response()
|
|
|
|
inputs = ["Hello", "World"]
|
|
|
|
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
|
response = litellm.embedding(
|
|
model="databricks/bge-large-en-v1.5",
|
|
input=inputs,
|
|
client=sync_handler,
|
|
extraparam="testpassingextraparam",
|
|
)
|
|
assert response.to_dict() == mock_embedding_response()
|
|
|
|
mock_post.assert_called_once_with(
|
|
f"{base_url}/embeddings",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
data=json.dumps(
|
|
{
|
|
"model": "bge-large-en-v1.5",
|
|
"input": inputs,
|
|
"extraparam": "testpassingextraparam",
|
|
}
|
|
),
|
|
)
|
|
|
|
|
|
def test_embeddings_with_async_http_handler(monkeypatch):
|
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
|
api_key = "dapimykey"
|
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
|
|
|
async_handler = AsyncHTTPHandler()
|
|
mock_response = Mock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_embedding_response()
|
|
|
|
inputs = ["Hello", "World"]
|
|
|
|
with patch.object(
|
|
AsyncHTTPHandler, "post", return_value=mock_response
|
|
) as mock_post:
|
|
response = asyncio.run(
|
|
litellm.aembedding(
|
|
model="databricks/bge-large-en-v1.5",
|
|
input=inputs,
|
|
client=async_handler,
|
|
extraparam="testpassingextraparam",
|
|
)
|
|
)
|
|
assert response.to_dict() == mock_embedding_response()
|
|
|
|
mock_post.assert_called_once_with(
|
|
f"{base_url}/embeddings",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
data=json.dumps(
|
|
{
|
|
"model": "bge-large-en-v1.5",
|
|
"input": inputs,
|
|
"extraparam": "testpassingextraparam",
|
|
}
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
|
|
def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
|
|
from databricks.sdk import WorkspaceClient
|
|
from databricks.sdk.config import Config
|
|
|
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
|
api_key = "dapimykey"
|
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
|
|
|
sync_handler = HTTPHandler()
|
|
mock_response = Mock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_embedding_response()
|
|
|
|
base_url = "https://my.workspace.cloud.databricks.com"
|
|
api_key = "dapimykey"
|
|
headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
}
|
|
inputs = ["Hello", "World"]
|
|
|
|
mock_workspace_client: WorkspaceClient = MagicMock()
|
|
mock_config: Config = MagicMock()
|
|
# Simulate the behavior of the config property and its methods
|
|
mock_config.authenticate.side_effect = lambda: headers
|
|
mock_config.host = base_url # Assign directly as if it's a property
|
|
mock_workspace_client.config = mock_config
|
|
|
|
with patch(
|
|
"databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
|
|
), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
|
response = litellm.embedding(
|
|
model="databricks/bge-large-en-v1.5",
|
|
input=inputs,
|
|
client=sync_handler,
|
|
extraparam="testpassingextraparam",
|
|
)
|
|
assert response.to_dict() == mock_embedding_response()
|
|
|
|
mock_post.assert_called_once_with(
|
|
f"{base_url}/serving-endpoints/embeddings",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
data=json.dumps(
|
|
{
|
|
"model": "bge-large-en-v1.5",
|
|
"input": inputs,
|
|
"extraparam": "testpassingextraparam",
|
|
}
|
|
),
|
|
)
|