diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 4752e3fe4..f12ecb7f5 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -116,7 +116,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): "model": self.map_to_provider_model(request.model), "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request), + **get_sampling_options(request.sampling_params), } async def embeddings( diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 441f32166..69535cd3c 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -116,7 +116,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): if prompt.startswith("<|begin_of_text|>"): prompt = prompt[len("<|begin_of_text|>") :] - options = get_sampling_options(request) + options = get_sampling_options(request.sampling_params) options.setdefault("max_tokens", 512) if fmt := request.response_format: diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index d4fe75cfa..916241a7c 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return await self._nonstream_completion(request) def _get_params_for_completion(self, request: CompletionRequest) -> dict: - sampling_options = get_sampling_options(request) + sampling_options = get_sampling_options(request.sampling_params) # This is needed since the Ollama API expects num_predict to be set # for early truncation instead of max_tokens. if sampling_options["max_tokens"] is not None: @@ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return { "model": OLLAMA_SUPPORTED_MODELS[request.model], "prompt": chat_completion_request_to_prompt(request, self.formatter), - "options": get_sampling_options(request), + "options": get_sampling_options(request.sampling_params), "raw": True, "stream": request.stream, } diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index dd36c0ab0..25a6759b3 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -90,20 +90,21 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): else: return await self._nonstream_completion(request) - def _get_params_for_completion(self, request: CompletionRequest) -> dict: - prompt, input_tokens = completion_request_to_prompt_model_input_info( - request, self.formatter - ) - max_new_tokens = min( - request.sampling_params.max_tokens or (self.max_tokens - input_tokens), + def _get_max_new_tokens(self, sampling_params, input_tokens): + return min( + sampling_params.max_tokens or (self.max_tokens - input_tokens), self.max_tokens - input_tokens - 1, ) - options = get_sampling_options(request) + def _build_options( + self, + sampling_params: Optional[SamplingParams] = None, + fmt: ResponseFormat = None, + ): + options = get_sampling_options(sampling_params) # delete key "max_tokens" from options since its not supported by the API options.pop("max_tokens", None) - - if fmt := request.response_format: + if fmt: if fmt.type == ResponseFormatType.json_schema.value: options["grammar"] = { "type": "json", @@ -114,13 +115,22 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): else: raise ValueError(f"Unexpected response format: {fmt.type}") + return options + + def _get_params_for_completion(self, request: CompletionRequest) -> dict: + prompt, input_tokens = completion_request_to_prompt_model_input_info( + request, self.formatter + ) + return dict( prompt=prompt, stream=request.stream, details=True, - max_new_tokens=max_new_tokens, + max_new_tokens=self._get_max_new_tokens( + request.sampling_params, input_tokens + ), stop_sequences=["<|eom_id|>", "<|eot_id|>"], - **options, + **self._build_options(request.sampling_params, request.response_format), ) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: @@ -129,7 +139,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def _generate_and_convert_to_openai_compat(): s = await self.client.text_generation(**params) async for chunk in s: - token_result = chunk.token finish_reason = None if chunk.details: @@ -230,29 +239,15 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): prompt, input_tokens = chat_completion_request_to_model_input_info( request, self.formatter ) - max_new_tokens = min( - request.sampling_params.max_tokens or (self.max_tokens - input_tokens), - self.max_tokens - input_tokens - 1, - ) - options = get_sampling_options(request) - if fmt := request.response_format: - if fmt.type == ResponseFormatType.json_schema.value: - options["grammar"] = { - "type": "json", - "value": fmt.schema, - } - elif fmt.type == ResponseFormatType.grammar.value: - raise ValueError("Grammar response format not supported yet") - else: - raise ValueError(f"Unexpected response format: {fmt.type}") - return dict( prompt=prompt, stream=request.stream, details=True, - max_new_tokens=max_new_tokens, + max_new_tokens=self._get_max_new_tokens( + request.sampling_params, input_tokens + ), stop_sequences=["<|eom_id|>", "<|eot_id|>"], - **options, + **self._build_options(request.sampling_params, request.response_format), ) async def embeddings( diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 2f258e620..daf57497a 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -131,7 +131,7 @@ class TogetherInferenceAdapter( yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: - options = get_sampling_options(request) + options = get_sampling_options(request.sampling_params) if fmt := request.response_format: if fmt.type == ResponseFormatType.json_schema.value: options["response_format"] = { diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index dacf646b0..4687618fa 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -143,7 +143,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): "model": VLLM_SUPPORTED_MODELS[request.model], "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request), + **get_sampling_options(request.sampling_params), } async def embeddings( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index f810dc40a..5a5ddbb50 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -29,9 +29,9 @@ class OpenAICompatCompletionResponse(BaseModel): choices: List[OpenAICompatCompletionChoice] -def get_sampling_options(request: ChatCompletionRequest) -> dict: +def get_sampling_options(params: SamplingParams) -> dict: options = {} - if params := request.sampling_params: + if params: for attr in {"temperature", "top_p", "top_k", "max_tokens"}: if getattr(params, attr): options[attr] = getattr(params, attr)