fix(utils.py): fix embedding response output parsing

This commit is contained in:
Krrish Dholakia 2023-11-25 12:06:47 -08:00
parent 62d8f9ad2a
commit dac76a4861
3 changed files with 102 additions and 53 deletions

View file

@ -336,14 +336,45 @@ class ModelResponse(OpenAIObject):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class Embedding(OpenAIObject):
embedding: list = []
index: int
object: str
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class EmbeddingResponse(OpenAIObject):
def __init__(self, model=None, usage=None, stream=False, response_ms=None):
data: Optional[Embedding] = None
"""The actual embedding value"""
usage: Optional[Usage] = None
"""Usage statistics for the completion request."""
def __init__(self, model=None, usage=None, stream=False, response_ms=None, data=None):
object = "list"
if response_ms:
_response_ms = response_ms
else:
_response_ms = None
data = []
if data:
data = data
else:
data = None
if usage:
usage = usage
else:
usage = Usage()
model = model
super().__init__(model=model, object=object, data=data, usage=usage)
@ -1239,9 +1270,12 @@ def client(original_function):
pass
else:
call_type = original_function.__name__
print_verbose(f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}")
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
else:
elif call_type == CallTypes.embedding.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, response_type="embedding")
else:
return cached_result
# MODEL CALL
result = original_function(*args, **kwargs)
@ -3234,40 +3268,71 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
pass
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion"):
try:
if response_object is None or model_response_object is None:
raise Exception("Error in response object format")
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["message"].get("content", None),
role=choice["message"]["role"],
function_call=choice["message"].get("function_call", None),
tool_calls=choice["message"].get("tool_calls", None)
)
finish_reason = choice.get("finish_reason", None)
if finish_reason == None:
# gpt-4 vision can return 'finish_reason' or 'finish_details'
finish_reason = choice.get("finish_details")
choice = Choices(finish_reason=finish_reason, index=idx, message=message)
choice_list.append(choice)
model_response_object.choices = choice_list
if response_type == "completion":
if response_object is None or model_response_object is None:
raise Exception("Error in response object format")
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["message"].get("content", None),
role=choice["message"]["role"],
function_call=choice["message"].get("function_call", None),
tool_calls=choice["message"].get("tool_calls", None)
)
finish_reason = choice.get("finish_reason", None)
if finish_reason == None:
# gpt-4 vision can return 'finish_reason' or 'finish_details'
finish_reason = choice.get("finish_details")
choice = Choices(finish_reason=finish_reason, index=idx, message=message)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object and response_object["usage"] is not None:
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
if "usage" in response_object and response_object["usage"] is not None:
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
if "id" in response_object:
model_response_object.id = response_object["id"]
if "system_fingerprint" in response_object:
model_response_object.system_fingerprint = response_object["system_fingerprint"]
if "id" in response_object:
model_response_object.id = response_object["id"]
if "system_fingerprint" in response_object:
model_response_object.system_fingerprint = response_object["system_fingerprint"]
if "model" in response_object:
model_response_object.model = response_object["model"]
return model_response_object
if "model" in response_object:
model_response_object.model = response_object["model"]
return model_response_object
elif response_type == "embedding":
if response_object is None:
raise Exception("Error in response object format")
if model_response_object is None:
model_response_object = EmbeddingResponse()
if "model" in response_object:
model_response_object.model = response_object["model"]
if "object" in response_object:
model_response_object.object = response_object["object"]
data = []
for idx, embedding in enumerate(response_object["data"]):
embedding_obj = Embedding(
embedding=embedding.get("embedding", None),
index = embedding.get("index", idx),
object=embedding.get("object", "embedding")
)
data.append(embedding_obj)
model_response_object.data = data
if "usage" in response_object and response_object["usage"] is not None:
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
return model_response_object
except Exception as e:
raise Exception(f"Invalid response object {e}")