mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Support max_completion_tokens on Sagemaker
This commit is contained in:
parent
36ee132514
commit
59debe5ae7
2 changed files with 40 additions and 2 deletions
|
@ -37,6 +37,7 @@ class SagemakerConfig(BaseConfig):
|
|||
"""
|
||||
|
||||
max_new_tokens: Optional[int] = None
|
||||
max_completion_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
return_full_text: Optional[bool] = None
|
||||
|
@ -44,6 +45,7 @@ class SagemakerConfig(BaseConfig):
|
|||
def __init__(
|
||||
self,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
|
@ -65,7 +67,7 @@ class SagemakerConfig(BaseConfig):
|
|||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
return ["stream", "temperature", "max_tokens", "max_completion_tokens", "top_p", "stop", "n"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
|
@ -102,6 +104,8 @@ class SagemakerConfig(BaseConfig):
|
|||
if value == 0:
|
||||
value = 1
|
||||
optional_params["max_new_tokens"] = value
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_new_tokens"] = value
|
||||
non_default_params.pop("aws_sagemaker_allow_zero_temp", None)
|
||||
return optional_params
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
|
||||
sys.path.insert(0, os.path.abspath("../../../../.."))
|
||||
from litellm.llms.sagemaker.common_utils import AWSEventStreamDecoder
|
||||
|
||||
from litellm.llms.sagemaker.completion.transformation import SagemakerConfig
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_bytes_unicode_decode_error():
|
||||
|
@ -95,3 +95,37 @@ async def test_aiter_bytes_valid_chunk_followed_by_unicode_error():
|
|||
# Verify we got our valid chunk despite the subsequent error
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["text"] == "hello" # Verify the content of the valid chunk
|
||||
|
||||
class TestSagemakerTransform:
|
||||
def setup_method(self):
|
||||
self.config = SagemakerConfig()
|
||||
self.model = "test"
|
||||
self.logging_obj = MagicMock()
|
||||
|
||||
def test_map_mistral_params(self):
|
||||
"""Test that parameters are correctly mapped"""
|
||||
test_params = {"temperature": 0.7, "max_tokens": 200, "max_completion_tokens": 256}
|
||||
|
||||
result = self.config.map_openai_params(
|
||||
non_default_params=test_params,
|
||||
optional_params={},
|
||||
model=self.model,
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
# The function should properly map max_completion_tokens to max_tokens and override max_tokens
|
||||
assert result == {"temperature": 0.7, "max_new_tokens": 256}
|
||||
|
||||
def test_mistral_max_tokens_backward_compat(self):
|
||||
"""Test that parameters are correctly mapped"""
|
||||
test_params = {"temperature": 0.7, "max_tokens": 200,}
|
||||
|
||||
result = self.config.map_openai_params(
|
||||
non_default_params=test_params,
|
||||
optional_params={},
|
||||
model=self.model,
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
# The function should properly map max_tokens if max_completion_tokens is not provided
|
||||
assert result == {"temperature": 0.7, "max_new_tokens": 200}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue