diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 4752e3fe4..f12ecb7f5 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -116,7 +116,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): "model": self.map_to_provider_model(request.model), "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request), + **get_sampling_options(request.sampling_params), } async def embeddings( diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 441f32166..69535cd3c 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -116,7 +116,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): if prompt.startswith("<|begin_of_text|>"): prompt = prompt[len("<|begin_of_text|>") :] - options = get_sampling_options(request) + options = get_sampling_options(request.sampling_params) options.setdefault("max_tokens", 512) if fmt := request.response_format: diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index d4fe75cfa..916241a7c 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return await self._nonstream_completion(request) def _get_params_for_completion(self, request: CompletionRequest) -> dict: - sampling_options = get_sampling_options(request) + sampling_options = get_sampling_options(request.sampling_params) # This is needed since the Ollama API expects num_predict to be set # for early truncation instead of max_tokens. if sampling_options["max_tokens"] is not None: @@ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return { "model": OLLAMA_SUPPORTED_MODELS[request.model], "prompt": chat_completion_request_to_prompt(request, self.formatter), - "options": get_sampling_options(request), + "options": get_sampling_options(request.sampling_params), "raw": True, "stream": request.stream, } diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index f19181320..a7fa6ba00 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -24,9 +24,12 @@ from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionResponse, process_chat_completion_response, process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_model_input_info, + completion_request_to_prompt_model_input_info, ) from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig @@ -75,7 +78,98 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - raise NotImplementedError() + request = CompletionRequest( + model=model, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + def _get_max_new_tokens(self, sampling_params, input_tokens): + return min( + sampling_params.max_tokens or (self.max_tokens - input_tokens), + self.max_tokens - input_tokens - 1, + ) + + def _build_options( + self, + sampling_params: Optional[SamplingParams] = None, + fmt: ResponseFormat = None, + ): + options = get_sampling_options(sampling_params) + # delete key "max_tokens" from options since its not supported by the API + options.pop("max_tokens", None) + if fmt: + if fmt.type == ResponseFormatType.json_schema.value: + options["grammar"] = { + "type": "json", + "value": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise ValueError("Grammar response format not supported yet") + else: + raise ValueError(f"Unexpected response format: {fmt.type}") + + return options + + def _get_params_for_completion(self, request: CompletionRequest) -> dict: + prompt, input_tokens = completion_request_to_prompt_model_input_info( + request, self.formatter + ) + + return dict( + prompt=prompt, + stream=request.stream, + details=True, + max_new_tokens=self._get_max_new_tokens( + request.sampling_params, input_tokens + ), + stop_sequences=["<|eom_id|>", "<|eot_id|>"], + **self._build_options(request.sampling_params, request.response_format), + ) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + + async def _generate_and_convert_to_openai_compat(): + s = await self.client.text_generation(**params) + async for chunk in s: + token_result = chunk.token + finish_reason = None + if chunk.details: + finish_reason = chunk.details.finish_reason + + choice = OpenAICompatCompletionChoice( + text=token_result.text, finish_reason=finish_reason + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + r = await self.client.text_generation(**params) + + choice = OpenAICompatCompletionChoice( + finish_reason=r.details.finish_reason, + text="".join(t.text for t in r.details.tokens), + ) + + response = OpenAICompatCompletionResponse( + choices=[choice], + ) + + return process_completion_response(response, self.formatter) async def chat_completion( self, @@ -146,29 +240,15 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): prompt, input_tokens = chat_completion_request_to_model_input_info( request, self.formatter ) - max_new_tokens = min( - request.sampling_params.max_tokens or (self.max_tokens - input_tokens), - self.max_tokens - input_tokens - 1, - ) - options = get_sampling_options(request) - if fmt := request.response_format: - if fmt.type == ResponseFormatType.json_schema.value: - options["grammar"] = { - "type": "json", - "value": fmt.schema, - } - elif fmt.type == ResponseFormatType.grammar.value: - raise ValueError("Grammar response format not supported yet") - else: - raise ValueError(f"Unexpected response format: {fmt.type}") - return dict( prompt=prompt, stream=request.stream, details=True, - max_new_tokens=max_new_tokens, + max_new_tokens=self._get_max_new_tokens( + request.sampling_params, input_tokens + ), stop_sequences=["<|eom_id|>", "<|eot_id|>"], - **options, + **self._build_options(request.sampling_params, request.response_format), ) async def embeddings( diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 2f258e620..daf57497a 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -131,7 +131,7 @@ class TogetherInferenceAdapter( yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: - options = get_sampling_options(request) + options = get_sampling_options(request.sampling_params) if fmt := request.response_format: if fmt.type == ResponseFormatType.json_schema.value: options["response_format"] = { diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index dacf646b0..4687618fa 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -143,7 +143,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): "model": VLLM_SUPPORTED_MODELS[request.model], "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request), + **get_sampling_options(request.sampling_params), } async def embeddings( diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index ad49448e2..c7cbdd592 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -137,6 +137,7 @@ async def test_completion(inference_settings): if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::ollama", + "remote::tgi", ): pytest.skip("Other inference providers don't support completion() yet") @@ -170,6 +171,46 @@ async def test_completion(inference_settings): assert last.stop_reason == StopReason.out_of_tokens +@pytest.mark.asyncio +async def test_completions_structured_output(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::tgi", + ): + pytest.skip( + "Other inference providers don't support structured output in completions yet" + ) + + class Output(BaseModel): + name: str + year_born: str + year_retired: str + + user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." + response = await inference_impl.completion( + content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ", + stream=False, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + response_format=JsonResponseFormat( + schema=Output.model_json_schema(), + ), + ) + assert isinstance(response, CompletionResponse) + assert isinstance(response.content, str) + + answer = Output.parse_raw(response.content) + assert answer.name == "Michael Jordan" + assert answer.year_born == "1963" + assert answer.year_retired == "2003" + + @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 22ae8a717..086227c73 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -29,9 +29,9 @@ class OpenAICompatCompletionResponse(BaseModel): choices: List[OpenAICompatCompletionChoice] -def get_sampling_options(request: ChatCompletionRequest) -> dict: +def get_sampling_options(params: SamplingParams) -> dict: options = {} - if params := request.sampling_params: + if params: for attr in {"temperature", "top_p", "top_k", "max_tokens"}: if getattr(params, attr): options[attr] = getattr(params, attr) @@ -64,7 +64,18 @@ def process_completion_response( response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> CompletionResponse: choice = response.choices[0] - + # drop suffix if present and return stop reason as end of turn + if choice.text.endswith("<|eot_id|>"): + return CompletionResponse( + stop_reason=StopReason.end_of_turn, + content=choice.text[: -len("<|eot_id|>")], + ) + # drop suffix if present and return stop reason as end of message + if choice.text.endswith("<|eom_id|>"): + return CompletionResponse( + stop_reason=StopReason.end_of_message, + content=choice.text[: -len("<|eom_id|>")], + ) return CompletionResponse( stop_reason=get_stop_reason(choice.finish_reason), content=choice.text, @@ -95,13 +106,6 @@ async def process_completion_stream_response( choice = chunk.choices[0] finish_reason = choice.finish_reason - if finish_reason: - if finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - text = text_from_choice(choice) if text == "<|eot_id|>": stop_reason = StopReason.end_of_turn @@ -115,6 +119,12 @@ async def process_completion_stream_response( delta=text, stop_reason=stop_reason, ) + if finish_reason: + if finish_reason in ["stop", "eos", "eos_token"]: + stop_reason = StopReason.end_of_turn + elif finish_reason == "length": + stop_reason = StopReason.out_of_tokens + break yield CompletionResponseStreamChunk( delta="", diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 48f1df02f..d204ab728 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -31,6 +31,13 @@ def completion_request_to_prompt( return formatter.tokenizer.decode(model_input.tokens) +def completion_request_to_prompt_model_input_info( + request: CompletionRequest, formatter: ChatFormat +) -> Tuple[str, int]: + model_input = formatter.encode_content(request.content) + return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens)) + + def chat_completion_request_to_prompt( request: ChatCompletionRequest, formatter: ChatFormat ) -> str: