fix handle user message

This commit is contained in:
Ishaan Jaff 2024-09-12 14:34:32 -07:00
parent ded40e4d41
commit 0f24f339f3
2 changed files with 8 additions and 4 deletions

View file

@ -15,6 +15,7 @@ import types
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
import litellm import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from .openai import OpenAIConfig from .openai import OpenAIConfig
@ -78,16 +79,19 @@ class OpenAIO1Config(OpenAIConfig):
return True return True
return False return False
def o1_prompt_factory(self, messages: List[Any]): def o1_prompt_factory(self, messages: List[AllMessageValues]):
""" """
Handles limitations of O-1 model family. Handles limitations of O-1 model family.
- modalities: image => drop param (if user opts in to dropping param) - modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user' - role: system ==> translate to role 'user'
""" """
for message in messages: for i, message in enumerate(messages):
if message["role"] == "system": if message["role"] == "system":
message["role"] = "user" new_message = ChatCompletionUserMessage(
content=message["content"], role="user"
)
messages[i] = new_message # Replace the old message with the new one
if isinstance(message["content"], list): if isinstance(message["content"], list):
new_content = [] new_content = []

View file

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from openai.types.audio.transcription_create_params import FileTypes from openai.types.audio.transcription_create_params import FileTypes
from openai.types.completion_usage import CompletionUsage from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage
from pydantic import ConfigDict, Field, PrivateAttr from pydantic import ConfigDict, Field, PrivateAttr
from typing_extensions import Callable, Dict, Required, TypedDict, override from typing_extensions import Callable, Dict, Required, TypedDict, override