fix(openai.py): return logprobs for text completion calls

This commit is contained in:
Krrish Dholakia 2024-04-02 14:05:56 -07:00
parent 80f8645e1a
commit b07788d2a5
6 changed files with 50459 additions and 82 deletions

View file

@ -520,6 +520,9 @@ def completion(
eos_token = kwargs.get("eos_token", None)
preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None)
### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False)
### ASYNC CALLS ###
acompletion = kwargs.get("acompletion", False)
client = kwargs.get("client", None)
@ -561,6 +564,8 @@ def completion(
litellm_params = [
"metadata",
"acompletion",
"atext_completion",
"text_completion",
"caching",
"mock_response",
"api_key",
@ -1043,8 +1048,9 @@ def completion(
prompt = messages[0]["content"]
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
## COMPLETION CALL
model_response = openai_text_completions.completion(
_response = openai_text_completions.completion(
model=model,
messages=messages,
model_response=model_response,
@ -1059,15 +1065,25 @@ def completion(
timeout=timeout,
)
if (
optional_params.get("stream", False) == False
and acompletion == False
and text_completion == False
):
# convert to chat completion response
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=_response, model_response_object=model_response
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=model_response,
original_response=_response,
additional_args={"headers": headers},
)
response = model_response
response = _response
elif (
"replicate" in model
or custom_llm_provider == "replicate"
@ -2960,6 +2976,11 @@ async def atext_completion(*args, **kwargs):
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}")
## TRANSLATE CHAT TO TEXT FORMAT ##
if isinstance(response, TextCompletionResponse):
return response
text_completion_response = TextCompletionResponse()
text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion"
@ -3156,7 +3177,7 @@ def text_completion(
concurrent.futures.as_completed(futures)
):
responses[i] = future.result()
text_completion_response.choices = responses
text_completion_response.choices = responses # type: ignore
return text_completion_response
# else:
@ -3193,6 +3214,7 @@ def text_completion(
)
kwargs.pop("prompt", None)
kwargs["text_completion"] = True
response = completion(
model=model,
messages=messages,
@ -3213,6 +3235,9 @@ def text_completion(
except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}")
if isinstance(response, TextCompletionResponse):
return response
text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None)