mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(triton/completion/transformation.py): remove bad_words / stop words from triton call
parameter 'bad_words' has invalid type. It should be either 'int', 'bool', or 'string'.
This commit is contained in:
parent
f08a4e3c06
commit
d171ac624c
2 changed files with 21 additions and 2 deletions
|
@ -201,8 +201,6 @@ class TritonGenerateConfig(TritonConfig):
|
|||
"max_tokens": int(
|
||||
optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON)
|
||||
),
|
||||
"bad_words": [""],
|
||||
"stop_words": [""],
|
||||
},
|
||||
"stream": bool(stream),
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
|
|||
import litellm
|
||||
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_passes():
|
||||
try:
|
||||
data = [
|
||||
|
@ -230,3 +231,23 @@ async def test_triton_embeddings():
|
|||
assert response.data[0]["embedding"] == [0.1, 0.2]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
|
||||
def test_triton_generate_raw_request():
|
||||
from litellm.utils import return_raw_request
|
||||
from litellm.types.utils import CallTypes
|
||||
try:
|
||||
kwargs = {
|
||||
"model": "triton/llama-3-8b-instruct",
|
||||
"messages": [{"role": "user", "content": "who are u?"}],
|
||||
"api_base": "http://localhost:8000/generate",
|
||||
}
|
||||
raw_request = return_raw_request(endpoint=CallTypes.completion, kwargs=kwargs)
|
||||
print("raw_request", raw_request)
|
||||
assert raw_request is not None
|
||||
assert "bad_words" not in json.dumps(raw_request["raw_request_body"])
|
||||
assert "stop_words" not in json.dumps(raw_request["raw_request_body"])
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue