diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 449fff50d..5895e528e 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -88,7 +88,8 @@ class CompletionRequest(BaseModel): class CompletionResponse(BaseModel): """Completion response.""" - completion_message: CompletionMessage + content: str + stop_reason: StopReason logprobs: Optional[List[TokenLogProbs]] = None @@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel): class BatchCompletionResponse(BaseModel): """Batch completion response.""" - completion_message_batch: List[CompletionMessage] + batch: List[CompletionResponse] @json_schema_type @@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel): @json_schema_type class BatchChatCompletionResponse(BaseModel): - completion_message_batch: List[CompletionMessage] + batch: List[ChatCompletionResponse] @json_schema_type diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 46a409ebe..9ca128176 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -301,6 +301,7 @@ class Llama: request: CompletionRequest, ) -> Generator: sampling_params = request.sampling_params + max_gen_len = sampling_params.max_tokens if ( max_gen_len is None or max_gen_len == 0 @@ -315,6 +316,7 @@ class Llama: temperature=sampling_params.temperature, top_p=sampling_params.top_p, logprobs=bool(request.logprobs), + include_stop_token=True, echo=False, ) diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index bdccb4f03..e155ffd34 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -86,6 +86,101 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): ) self.check_model(request) + if request.stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + def impl(): + stop_reason = None + + for token_result in self.generator.chat_completion(request): + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + + logprobs = None + if stop_reason is None: + if request.logprobs: + assert len(token_result.logprobs) == 1 + + logprobs = [ + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ] + + yield CompletionResponseStreamChunk( + delta=text, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) + + if stop_reason is None: + yield CompletionResponseStreamChunk( + delta="", + stop_reason=StopReason.out_of_tokens, + ) + + if self.config.create_distributed_process_group: + async with SEMAPHORE: + for x in impl(): + yield x + else: + for x in impl(): + yield x + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + def impl(): + tokens = [] + logprobs = [] + stop_reason = None + + tokenizer = self.generator.formatter.tokenizer + for token_result in self.generator.completion(request): + tokens.append(token_result.token) + + if token_result.token in tokenizer.stop_tokens: + # not quite right semantically + stop_reason = StopReason.end_of_turn + + if request.logprobs: + assert len(token_result.logprobs) == 1 + + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ) + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + content = self.generator.formatter.tokenizer.decode(tokens) + return CompletionResponse( + content=content, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) + + if self.config.create_distributed_process_group: + async with SEMAPHORE: + return impl() + else: + return impl() + async def chat_completion( self, model: str, diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 385e7efb9..60ebe1766 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -126,6 +126,42 @@ async def test_model_list(inference_settings): assert model_def.identifier == params["model"] +@pytest.mark.asyncio +async def test_completion(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_id__ != "meta-reference": + pytest.skip("Other inference providers don't support completion() yet") + + response = await inference_impl.completion( + content="Roses are red,", + stream=False, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + ) + + assert isinstance(response, CompletionResponse) + assert "violets are blue" in response.content + + chunks = [ + r + async for r in await inference_impl.completion( + content="Roses are red,", + stream=True, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + ) + ] + + print(chunks) + + @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"]