mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) * coverage (#5713) Signed-off-by: dbczumar <corey.zumar@databricks.com> * Move (#5714) Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix(litellm_logging.py): fix logging client re-init (#5710) Fixes https://github.com/BerriAI/litellm/issues/5695 * fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config Fixes https://github.com/BerriAI/litellm/issues/5682 * feat(o1_handler.py): fake streaming for openai o1 models Fixes https://github.com/BerriAI/litellm/issues/5694 * docs: deprecated traceloop integration in favor of native otel (#5249) * fix: fix linting errors * fix: fix linting errors * fix(main.py): fix o1 import --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com> * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730) * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it * fix(custom_logger.py): reset calltype * fix: fix linting errors * fix: fix linting error * fix: fix import * test(test_databricks.py): fix databricks tests --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
This commit is contained in:
parent
1e59395280
commit
234185ec13
34 changed files with 1387 additions and 502 deletions
502
tests/llm_translation/test_databricks.py
Normal file
502
tests/llm_translation/test_databricks.py
Normal file
|
@ -0,0 +1,502 @@
|
|||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError, InternalServerError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
|
||||
def mock_chat_response() -> Dict[str, Any]:
|
||||
return {
|
||||
"id": "chatcmpl_3f78f09a-489c-4b8d-a587-f162c7497891",
|
||||
"object": "chat.completion",
|
||||
"created": 1726285449,
|
||||
"model": "dbrx-instruct-071224",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I'm an AI assistant. I'm doing well. How can I help?",
|
||||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 230,
|
||||
"completion_tokens": 38,
|
||||
"total_tokens": 268,
|
||||
"completion_tokens_details": None,
|
||||
},
|
||||
"system_fingerprint": None,
|
||||
}
|
||||
|
||||
|
||||
def mock_chat_streaming_response_chunks() -> List[str]:
|
||||
return [
|
||||
json.dumps(
|
||||
{
|
||||
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1726469651,
|
||||
"model": "dbrx-instruct-071224",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "Hello"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 230,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 231,
|
||||
},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1726469651,
|
||||
"model": "dbrx-instruct-071224",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": " world"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 230,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 231,
|
||||
},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1726469651,
|
||||
"model": "dbrx-instruct-071224",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": "!"},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 230,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 231,
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def mock_chat_streaming_response_chunks_bytes() -> List[bytes]:
|
||||
string_chunks = mock_chat_streaming_response_chunks()
|
||||
bytes_chunks = [chunk.encode("utf-8") + b"\n" for chunk in string_chunks]
|
||||
# Simulate the end of the stream
|
||||
bytes_chunks.append(b"")
|
||||
return bytes_chunks
|
||||
|
||||
|
||||
def mock_http_handler_chat_streaming_response() -> MagicMock:
|
||||
mock_stream_chunks = mock_chat_streaming_response_chunks()
|
||||
|
||||
def mock_iter_lines():
|
||||
for chunk in mock_stream_chunks:
|
||||
for line in chunk.splitlines():
|
||||
yield line
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.side_effect = mock_iter_lines
|
||||
mock_response.status_code = 200
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
def mock_http_handler_chat_async_streaming_response() -> MagicMock:
|
||||
mock_stream_chunks = mock_chat_streaming_response_chunks()
|
||||
|
||||
async def mock_iter_lines():
|
||||
for chunk in mock_stream_chunks:
|
||||
for line in chunk.splitlines():
|
||||
yield line
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.aiter_lines.return_value = mock_iter_lines()
|
||||
mock_response.status_code = 200
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
def mock_databricks_client_chat_streaming_response() -> MagicMock:
|
||||
mock_stream_chunks = mock_chat_streaming_response_chunks_bytes()
|
||||
|
||||
def mock_read_from_stream(size=-1):
|
||||
if mock_stream_chunks:
|
||||
return mock_stream_chunks.pop(0)
|
||||
return b""
|
||||
|
||||
mock_response = MagicMock()
|
||||
streaming_response_mock = MagicMock()
|
||||
streaming_response_iterator_mock = MagicMock()
|
||||
# Mock the __getitem__("content") method to return the streaming response
|
||||
mock_response.__getitem__.return_value = streaming_response_mock
|
||||
# Mock the streaming response __enter__ method to return the streaming response iterator
|
||||
streaming_response_mock.__enter__.return_value = streaming_response_iterator_mock
|
||||
|
||||
streaming_response_iterator_mock.read1.side_effect = mock_read_from_stream
|
||||
streaming_response_iterator_mock.closed = False
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
def mock_embedding_response() -> Dict[str, Any]:
|
||||
return {
|
||||
"object": "list",
|
||||
"model": "bge-large-en-v1.5",
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
0.06768798828125,
|
||||
-0.01291656494140625,
|
||||
-0.0501708984375,
|
||||
0.0245361328125,
|
||||
-0.030364990234375,
|
||||
],
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 8,
|
||||
"total_tokens": 8,
|
||||
"completion_tokens": 0,
|
||||
"completion_tokens_details": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("set_base", [True, False])
|
||||
def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
|
||||
if set_base:
|
||||
monkeypatch.setenv(
|
||||
"DATABRICKS_API_BASE",
|
||||
"https://my.workspace.cloud.databricks.com/serving-endpoints",
|
||||
)
|
||||
monkeypatch.delenv(
|
||||
"DATABRICKS_API_KEY",
|
||||
)
|
||||
err_msg = "A call is being made to LLM Provider but no key is set"
|
||||
else:
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
|
||||
monkeypatch.delenv("DATABRICKS_API_BASE")
|
||||
err_msg = "A call is being made to LLM Provider but no api base is set"
|
||||
|
||||
with pytest.raises(BadRequestError) as exc:
|
||||
litellm.completion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages={"role": "user", "content": "How are you?"},
|
||||
)
|
||||
assert err_msg in str(exc)
|
||||
|
||||
with pytest.raises(BadRequestError) as exc:
|
||||
litellm.embedding(
|
||||
model="databricks/bge-12312",
|
||||
input=["Hello", "World"],
|
||||
)
|
||||
assert err_msg in str(exc)
|
||||
|
||||
|
||||
def test_completions_with_sync_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
sync_handler = HTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_chat_response()
|
||||
|
||||
expected_response_json = {
|
||||
**mock_chat_response(),
|
||||
**{
|
||||
"model": "databricks/dbrx-instruct-071224",
|
||||
},
|
||||
}
|
||||
|
||||
messages = [{"role": "user", "content": "How are you?"}]
|
||||
|
||||
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||
response = litellm.completion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages=messages,
|
||||
client=sync_handler,
|
||||
temperature=0.5,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
assert response.to_dict() == expected_response_json
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"extraparam": "testpassingextraparam",
|
||||
"stream": False,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_completions_with_async_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
async_handler = AsyncHTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_chat_response()
|
||||
|
||||
expected_response_json = {
|
||||
**mock_chat_response(),
|
||||
**{
|
||||
"model": "databricks/dbrx-instruct-071224",
|
||||
},
|
||||
}
|
||||
|
||||
messages = [{"role": "user", "content": "How are you?"}]
|
||||
|
||||
with patch.object(
|
||||
AsyncHTTPHandler, "post", return_value=mock_response
|
||||
) as mock_post:
|
||||
response = asyncio.run(
|
||||
litellm.acompletion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages=messages,
|
||||
client=async_handler,
|
||||
temperature=0.5,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
)
|
||||
assert response.to_dict() == expected_response_json
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"extraparam": "testpassingextraparam",
|
||||
"stream": False,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_completions_streaming_with_sync_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
sync_handler = HTTPHandler()
|
||||
|
||||
messages = [{"role": "user", "content": "How are you?"}]
|
||||
mock_response = mock_http_handler_chat_streaming_response()
|
||||
|
||||
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||
response_stream: CustomStreamWrapper = litellm.completion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages=messages,
|
||||
client=sync_handler,
|
||||
temperature=0.5,
|
||||
extraparam="testpassingextraparam",
|
||||
stream=True,
|
||||
)
|
||||
response = list(response_stream)
|
||||
assert "dbrx-instruct-071224" in str(response)
|
||||
assert "chatcmpl" in str(response)
|
||||
assert len(response) == 4
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def test_completions_streaming_with_async_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
async_handler = AsyncHTTPHandler()
|
||||
|
||||
messages = [{"role": "user", "content": "How are you?"}]
|
||||
mock_response = mock_http_handler_chat_async_streaming_response()
|
||||
|
||||
with patch.object(
|
||||
AsyncHTTPHandler, "post", return_value=mock_response
|
||||
) as mock_post:
|
||||
response_stream: CustomStreamWrapper = asyncio.run(
|
||||
litellm.acompletion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages=messages,
|
||||
client=async_handler,
|
||||
temperature=0.5,
|
||||
extraparam="testpassingextraparam",
|
||||
stream=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Use async list gathering for the response
|
||||
async def gather_responses():
|
||||
return [item async for item in response_stream]
|
||||
|
||||
response = asyncio.run(gather_responses())
|
||||
assert "dbrx-instruct-071224" in str(response)
|
||||
assert "chatcmpl" in str(response)
|
||||
assert len(response) == 4
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def test_embeddings_with_sync_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
sync_handler = HTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_embedding_response()
|
||||
|
||||
inputs = ["Hello", "World"]
|
||||
|
||||
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||
response = litellm.embedding(
|
||||
model="databricks/bge-large-en-v1.5",
|
||||
input=inputs,
|
||||
client=sync_handler,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
assert response.to_dict() == mock_embedding_response()
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/embeddings",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "bge-large-en-v1.5",
|
||||
"input": inputs,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_embeddings_with_async_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
async_handler = AsyncHTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_embedding_response()
|
||||
|
||||
inputs = ["Hello", "World"]
|
||||
|
||||
with patch.object(
|
||||
AsyncHTTPHandler, "post", return_value=mock_response
|
||||
) as mock_post:
|
||||
response = asyncio.run(
|
||||
litellm.aembedding(
|
||||
model="databricks/bge-large-en-v1.5",
|
||||
input=inputs,
|
||||
client=async_handler,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
)
|
||||
assert response.to_dict() == mock_embedding_response()
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/embeddings",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "bge-large-en-v1.5",
|
||||
"input": inputs,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue