mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
Add completion() impl for meta-reference
This commit is contained in:
parent
bcaf639dd6
commit
072d1b7205
4 changed files with 137 additions and 3 deletions
|
@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
|
|||
class CompletionResponse(BaseModel):
|
||||
"""Completion response."""
|
||||
|
||||
completion_message: CompletionMessage
|
||||
content: str
|
||||
stop_reason: StopReason
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
|
||||
|
||||
|
@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
|
|||
class BatchCompletionResponse(BaseModel):
|
||||
"""Batch completion response."""
|
||||
|
||||
completion_message_batch: List[CompletionMessage]
|
||||
batch: List[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
completion_message_batch: List[CompletionMessage]
|
||||
batch: List[ChatCompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -301,6 +301,7 @@ class Llama:
|
|||
request: CompletionRequest,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
|
@ -315,6 +316,7 @@ class Llama:
|
|||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -86,6 +86,101 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
self.check_model(request)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
logprobs = None
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs = [
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.token in tokenizer.stop_tokens:
|
||||
# not quite right semantically
|
||||
stop_reason = StopReason.end_of_turn
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -126,6 +126,42 @@ async def test_model_list(inference_settings):
|
|||
assert model_def.identifier == params["model"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_id__ != "meta-reference":
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
response = await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=False,
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert "violets are blue" in response.content
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
print(chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue