(fix) litellm.text_completion raises a non-blocking error on simple usage (#6546)

* unit test test_huggingface_text_completion_logprobs

* fix return TextCompletionHandler convert_chat_to_text_completion

* fix hf rest api

* fix test_huggingface_text_completion_logprobs

* fix linting errors

* fix importLiteLLMResponseObjectHandler

* fix test for LiteLLMResponseObjectHandler

* fix test text completion
This commit is contained in:
Ishaan Jaff 2024-11-05 05:17:48 +05:30 committed by GitHub
parent 67ddf55ebd
commit 58ce30acee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 374 additions and 111 deletions

View file

@ -14,11 +14,17 @@ from litellm.types.utils import (
Delta, Delta,
EmbeddingResponse, EmbeddingResponse,
Function, Function,
HiddenParams,
ImageResponse, ImageResponse,
)
from litellm.types.utils import Logprobs as TextCompletionLogprobs
from litellm.types.utils import (
Message, Message,
ModelResponse, ModelResponse,
RerankResponse, RerankResponse,
StreamingChoices, StreamingChoices,
TextChoices,
TextCompletionResponse,
TranscriptionResponse, TranscriptionResponse,
Usage, Usage,
) )
@ -235,6 +241,77 @@ class LiteLLMResponseObjectHandler:
model_response_object = ImageResponse(**model_response_dict) model_response_object = ImageResponse(**model_response_dict)
return model_response_object return model_response_object
@staticmethod
def convert_chat_to_text_completion(
response: ModelResponse,
text_completion_response: TextCompletionResponse,
custom_llm_provider: Optional[str] = None,
) -> TextCompletionResponse:
"""
Converts a chat completion response to a text completion response format.
Note: This is used for huggingface. For OpenAI / Azure Text the providers files directly return TextCompletionResponse which we then send to user
Args:
response (ModelResponse): The chat completion response to convert
Returns:
TextCompletionResponse: The converted text completion response
Example:
chat_response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hi"}])
text_response = convert_chat_to_text_completion(chat_response)
"""
transformed_logprobs = LiteLLMResponseObjectHandler._convert_provider_response_logprobs_to_text_completion_logprobs(
response=response,
custom_llm_provider=custom_llm_provider,
)
text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None)
text_completion_response["model"] = response.get("model", None)
choices_list: List[TextChoices] = []
# Convert each choice to TextChoices
for choice in response["choices"]:
text_choices = TextChoices()
text_choices["text"] = choice["message"]["content"]
text_choices["index"] = choice["index"]
text_choices["logprobs"] = transformed_logprobs
text_choices["finish_reason"] = choice["finish_reason"]
choices_list.append(text_choices)
text_completion_response["choices"] = choices_list
text_completion_response["usage"] = response.get("usage", None)
text_completion_response._hidden_params = HiddenParams(
**response._hidden_params
)
return text_completion_response
@staticmethod
def _convert_provider_response_logprobs_to_text_completion_logprobs(
response: ModelResponse,
custom_llm_provider: Optional[str] = None,
) -> Optional[TextCompletionLogprobs]:
"""
Convert logprobs from provider to OpenAI.Completion() format
Only supported for HF TGI models
"""
transformed_logprobs: Optional[TextCompletionLogprobs] = None
if custom_llm_provider == "huggingface":
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.huggingface._transform_logprobs(
hf_response=raw_response
)
except Exception as e:
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
return transformed_logprobs
def convert_to_model_response_object( # noqa: PLR0915 def convert_to_model_response_object( # noqa: PLR0915
response_object: Optional[dict] = None, response_object: Optional[dict] = None,

View file

@ -15,6 +15,7 @@ import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.completion import ChatCompletionMessageToolCallParam from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.types.utils import Logprobs as TextCompletionLogprobs
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from .base import BaseLLM from .base import BaseLLM
@ -1183,3 +1184,73 @@ class Huggingface(BaseLLM):
input=input, input=input,
encoding=encoding, encoding=encoding,
) )
def _transform_logprobs(
self, hf_response: Optional[List]
) -> Optional[TextCompletionLogprobs]:
"""
Transform Hugging Face logprobs to OpenAI.Completion() format
"""
if hf_response is None:
return None
# Initialize an empty list for the transformed logprobs
_logprob: TextCompletionLogprobs = TextCompletionLogprobs(
text_offset=[],
token_logprobs=[],
tokens=[],
top_logprobs=[],
)
# For each Hugging Face response, transform the logprobs
for response in hf_response:
# Extract the relevant information from the response
response_details = response["details"]
top_tokens = response_details.get("top_tokens", {})
for i, token in enumerate(response_details["prefill"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
# Add the token information to the 'token_info' list
_logprob.tokens.append(token_text)
_logprob.token_logprobs.append(token_logprob)
# stub this to work with llm eval harness
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
_logprob.top_logprobs.append(top_alt_tokens)
# For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(response_details["tokens"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
top_alt_tokens = {}
temp_top_logprobs = []
if top_tokens != {}:
temp_top_logprobs = top_tokens[i]
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
for elem in temp_top_logprobs:
text = elem["text"]
logprob = elem["logprob"]
top_alt_tokens[text] = logprob
# Add the token information to the 'token_info' list
_logprob.tokens.append(token_text)
_logprob.token_logprobs.append(token_logprob)
_logprob.top_logprobs.append(top_alt_tokens)
# Add the text offset of the token
# This is computed as the sum of the lengths of all previous tokens
_logprob.text_offset.append(
sum(len(t["text"]) for t in response_details["tokens"][:i])
)
return _logprob

View file

@ -3867,34 +3867,17 @@ async def atext_completion(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
else: else:
transformed_logprobs = None ## OpenAI / Azure Text Completion Returns here
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}")
## TRANSLATE CHAT TO TEXT FORMAT ##
if isinstance(response, TextCompletionResponse): if isinstance(response, TextCompletionResponse):
return response return response
elif asyncio.iscoroutine(response): elif asyncio.iscoroutine(response):
response = await response response = await response
text_completion_response = TextCompletionResponse() text_completion_response = TextCompletionResponse()
text_completion_response["id"] = response.get("id", None) text_completion_response = litellm.utils.LiteLLMResponseObjectHandler.convert_chat_to_text_completion(
text_completion_response["object"] = "text_completion" text_completion_response=text_completion_response,
text_completion_response["created"] = response.get("created", None) response=response,
text_completion_response["model"] = response.get("model", None) custom_llm_provider=custom_llm_provider,
text_choices = TextChoices()
text_choices["text"] = response["choices"][0]["message"]["content"]
text_choices["index"] = response["choices"][0]["index"]
text_choices["logprobs"] = transformed_logprobs
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
text_completion_response["choices"] = [text_choices]
text_completion_response["usage"] = response.get("usage", None)
text_completion_response._hidden_params = HiddenParams(
**response._hidden_params
) )
return text_completion_response return text_completion_response
except Exception as e: except Exception as e:
@ -4156,29 +4139,17 @@ def text_completion( # noqa: PLR0915
return response return response
elif isinstance(response, TextCompletionStreamWrapper): elif isinstance(response, TextCompletionStreamWrapper):
return response return response
transformed_logprobs = None
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e:
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
# OpenAI Text / Azure Text will return here
if isinstance(response, TextCompletionResponse): if isinstance(response, TextCompletionResponse):
return response return response
text_completion_response["id"] = response.get("id", None) text_completion_response = (
text_completion_response["object"] = "text_completion" litellm.utils.LiteLLMResponseObjectHandler.convert_chat_to_text_completion(
text_completion_response["created"] = response.get("created", None) response=response,
text_completion_response["model"] = response.get("model", None) text_completion_response=text_completion_response,
text_choices = TextChoices() )
text_choices["text"] = response["choices"][0]["message"]["content"] )
text_choices["index"] = response["choices"][0]["index"]
text_choices["logprobs"] = transformed_logprobs
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
text_completion_response["choices"] = [text_choices]
text_completion_response["usage"] = response.get("usage", None)
text_completion_response._hidden_params = HiddenParams(**response._hidden_params)
return text_completion_response return text_completion_response

View file

@ -71,6 +71,7 @@ from litellm.litellm_core_utils.get_llm_provider_logic import (
) )
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import ( from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
LiteLLMResponseObjectHandler,
_handle_invalid_parallel_tool_calls, _handle_invalid_parallel_tool_calls,
convert_to_model_response_object, convert_to_model_response_object,
convert_to_streaming_response, convert_to_streaming_response,
@ -8388,76 +8389,6 @@ def get_valid_models() -> List[str]:
return [] # NON-Blocking return [] # NON-Blocking
# used for litellm.text_completion() to transform HF logprobs to OpenAI.Completion() format
def transform_logprobs(hf_response):
# Initialize an empty list for the transformed logprobs
transformed_logprobs = []
# For each Hugging Face response, transform the logprobs
for response in hf_response:
# Extract the relevant information from the response
response_details = response["details"]
top_tokens = response_details.get("top_tokens", {})
# Initialize an empty list for the token information
token_info = {
"tokens": [],
"token_logprobs": [],
"text_offset": [],
"top_logprobs": [],
}
for i, token in enumerate(response_details["prefill"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
# Add the token information to the 'token_info' list
token_info["tokens"].append(token_text)
token_info["token_logprobs"].append(token_logprob)
# stub this to work with llm eval harness
top_alt_tokens = {"": -1, "": -2, "": -3} # noqa: F601
token_info["top_logprobs"].append(top_alt_tokens)
# For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(response_details["tokens"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
top_alt_tokens = {}
temp_top_logprobs = []
if top_tokens != {}:
temp_top_logprobs = top_tokens[i]
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
for elem in temp_top_logprobs:
text = elem["text"]
logprob = elem["logprob"]
top_alt_tokens[text] = logprob
# Add the token information to the 'token_info' list
token_info["tokens"].append(token_text)
token_info["token_logprobs"].append(token_logprob)
token_info["top_logprobs"].append(top_alt_tokens)
# Add the text offset of the token
# This is computed as the sum of the lengths of all previous tokens
token_info["text_offset"].append(
sum(len(t["text"]) for t in response_details["tokens"][:i])
)
# Add the 'token_info' list to the 'transformed_logprobs' list
transformed_logprobs = token_info
return transformed_logprobs
def print_args_passed_to_litellm(original_function, args, kwargs): def print_args_passed_to_litellm(original_function, args, kwargs):
try: try:
# we've already printed this for acompletion, don't print for completion # we've already printed this for acompletion, don't print for completion

View file

@ -0,0 +1,141 @@
import json
import os
import sys
from datetime import datetime
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
import pytest
from litellm.utils import (
LiteLLMResponseObjectHandler,
)
from datetime import timedelta
from litellm.types.utils import (
ModelResponse,
TextCompletionResponse,
TextChoices,
Logprobs as TextCompletionLogprobs,
Usage,
)
def test_convert_chat_to_text_completion():
"""Test converting chat completion to text completion"""
chat_response = ModelResponse(
id="chat123",
created=1234567890,
model="gpt-3.5-turbo",
choices=[
{
"index": 0,
"message": {"content": "Hello, world!"},
"finish_reason": "stop",
}
],
usage={"total_tokens": 10, "completion_tokens": 10},
_hidden_params={"api_key": "test"},
)
text_completion = TextCompletionResponse()
result = LiteLLMResponseObjectHandler.convert_chat_to_text_completion(
response=chat_response, text_completion_response=text_completion
)
assert isinstance(result, TextCompletionResponse)
assert result.id == "chat123"
assert result.object == "text_completion"
assert result.created == 1234567890
assert result.model == "gpt-3.5-turbo"
assert result.choices[0].text == "Hello, world!"
assert result.choices[0].finish_reason == "stop"
assert result.usage == Usage(
completion_tokens=10,
prompt_tokens=0,
total_tokens=10,
completion_tokens_details=None,
prompt_tokens_details=None,
)
def test_convert_provider_response_logprobs():
"""Test converting provider logprobs to text completion logprobs"""
response = ModelResponse(
id="test123",
_hidden_params={
"original_response": {
"details": {"tokens": [{"text": "hello", "logprob": -1.0}]}
}
},
)
result = LiteLLMResponseObjectHandler._convert_provider_response_logprobs_to_text_completion_logprobs(
response=response, custom_llm_provider="huggingface"
)
# Note: The actual assertion here depends on the implementation of
# litellm.huggingface._transform_logprobs, but we can at least test the function call
assert (
result is not None or result is None
) # Will depend on the actual implementation
def test_convert_provider_response_logprobs_non_huggingface():
"""Test converting provider logprobs for non-huggingface provider"""
response = ModelResponse(id="test123", _hidden_params={})
result = LiteLLMResponseObjectHandler._convert_provider_response_logprobs_to_text_completion_logprobs(
response=response, custom_llm_provider="openai"
)
assert result is None
def test_convert_chat_to_text_completion_multiple_choices():
"""Test converting chat completion to text completion with multiple choices"""
chat_response = ModelResponse(
id="chat456",
created=1234567890,
model="gpt-3.5-turbo",
choices=[
{
"index": 0,
"message": {"content": "First response"},
"finish_reason": "stop",
},
{
"index": 1,
"message": {"content": "Second response"},
"finish_reason": "length",
},
],
usage={"total_tokens": 20},
_hidden_params={"api_key": "test"},
)
text_completion = TextCompletionResponse()
result = LiteLLMResponseObjectHandler.convert_chat_to_text_completion(
response=chat_response, text_completion_response=text_completion
)
assert isinstance(result, TextCompletionResponse)
assert result.id == "chat456"
assert result.object == "text_completion"
assert len(result.choices) == 2
assert result.choices[0].text == "First response"
assert result.choices[0].finish_reason == "stop"
assert result.choices[1].text == "Second response"
assert result.choices[1].finish_reason == "length"
assert result.usage == Usage(
completion_tokens=0,
prompt_tokens=0,
total_tokens=20,
completion_tokens_details=None,
prompt_tokens_details=None,
)

View file

@ -3,11 +3,15 @@ import os
import sys import sys
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest
import httpx
from respx import MockRouter
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm
from litellm.types.utils import TextCompletionResponse from litellm.types.utils import TextCompletionResponse
@ -62,3 +66,71 @@ def test_convert_dict_to_text_completion_response():
assert response.choices[0].logprobs.token_logprobs == [None, -12.203847] assert response.choices[0].logprobs.token_logprobs == [None, -12.203847]
assert response.choices[0].logprobs.tokens == ["hello", " crisp"] assert response.choices[0].logprobs.tokens == ["hello", " crisp"]
assert response.choices[0].logprobs.top_logprobs == [None, {",": -2.1568563}] assert response.choices[0].logprobs.top_logprobs == [None, {",": -2.1568563}]
@pytest.mark.asyncio
@pytest.mark.respx
async def test_huggingface_text_completion_logprobs(respx_mock: MockRouter):
"""Test text completion with Hugging Face, focusing on logprobs structure"""
litellm.set_verbose = True
# Mock the raw response from Hugging Face
mock_response = [
{
"generated_text": ",\n\nI have a question...", # truncated for brevity
"details": {
"finish_reason": "length",
"generated_tokens": 100,
"seed": None,
"prefill": [],
"tokens": [
{"id": 28725, "text": ",", "logprob": -1.7626953, "special": False},
{"id": 13, "text": "\n", "logprob": -1.7314453, "special": False},
],
},
}
]
# Mock the API request
mock_request = respx_mock.post(
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
).mock(return_value=httpx.Response(200, json=mock_response))
response = await litellm.atext_completion(
model="huggingface/mistralai/Mistral-7B-v0.1",
prompt="good morning",
)
# Verify the request
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
assert request_body == {
"inputs": "good morning",
"parameters": {"details": True, "return_full_text": False},
"stream": False,
}
print("response=", response)
# Verify response structure
assert isinstance(response, TextCompletionResponse)
assert response.object == "text_completion"
assert response.model == "mistralai/Mistral-7B-v0.1"
# Verify logprobs structure
choice = response.choices[0]
assert choice.finish_reason == "length"
assert choice.index == 0
assert isinstance(choice.logprobs.tokens, list)
assert isinstance(choice.logprobs.token_logprobs, list)
assert isinstance(choice.logprobs.text_offset, list)
assert isinstance(choice.logprobs.top_logprobs, list)
assert choice.logprobs.tokens == [",", "\n"]
assert choice.logprobs.token_logprobs == [-1.7626953, -1.7314453]
assert choice.logprobs.text_offset == [0, 1]
assert choice.logprobs.top_logprobs == [{}, {}]
# Verify usage
assert response.usage["completion_tokens"] > 0
assert response.usage["prompt_tokens"] > 0
assert response.usage["total_tokens"] > 0