diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py index 21fcf2eefb..0db83b2d3d 100644 --- a/litellm/llms/triton/completion/transformation.py +++ b/litellm/llms/triton/completion/transformation.py @@ -201,8 +201,6 @@ class TritonGenerateConfig(TritonConfig): "max_tokens": int( optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON) ), - "bad_words": [""], - "stop_words": [""], }, "stream": bool(stream), } diff --git a/tests/llm_translation/test_triton.py b/tests/llm_translation/test_triton.py index 7e4ba92f23..8a3bbb4661 100644 --- a/tests/llm_translation/test_triton.py +++ b/tests/llm_translation/test_triton.py @@ -20,6 +20,7 @@ from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig import litellm + def test_split_embedding_by_shape_passes(): try: data = [ @@ -230,3 +231,23 @@ async def test_triton_embeddings(): assert response.data[0]["embedding"] == [0.1, 0.2] except Exception as e: pytest.fail(f"Error occurred: {e}") + + + +def test_triton_generate_raw_request(): + from litellm.utils import return_raw_request + from litellm.types.utils import CallTypes + try: + kwargs = { + "model": "triton/llama-3-8b-instruct", + "messages": [{"role": "user", "content": "who are u?"}], + "api_base": "http://localhost:8000/generate", + } + raw_request = return_raw_request(endpoint=CallTypes.completion, kwargs=kwargs) + print("raw_request", raw_request) + assert raw_request is not None + assert "bad_words" not in json.dumps(raw_request["raw_request_body"]) + assert "stop_words" not in json.dumps(raw_request["raw_request_body"]) + except Exception as e: + pytest.fail(f"Error occurred: {e}") +