fix(triton/completion/transformation.py): remove bad_words / stop words from triton call

parameter 'bad_words' has invalid type. It should be either 'int', 'bool', or 'string'.
This commit is contained in:
Krrish Dholakia 2025-04-16 16:13:48 -07:00
parent f08a4e3c06
commit d171ac624c
2 changed files with 21 additions and 2 deletions

View file

@ -201,8 +201,6 @@ class TritonGenerateConfig(TritonConfig):
"max_tokens": int(
optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON)
),
"bad_words": [""],
"stop_words": [""],
},
"stream": bool(stream),
}

View file

@ -20,6 +20,7 @@ from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
import litellm
def test_split_embedding_by_shape_passes():
try:
data = [
@ -230,3 +231,23 @@ async def test_triton_embeddings():
assert response.data[0]["embedding"] == [0.1, 0.2]
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_triton_generate_raw_request():
from litellm.utils import return_raw_request
from litellm.types.utils import CallTypes
try:
kwargs = {
"model": "triton/llama-3-8b-instruct",
"messages": [{"role": "user", "content": "who are u?"}],
"api_base": "http://localhost:8000/generate",
}
raw_request = return_raw_request(endpoint=CallTypes.completion, kwargs=kwargs)
print("raw_request", raw_request)
assert raw_request is not None
assert "bad_words" not in json.dumps(raw_request["raw_request_body"])
assert "stop_words" not in json.dumps(raw_request["raw_request_body"])
except Exception as e:
pytest.fail(f"Error occurred: {e}")