mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
remove redundant defensive checks, fastapi does appropriate request validation
This commit is contained in:
parent
9b7b0c4c8d
commit
d4c3aee490
2 changed files with 36 additions and 4 deletions
|
@ -154,8 +154,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
TextTruncation.end: "END",
|
||||
TextTruncation.start: "START",
|
||||
}
|
||||
if text_truncation not in text_truncation_options:
|
||||
raise ValueError(f"Invalid text_truncation: {text_truncation}")
|
||||
extra_body["truncate"] = text_truncation_options[text_truncation]
|
||||
|
||||
if output_dimension is not None:
|
||||
|
@ -166,8 +164,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
EmbeddingTaskType.document: "passage",
|
||||
EmbeddingTaskType.query: "query",
|
||||
}
|
||||
if task_type not in task_type_options:
|
||||
raise ValueError(f"Invalid task_type: {task_type}")
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
|
|
|
@ -188,3 +188,39 @@ def test_embedding_task_type(llama_stack_client, embedding_model_id):
|
|||
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="document"
|
||||
)
|
||||
assert query_embedding.embeddings != document_embedding.embeddings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_truncation",
|
||||
[
|
||||
None,
|
||||
"none",
|
||||
"end",
|
||||
"start",
|
||||
],
|
||||
)
|
||||
def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation):
|
||||
response = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
|
||||
)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == 1
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
assert isinstance(response.embeddings[0][0], float)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_truncation",
|
||||
[
|
||||
"NONE",
|
||||
"END",
|
||||
"START",
|
||||
"left",
|
||||
"right",
|
||||
],
|
||||
)
|
||||
def test_embedding_text_truncation_error(llama_stack_client, embedding_model_id, text_truncation):
|
||||
with pytest.raises(BadRequestError) as excinfo:
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue