diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 06c25e4b6..218762d75 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -325,21 +325,8 @@ class AzureChatCompletion(BaseLLM): original_response=response, ) - embedding_response = json.loads(response.model_dump_json()) - output_data = [] - for idx, embedding in enumerate(embedding_response["data"]): - output_data.append( - { - "object": embedding["object"], - "index": embedding["index"], - "embedding": embedding["embedding"] - } - ) - model_response["object"] = "list" - model_response["data"] = output_data - model_response["model"] = "azure/" + model - model_response["usage"] = embedding_response["usage"] - return model_response + + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") except AzureOpenAIError as e: exception_mapping_worked = True raise e diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index d635da131..6adc69f57 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -343,11 +343,8 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, original_response=response, ) - model_response.data = response.data - model_response.model = model - model_response.usage = response.usage - model_response.object = "list" - return model_response + + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") except OpenAIError as e: exception_mapping_worked = True raise e diff --git a/litellm/utils.py b/litellm/utils.py index 806ef1f7b..922076575 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -336,14 +336,45 @@ class ModelResponse(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) +class Embedding(OpenAIObject): + embedding: list = [] + index: int + object: str + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + class EmbeddingResponse(OpenAIObject): - def __init__(self, model=None, usage=None, stream=False, response_ms=None): + data: Optional[Embedding] = None + """The actual embedding value""" + + usage: Optional[Usage] = None + """Usage statistics for the completion request.""" + def __init__(self, model=None, usage=None, stream=False, response_ms=None, data=None): object = "list" if response_ms: _response_ms = response_ms else: _response_ms = None - data = [] + if data: + data = data + else: + data = None + + if usage: + usage = usage + else: + usage = Usage() + model = model super().__init__(model=model, object=object, data=data, usage=usage) @@ -1239,9 +1270,12 @@ def client(original_function): pass else: call_type = original_function.__name__ + print_verbose(f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}") if call_type == CallTypes.completion.value and isinstance(cached_result, dict): return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) - else: + elif call_type == CallTypes.embedding.value and isinstance(cached_result, dict): + return convert_to_model_response_object(response_object=cached_result, response_type="embedding") + else: return cached_result # MODEL CALL result = original_function(*args, **kwargs) @@ -3234,40 +3268,71 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k pass -def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None): +def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion"): try: - if response_object is None or model_response_object is None: - raise Exception("Error in response object format") - choice_list=[] - for idx, choice in enumerate(response_object["choices"]): - message = Message( - content=choice["message"].get("content", None), - role=choice["message"]["role"], - function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None) - ) - finish_reason = choice.get("finish_reason", None) - if finish_reason == None: - # gpt-4 vision can return 'finish_reason' or 'finish_details' - finish_reason = choice.get("finish_details") - choice = Choices(finish_reason=finish_reason, index=idx, message=message) - choice_list.append(choice) - model_response_object.choices = choice_list + if response_type == "completion": + if response_object is None or model_response_object is None: + raise Exception("Error in response object format") + choice_list=[] + for idx, choice in enumerate(response_object["choices"]): + message = Message( + content=choice["message"].get("content", None), + role=choice["message"]["role"], + function_call=choice["message"].get("function_call", None), + tool_calls=choice["message"].get("tool_calls", None) + ) + finish_reason = choice.get("finish_reason", None) + if finish_reason == None: + # gpt-4 vision can return 'finish_reason' or 'finish_details' + finish_reason = choice.get("finish_details") + choice = Choices(finish_reason=finish_reason, index=idx, message=message) + choice_list.append(choice) + model_response_object.choices = choice_list - if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - if "id" in response_object: - model_response_object.id = response_object["id"] - - if "system_fingerprint" in response_object: - model_response_object.system_fingerprint = response_object["system_fingerprint"] + if "id" in response_object: + model_response_object.id = response_object["id"] + + if "system_fingerprint" in response_object: + model_response_object.system_fingerprint = response_object["system_fingerprint"] - if "model" in response_object: - model_response_object.model = response_object["model"] - return model_response_object + if "model" in response_object: + model_response_object.model = response_object["model"] + return model_response_object + elif response_type == "embedding": + if response_object is None: + raise Exception("Error in response object format") + + if model_response_object is None: + model_response_object = EmbeddingResponse() + + if "model" in response_object: + model_response_object.model = response_object["model"] + + if "object" in response_object: + model_response_object.object = response_object["object"] + + data = [] + for idx, embedding in enumerate(response_object["data"]): + embedding_obj = Embedding( + embedding=embedding.get("embedding", None), + index = embedding.get("index", idx), + object=embedding.get("object", "embedding") + ) + data.append(embedding_obj) + model_response_object.data = data + + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + + + return model_response_object except Exception as e: raise Exception(f"Invalid response object {e}")