mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(xai/chat/transformation.py): filter out 'name' param for xai non-… (#9761)
* fix(xai/chat/transformation.py): filter out 'name' param for xai non-user roles Fixes https://github.com/BerriAI/litellm/issues/9720 * test fix test_hf_chat_template --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
parent
d66db2207b
commit
90a4dfab3c
3 changed files with 43 additions and 3 deletions
|
@ -35,7 +35,7 @@ def handle_messages_with_content_list_to_str_conversion(
|
||||||
|
|
||||||
|
|
||||||
def strip_name_from_messages(
|
def strip_name_from_messages(
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues], allowed_name_roles: List[str] = ["user"]
|
||||||
) -> List[AllMessageValues]:
|
) -> List[AllMessageValues]:
|
||||||
"""
|
"""
|
||||||
Removes 'name' from messages
|
Removes 'name' from messages
|
||||||
|
@ -44,7 +44,7 @@ def strip_name_from_messages(
|
||||||
for message in messages:
|
for message in messages:
|
||||||
msg_role = message.get("role")
|
msg_role = message.get("role")
|
||||||
msg_copy = message.copy()
|
msg_copy = message.copy()
|
||||||
if msg_role == "user":
|
if msg_role not in allowed_name_roles:
|
||||||
msg_copy.pop("name", None) # type: ignore
|
msg_copy.pop("name", None) # type: ignore
|
||||||
new_messages.append(msg_copy)
|
new_messages.append(msg_copy)
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
|
strip_name_from_messages,
|
||||||
|
)
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
@ -51,3 +55,21 @@ class XAIChatConfig(OpenAIGPTConfig):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
optional_params[param] = value
|
optional_params[param] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Handle https://github.com/BerriAI/litellm/issues/9720
|
||||||
|
|
||||||
|
Filter out 'name' from messages
|
||||||
|
"""
|
||||||
|
messages = strip_name_from_messages(messages)
|
||||||
|
return super().transform_request(
|
||||||
|
model, messages, optional_params, litellm_params, headers
|
||||||
|
)
|
||||||
|
|
|
@ -142,3 +142,21 @@ def test_completion_xai(stream):
|
||||||
assert response.choices[0].message.content is not None
|
assert response.choices[0].message.content is not None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_xai_message_name_filtering():
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "*I press the green button*",
|
||||||
|
"name": "example_user"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Hello", "name": "John"},
|
||||||
|
{"role": "assistant", "content": "Hello", "name": "Jane"},
|
||||||
|
]
|
||||||
|
response = completion(
|
||||||
|
model="xai/grok-beta",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue