mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Add completion() impl for meta-reference
This commit is contained in:
parent
bcaf639dd6
commit
072d1b7205
4 changed files with 137 additions and 3 deletions
|
@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(BaseModel):
|
||||||
"""Completion response."""
|
"""Completion response."""
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
content: str
|
||||||
|
stop_reason: StopReason
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
|
||||||
class BatchCompletionResponse(BaseModel):
|
class BatchCompletionResponse(BaseModel):
|
||||||
"""Batch completion response."""
|
"""Batch completion response."""
|
||||||
|
|
||||||
completion_message_batch: List[CompletionMessage]
|
batch: List[CompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
completion_message_batch: List[CompletionMessage]
|
batch: List[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -301,6 +301,7 @@ class Llama:
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
|
max_gen_len = sampling_params.max_tokens
|
||||||
if (
|
if (
|
||||||
max_gen_len is None
|
max_gen_len is None
|
||||||
or max_gen_len == 0
|
or max_gen_len == 0
|
||||||
|
@ -315,6 +316,7 @@ class Llama:
|
||||||
temperature=sampling_params.temperature,
|
temperature=sampling_params.temperature,
|
||||||
top_p=sampling_params.top_p,
|
top_p=sampling_params.top_p,
|
||||||
logprobs=bool(request.logprobs),
|
logprobs=bool(request.logprobs),
|
||||||
|
include_stop_token=True,
|
||||||
echo=False,
|
echo=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -86,6 +86,101 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
return self._stream_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
def impl():
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
for token_result in self.generator.chat_completion(request):
|
||||||
|
if token_result.text == "<|eot_id|>":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
text = ""
|
||||||
|
elif token_result.text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
text = ""
|
||||||
|
else:
|
||||||
|
text = token_result.text
|
||||||
|
|
||||||
|
logprobs = None
|
||||||
|
if stop_reason is None:
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs = [
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
yield CompletionResponseStreamChunk(
|
||||||
|
delta=text,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
yield CompletionResponseStreamChunk(
|
||||||
|
delta="",
|
||||||
|
stop_reason=StopReason.out_of_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
|
async with SEMAPHORE:
|
||||||
|
for x in impl():
|
||||||
|
yield x
|
||||||
|
else:
|
||||||
|
for x in impl():
|
||||||
|
yield x
|
||||||
|
|
||||||
|
async def _nonstream_completion(
|
||||||
|
self, request: CompletionRequest
|
||||||
|
) -> CompletionResponse:
|
||||||
|
def impl():
|
||||||
|
tokens = []
|
||||||
|
logprobs = []
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
for token_result in self.generator.completion(request):
|
||||||
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
if token_result.token in tokenizer.stop_tokens:
|
||||||
|
# not quite right semantically
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||||
|
return CompletionResponse(
|
||||||
|
content=content,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
|
async with SEMAPHORE:
|
||||||
|
return impl()
|
||||||
|
else:
|
||||||
|
return impl()
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -126,6 +126,42 @@ async def test_model_list(inference_settings):
|
||||||
assert model_def.identifier == params["model"]
|
assert model_def.identifier == params["model"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion(inference_settings):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
params = inference_settings["common_params"]
|
||||||
|
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||||
|
if provider.__provider_id__ != "meta-reference":
|
||||||
|
pytest.skip("Other inference providers don't support completion() yet")
|
||||||
|
|
||||||
|
response = await inference_impl.completion(
|
||||||
|
content="Roses are red,",
|
||||||
|
stream=False,
|
||||||
|
model=params["model"],
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=50,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, CompletionResponse)
|
||||||
|
assert "violets are blue" in response.content
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
r
|
||||||
|
async for r in await inference_impl.completion(
|
||||||
|
content="Roses are red,",
|
||||||
|
stream=True,
|
||||||
|
model=params["model"],
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=50,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
print(chunks)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||||
inference_impl = inference_settings["impl"]
|
inference_impl = inference_settings["impl"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue