Support max_completion_tokens on Sagemaker

This commit is contained in:
Pranav Simha 2025-04-23 13:12:41 -07:00
parent 36ee132514
commit 59debe5ae7
2 changed files with 40 additions and 2 deletions

View file

@ -37,6 +37,7 @@ class SagemakerConfig(BaseConfig):
"""
max_new_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
return_full_text: Optional[bool] = None
@ -44,6 +45,7 @@ class SagemakerConfig(BaseConfig):
def __init__(
self,
max_new_tokens: Optional[int] = None,
max_completion_tokens: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
return_full_text: Optional[bool] = None,
@ -65,7 +67,7 @@ class SagemakerConfig(BaseConfig):
)
def get_supported_openai_params(self, model: str) -> List:
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
return ["stream", "temperature", "max_tokens", "max_completion_tokens", "top_p", "stop", "n"]
def map_openai_params(
self,
@ -102,6 +104,8 @@ class SagemakerConfig(BaseConfig):
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "max_completion_tokens":
optional_params["max_new_tokens"] = value
non_default_params.pop("aws_sagemaker_allow_zero_temp", None)
return optional_params

View file

@ -8,7 +8,7 @@ import pytest
sys.path.insert(0, os.path.abspath("../../../../.."))
from litellm.llms.sagemaker.common_utils import AWSEventStreamDecoder
from litellm.llms.sagemaker.completion.transformation import SagemakerConfig
@pytest.mark.asyncio
async def test_aiter_bytes_unicode_decode_error():
@ -95,3 +95,37 @@ async def test_aiter_bytes_valid_chunk_followed_by_unicode_error():
# Verify we got our valid chunk despite the subsequent error
assert len(chunks) == 1
assert chunks[0]["text"] == "hello" # Verify the content of the valid chunk
class TestSagemakerTransform:
def setup_method(self):
self.config = SagemakerConfig()
self.model = "test"
self.logging_obj = MagicMock()
def test_map_mistral_params(self):
"""Test that parameters are correctly mapped"""
test_params = {"temperature": 0.7, "max_tokens": 200, "max_completion_tokens": 256}
result = self.config.map_openai_params(
non_default_params=test_params,
optional_params={},
model=self.model,
drop_params=False,
)
# The function should properly map max_completion_tokens to max_tokens and override max_tokens
assert result == {"temperature": 0.7, "max_new_tokens": 256}
def test_mistral_max_tokens_backward_compat(self):
"""Test that parameters are correctly mapped"""
test_params = {"temperature": 0.7, "max_tokens": 200,}
result = self.config.map_openai_params(
non_default_params=test_params,
optional_params={},
model=self.model,
drop_params=False,
)
# The function should properly map max_tokens if max_completion_tokens is not provided
assert result == {"temperature": 0.7, "max_new_tokens": 200}