remove redundant defensive checks, fastapi does appropriate request validation

This commit is contained in:
Matthew Farrellee 2025-02-25 07:17:57 -05:00
parent 9b7b0c4c8d
commit d4c3aee490
2 changed files with 36 additions and 4 deletions

View file

@ -154,8 +154,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
TextTruncation.end: "END",
TextTruncation.start: "START",
}
if text_truncation not in text_truncation_options:
raise ValueError(f"Invalid text_truncation: {text_truncation}")
extra_body["truncate"] = text_truncation_options[text_truncation]
if output_dimension is not None:
@ -166,8 +164,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
EmbeddingTaskType.document: "passage",
EmbeddingTaskType.query: "query",
}
if task_type not in task_type_options:
raise ValueError(f"Invalid task_type: {task_type}")
extra_body["input_type"] = task_type_options[task_type]
try:

View file

@ -188,3 +188,39 @@ def test_embedding_task_type(llama_stack_client, embedding_model_id):
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="document"
)
assert query_embedding.embeddings != document_embedding.embeddings
@pytest.mark.parametrize(
"text_truncation",
[
None,
"none",
"end",
"start",
],
)
def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation):
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
@pytest.mark.parametrize(
"text_truncation",
[
"NONE",
"END",
"START",
"left",
"right",
],
)
def test_embedding_text_truncation_error(llama_stack_client, embedding_model_id, text_truncation):
with pytest.raises(BadRequestError) as excinfo:
llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
)