mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(utils.py): fix embedding response output parsing
This commit is contained in:
parent
62d8f9ad2a
commit
dac76a4861
3 changed files with 102 additions and 53 deletions
131
litellm/utils.py
131
litellm/utils.py
|
@ -336,14 +336,45 @@ class ModelResponse(OpenAIObject):
|
|||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
class Embedding(OpenAIObject):
|
||||
embedding: list = []
|
||||
index: int
|
||||
object: str
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
class EmbeddingResponse(OpenAIObject):
|
||||
def __init__(self, model=None, usage=None, stream=False, response_ms=None):
|
||||
data: Optional[Embedding] = None
|
||||
"""The actual embedding value"""
|
||||
|
||||
usage: Optional[Usage] = None
|
||||
"""Usage statistics for the completion request."""
|
||||
def __init__(self, model=None, usage=None, stream=False, response_ms=None, data=None):
|
||||
object = "list"
|
||||
if response_ms:
|
||||
_response_ms = response_ms
|
||||
else:
|
||||
_response_ms = None
|
||||
data = []
|
||||
if data:
|
||||
data = data
|
||||
else:
|
||||
data = None
|
||||
|
||||
if usage:
|
||||
usage = usage
|
||||
else:
|
||||
usage = Usage()
|
||||
|
||||
model = model
|
||||
super().__init__(model=model, object=object, data=data, usage=usage)
|
||||
|
||||
|
@ -1239,9 +1270,12 @@ def client(original_function):
|
|||
pass
|
||||
else:
|
||||
call_type = original_function.__name__
|
||||
print_verbose(f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}")
|
||||
if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
|
||||
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||
else:
|
||||
elif call_type == CallTypes.embedding.value and isinstance(cached_result, dict):
|
||||
return convert_to_model_response_object(response_object=cached_result, response_type="embedding")
|
||||
else:
|
||||
return cached_result
|
||||
# MODEL CALL
|
||||
result = original_function(*args, **kwargs)
|
||||
|
@ -3234,40 +3268,71 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
|
|||
pass
|
||||
|
||||
|
||||
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
|
||||
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion"):
|
||||
try:
|
||||
if response_object is None or model_response_object is None:
|
||||
raise Exception("Error in response object format")
|
||||
choice_list=[]
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(
|
||||
content=choice["message"].get("content", None),
|
||||
role=choice["message"]["role"],
|
||||
function_call=choice["message"].get("function_call", None),
|
||||
tool_calls=choice["message"].get("tool_calls", None)
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
if finish_reason == None:
|
||||
# gpt-4 vision can return 'finish_reason' or 'finish_details'
|
||||
finish_reason = choice.get("finish_details")
|
||||
choice = Choices(finish_reason=finish_reason, index=idx, message=message)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list
|
||||
if response_type == "completion":
|
||||
if response_object is None or model_response_object is None:
|
||||
raise Exception("Error in response object format")
|
||||
choice_list=[]
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(
|
||||
content=choice["message"].get("content", None),
|
||||
role=choice["message"]["role"],
|
||||
function_call=choice["message"].get("function_call", None),
|
||||
tool_calls=choice["message"].get("tool_calls", None)
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
if finish_reason == None:
|
||||
# gpt-4 vision can return 'finish_reason' or 'finish_details'
|
||||
finish_reason = choice.get("finish_details")
|
||||
choice = Choices(finish_reason=finish_reason, index=idx, message=message)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list
|
||||
|
||||
if "usage" in response_object and response_object["usage"] is not None:
|
||||
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
|
||||
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
|
||||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||
if "usage" in response_object and response_object["usage"] is not None:
|
||||
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
|
||||
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
|
||||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "system_fingerprint" in response_object:
|
||||
model_response_object.system_fingerprint = response_object["system_fingerprint"]
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "system_fingerprint" in response_object:
|
||||
model_response_object.system_fingerprint = response_object["system_fingerprint"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
return model_response_object
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
return model_response_object
|
||||
elif response_type == "embedding":
|
||||
if response_object is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
if model_response_object is None:
|
||||
model_response_object = EmbeddingResponse()
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
if "object" in response_object:
|
||||
model_response_object.object = response_object["object"]
|
||||
|
||||
data = []
|
||||
for idx, embedding in enumerate(response_object["data"]):
|
||||
embedding_obj = Embedding(
|
||||
embedding=embedding.get("embedding", None),
|
||||
index = embedding.get("index", idx),
|
||||
object=embedding.get("object", "embedding")
|
||||
)
|
||||
data.append(embedding_obj)
|
||||
model_response_object.data = data
|
||||
|
||||
if "usage" in response_object and response_object["usage"] is not None:
|
||||
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
|
||||
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
|
||||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||
|
||||
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid response object {e}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue