litellm/tests/llm_translation/test_text_completion_unit_tests.py
Krish Dholakia 4b9c66ea59 LiteLLM Minor Fixes & Improvements (11/29/2024) (#6965)
* fix(factory.py): ensure tool call converts image url

Fixes https://github.com/BerriAI/litellm/issues/6953

* fix(transformation.py): support mp4 + pdf url's for vertex ai

Fixes https://github.com/BerriAI/litellm/issues/6936

* fix(http_handler.py): mask gemini api key in error logs

Fixes https://github.com/BerriAI/litellm/issues/6963

* docs(prometheus.md): update prometheus FAQs

* feat(auth_checks.py): ensure specific model access > wildcard model access

if wildcard model is in access group, but specific model is not - deny access

* fix(auth_checks.py): handle auth checks for team based model access groups

handles scenario where model access group used for wildcard models

* fix(internal_user_endpoints.py): support adding guardrails on `/user/update`

Fixes https://github.com/BerriAI/litellm/issues/6942

* fix(key_management_endpoints.py): fix prepare_metadata_fields helper

* fix: fix tests

* build(requirements.txt): bump openai dep version

fixes proxies argument

* test: fix tests

* fix(http_handler.py): fix error message masking

* fix(bedrock_guardrails.py): pass in prepped data

* test: fix test

* test: fix nvidia nim test

* fix(http_handler.py): return original response headers

* fix: revert maskedhttpstatuserror

* test: update tests

* test: cleanup test

* fix(key_management_endpoints.py): fix metadata field update logic

* fix(key_management_endpoints.py): maintain initial order of guardrails in key update

* fix(key_management_endpoints.py): handle prepare metadata

* fix: fix linting errors

* fix: fix linting errors

* fix: fix linting errors

* fix: fix key management errors

* fix(key_management_endpoints.py): update metadata

* test: update test

* refactor: add more debug statements

* test: skip flaky test

* test: fix test

* fix: fix test

* fix: fix update metadata logic

* fix: fix test

* ci(config.yml): change db url for e2e ui testing
2024-12-01 05:26:06 -08:00

142 lines
4.9 KiB
Python

import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
import pytest
import httpx
from respx import MockRouter
from unittest.mock import patch, MagicMock, AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.types.utils import TextCompletionResponse
def test_convert_dict_to_text_completion_response():
input_dict = {
"id": "cmpl-ALVLPJgRkqpTomotoOMi3j0cAaL4L",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": {
"text_offset": [0, 5],
"token_logprobs": [None, -12.203847],
"tokens": ["hello", " crisp"],
"top_logprobs": [None, {",": -2.1568563}],
},
"text": "hello crisp",
}
],
"created": 1729688739,
"model": "davinci-002",
"object": "text_completion",
"system_fingerprint": None,
"usage": {
"completion_tokens": 1,
"prompt_tokens": 1,
"total_tokens": 2,
"completion_tokens_details": None,
"prompt_tokens_details": None,
},
}
response = TextCompletionResponse(**input_dict)
assert response.id == "cmpl-ALVLPJgRkqpTomotoOMi3j0cAaL4L"
assert len(response.choices) == 1
assert response.choices[0].finish_reason == "length"
assert response.choices[0].index == 0
assert response.choices[0].text == "hello crisp"
assert response.created == 1729688739
assert response.model == "davinci-002"
assert response.object == "text_completion"
assert response.system_fingerprint is None
assert response.usage.completion_tokens == 1
assert response.usage.prompt_tokens == 1
assert response.usage.total_tokens == 2
assert response.usage.completion_tokens_details is None
assert response.usage.prompt_tokens_details is None
# Test logprobs
assert response.choices[0].logprobs.text_offset == [0, 5]
assert response.choices[0].logprobs.token_logprobs == [None, -12.203847]
assert response.choices[0].logprobs.tokens == ["hello", " crisp"]
assert response.choices[0].logprobs.top_logprobs == [None, {",": -2.1568563}]
@pytest.mark.skip(
reason="need to migrate huggingface to support httpx client being passed in"
)
@pytest.mark.asyncio
@pytest.mark.respx
async def test_huggingface_text_completion_logprobs():
"""Test text completion with Hugging Face, focusing on logprobs structure"""
litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
mock_response = [
{
"generated_text": ",\n\nI have a question...", # truncated for brevity
"details": {
"finish_reason": "length",
"generated_tokens": 100,
"seed": None,
"prefill": [],
"tokens": [
{"id": 28725, "text": ",", "logprob": -1.7626953, "special": False},
{"id": 13, "text": "\n", "logprob": -1.7314453, "special": False},
],
},
}
]
return_val = AsyncMock()
return_val.json.return_value = mock_response
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=return_val) as mock_post:
response = await litellm.atext_completion(
model="huggingface/mistralai/Mistral-7B-v0.1",
prompt="good morning",
client=client,
)
# Verify the request
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
assert request_body == {
"inputs": "good morning",
"parameters": {"details": True, "return_full_text": False},
"stream": False,
}
print("response=", response)
# Verify response structure
assert isinstance(response, TextCompletionResponse)
assert response.object == "text_completion"
assert response.model == "mistralai/Mistral-7B-v0.1"
# Verify logprobs structure
choice = response.choices[0]
assert choice.finish_reason == "length"
assert choice.index == 0
assert isinstance(choice.logprobs.tokens, list)
assert isinstance(choice.logprobs.token_logprobs, list)
assert isinstance(choice.logprobs.text_offset, list)
assert isinstance(choice.logprobs.top_logprobs, list)
assert choice.logprobs.tokens == [",", "\n"]
assert choice.logprobs.token_logprobs == [-1.7626953, -1.7314453]
assert choice.logprobs.text_offset == [0, 1]
assert choice.logprobs.top_logprobs == [{}, {}]
# Verify usage
assert response.usage["completion_tokens"] > 0
assert response.usage["prompt_tokens"] > 0
assert response.usage["total_tokens"] > 0