diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index ea5d80565..320096826 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -7,6 +7,7 @@ import os import pytest +from llama_models.datatypes import SamplingParams, TopPSamplingStrategy from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( @@ -22,12 +23,8 @@ from llama_stack.apis.agents import ( ToolExecutionStep, Turn, ) -from llama_stack.apis.inference import ( - CompletionMessage, - SamplingParams, - TopPSamplingStrategy, - UserMessage, -) + +from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.safety import ViolationLevel from llama_stack.providers.datatypes import Api diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 932ae36e6..037e99819 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.models import Model + from .utils import group_chunks @@ -476,7 +477,7 @@ class TestInference: last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason assert last.event.delta.parse_status == ToolCallParseStatus.succeeded - assert last.event.delta.content.type == "tool_call" + assert isinstance(last.event.delta.content, ToolCall) call = last.event.delta.content assert call.tool_name == "get_weather" diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 1eb89f21d..694212a02 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -74,7 +74,8 @@ def get_sampling_options(params: SamplingParams) -> dict: options = {} if params: options.update(get_sampling_strategy_options(params)) - options["max_tokens"] = params.max_tokens + if params.max_tokens: + options["max_tokens"] = params.max_tokens if params.repetition_penalty is not None and params.repetition_penalty != 1.0: options["repeat_penalty"] = params.repetition_penalty