fix(triton/completion/transformation.py): remove bad_words / stop words from triton call

parameter 'bad_words' has invalid type. It should be either 'int', 'bool', or 'string'.
This commit is contained in:
Krrish Dholakia 2025-04-16 16:13:48 -07:00
parent f08a4e3c06
commit d171ac624c
2 changed files with 21 additions and 2 deletions

View file

@ -201,8 +201,6 @@ class TritonGenerateConfig(TritonConfig):
"max_tokens": int( "max_tokens": int(
optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON) optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON)
), ),
"bad_words": [""],
"stop_words": [""],
}, },
"stream": bool(stream), "stream": bool(stream),
} }

View file

@ -20,6 +20,7 @@ from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
import litellm import litellm
def test_split_embedding_by_shape_passes(): def test_split_embedding_by_shape_passes():
try: try:
data = [ data = [
@ -230,3 +231,23 @@ async def test_triton_embeddings():
assert response.data[0]["embedding"] == [0.1, 0.2] assert response.data[0]["embedding"] == [0.1, 0.2]
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_triton_generate_raw_request():
from litellm.utils import return_raw_request
from litellm.types.utils import CallTypes
try:
kwargs = {
"model": "triton/llama-3-8b-instruct",
"messages": [{"role": "user", "content": "who are u?"}],
"api_base": "http://localhost:8000/generate",
}
raw_request = return_raw_request(endpoint=CallTypes.completion, kwargs=kwargs)
print("raw_request", raw_request)
assert raw_request is not None
assert "bad_words" not in json.dumps(raw_request["raw_request_body"])
assert "stop_words" not in json.dumps(raw_request["raw_request_body"])
except Exception as e:
pytest.fail(f"Error occurred: {e}")