forked from phoenix/litellm-mirror
fix(openai.py): return logprobs for text completion calls
This commit is contained in:
parent
80f8645e1a
commit
b07788d2a5
6 changed files with 50459 additions and 82 deletions
|
@ -520,6 +520,9 @@ def completion(
|
|||
eos_token = kwargs.get("eos_token", None)
|
||||
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||
hf_model_name = kwargs.get("hf_model_name", None)
|
||||
### TEXT COMPLETION CALLS ###
|
||||
text_completion = kwargs.get("text_completion", False)
|
||||
atext_completion = kwargs.get("atext_completion", False)
|
||||
### ASYNC CALLS ###
|
||||
acompletion = kwargs.get("acompletion", False)
|
||||
client = kwargs.get("client", None)
|
||||
|
@ -561,6 +564,8 @@ def completion(
|
|||
litellm_params = [
|
||||
"metadata",
|
||||
"acompletion",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"caching",
|
||||
"mock_response",
|
||||
"api_key",
|
||||
|
@ -1043,8 +1048,9 @@ def completion(
|
|||
prompt = messages[0]["content"]
|
||||
else:
|
||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||
|
||||
## COMPLETION CALL
|
||||
model_response = openai_text_completions.completion(
|
||||
_response = openai_text_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
|
@ -1059,15 +1065,25 @@ def completion(
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
if (
|
||||
optional_params.get("stream", False) == False
|
||||
and acompletion == False
|
||||
and text_completion == False
|
||||
):
|
||||
# convert to chat completion response
|
||||
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
|
||||
response_object=_response, model_response_object=model_response
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=model_response,
|
||||
original_response=_response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
response = model_response
|
||||
response = _response
|
||||
elif (
|
||||
"replicate" in model
|
||||
or custom_llm_provider == "replicate"
|
||||
|
@ -2960,6 +2976,11 @@ async def atext_completion(*args, **kwargs):
|
|||
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM non blocking exception: {e}")
|
||||
|
||||
## TRANSLATE CHAT TO TEXT FORMAT ##
|
||||
if isinstance(response, TextCompletionResponse):
|
||||
return response
|
||||
|
||||
text_completion_response = TextCompletionResponse()
|
||||
text_completion_response["id"] = response.get("id", None)
|
||||
text_completion_response["object"] = "text_completion"
|
||||
|
@ -3156,7 +3177,7 @@ def text_completion(
|
|||
concurrent.futures.as_completed(futures)
|
||||
):
|
||||
responses[i] = future.result()
|
||||
text_completion_response.choices = responses
|
||||
text_completion_response.choices = responses # type: ignore
|
||||
|
||||
return text_completion_response
|
||||
# else:
|
||||
|
@ -3193,6 +3214,7 @@ def text_completion(
|
|||
)
|
||||
|
||||
kwargs.pop("prompt", None)
|
||||
kwargs["text_completion"] = True
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -3213,6 +3235,9 @@ def text_completion(
|
|||
except Exception as e:
|
||||
print_verbose(f"LiteLLM non blocking exception: {e}")
|
||||
|
||||
if isinstance(response, TextCompletionResponse):
|
||||
return response
|
||||
|
||||
text_completion_response["id"] = response.get("id", None)
|
||||
text_completion_response["object"] = "text_completion"
|
||||
text_completion_response["created"] = response.get("created", None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue