Add completion() impl for meta-reference

This commit is contained in:
Ashwin Bharambe 2024-10-18 20:41:21 -07:00
parent bcaf639dd6
commit 072d1b7205
4 changed files with 137 additions and 3 deletions

View file

@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
"""Completion response.""" """Completion response."""
completion_message: CompletionMessage content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
class BatchCompletionResponse(BaseModel): class BatchCompletionResponse(BaseModel):
"""Batch completion response.""" """Batch completion response."""
completion_message_batch: List[CompletionMessage] batch: List[CompletionResponse]
@json_schema_type @json_schema_type
@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):
@json_schema_type @json_schema_type
class BatchChatCompletionResponse(BaseModel): class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage] batch: List[ChatCompletionResponse]
@json_schema_type @json_schema_type

View file

@ -301,6 +301,7 @@ class Llama:
request: CompletionRequest, request: CompletionRequest,
) -> Generator: ) -> Generator:
sampling_params = request.sampling_params sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if ( if (
max_gen_len is None max_gen_len is None
or max_gen_len == 0 or max_gen_len == 0
@ -315,6 +316,7 @@ class Llama:
temperature=sampling_params.temperature, temperature=sampling_params.temperature,
top_p=sampling_params.top_p, top_p=sampling_params.top_p,
logprobs=bool(request.logprobs), logprobs=bool(request.logprobs),
include_stop_token=True,
echo=False, echo=False,
) )

View file

@ -86,6 +86,101 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
) )
self.check_model(request) self.check_model(request)
if request.stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
def impl():
stop_reason = None
for token_result in self.generator.chat_completion(request):
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token in tokenizer.stop_tokens:
# not quite right semantically
stop_reason = StopReason.end_of_turn
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
content = self.generator.formatter.tokenizer.decode(tokens)
return CompletionResponse(
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def chat_completion( async def chat_completion(
self, self,
model: str, model: str,

View file

@ -126,6 +126,42 @@ async def test_model_list(inference_settings):
assert model_def.identifier == params["model"] assert model_def.identifier == params["model"]
@pytest.mark.asyncio
async def test_completion(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_id__ != "meta-reference":
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Roses are red,",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "violets are blue" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
print(chunks)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages): async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"] inference_impl = inference_settings["impl"]