fix: bug fix when n>1 passed in

This commit is contained in:
Krrish Dholakia 2023-10-09 16:46:18 -07:00
parent 2004b449e8
commit 253e8d27db
8 changed files with 119 additions and 43 deletions

View file

@ -2,9 +2,9 @@ import os, types
import json
from enum import Enum
import requests
import time
import time, traceback
from typing import Callable, Optional
from litellm.utils import ModelResponse
from litellm.utils import ModelResponse, Choices, Message
import litellm
class CohereError(Exception):
@ -156,11 +156,16 @@ def completion(
)
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response["generations"][0]["text"]
except:
raise CohereError(message=json.dumps(completion_response), status_code=response.status_code)
choices_list = []
for idx, item in enumerate(completion_response["generations"]):
message_obj = Message(content=item["text"])
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise CohereError(message=traceback.format_exc(), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)