Make all methods async def again; add completion() for meta-reference (#270)

PR #201 had made several changes while trying to fix issues with getting the stream=False branches of inference and agents API working. As part of this, it made a change which was slightly gratuitous. Namely, making chat_completion() and brethren "def" instead of "async def".

The rationale was that this allowed the user (within llama-stack) of this to use it as:

```
async for chunk in api.chat_completion(params)
```

However, it causes unnecessary confusion for several folks. Given that clients (e.g., llama-stack-apps) anyway use the SDK methods (which are completely isolated) this choice was not ideal. Let's revert back so the call now looks like:

```
async for chunk in await api.chat_completion(params)
```

Bonus: Added a completion() implementation for the meta-reference provider. Technically should have been another PR :)
This commit is contained in:
Ashwin Bharambe 2024-10-18 20:50:59 -07:00 committed by GitHub
parent 95a96afe34
commit 2089427d60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 330 additions and 213 deletions

View file

@ -21,7 +21,7 @@
"info": { "info": {
"title": "[DRAFT] Llama Stack Specification", "title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1", "version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109" "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
}, },
"servers": [ "servers": [
{ {
@ -2830,8 +2830,11 @@
"CompletionResponse": { "CompletionResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
"completion_message": { "content": {
"$ref": "#/components/schemas/CompletionMessage" "type": "string"
},
"stop_reason": {
"$ref": "#/components/schemas/StopReason"
}, },
"logprobs": { "logprobs": {
"type": "array", "type": "array",
@ -2842,7 +2845,8 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"completion_message" "content",
"stop_reason"
], ],
"title": "Completion response." "title": "Completion response."
}, },
@ -6075,49 +6079,49 @@
], ],
"tags": [ "tags": [
{ {
"name": "Evaluations" "name": "Models"
},
{
"name": "Inspect"
}, },
{ {
"name": "RewardScoring" "name": "RewardScoring"
}, },
{ {
"name": "Datasets" "name": "MemoryBanks"
},
{
"name": "Models"
},
{
"name": "Telemetry"
},
{
"name": "PostTraining"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "BatchInference"
},
{
"name": "Inference"
},
{
"name": "Agents"
},
{
"name": "Memory"
},
{
"name": "Safety"
}, },
{ {
"name": "Shields" "name": "Shields"
}, },
{ {
"name": "MemoryBanks" "name": "SyntheticDataGeneration"
},
{
"name": "Inference"
},
{
"name": "Inspect"
},
{
"name": "BatchInference"
},
{
"name": "Memory"
},
{
"name": "Datasets"
},
{
"name": "Agents"
},
{
"name": "PostTraining"
},
{
"name": "Telemetry"
},
{
"name": "Safety"
},
{
"name": "Evaluations"
}, },
{ {
"name": "BuiltinTool", "name": "BuiltinTool",

View file

@ -501,14 +501,17 @@ components:
CompletionResponse: CompletionResponse:
additionalProperties: false additionalProperties: false
properties: properties:
completion_message: content:
$ref: '#/components/schemas/CompletionMessage' type: string
logprobs: logprobs:
items: items:
$ref: '#/components/schemas/TokenLogProbs' $ref: '#/components/schemas/TokenLogProbs'
type: array type: array
stop_reason:
$ref: '#/components/schemas/StopReason'
required: required:
- completion_message - content
- stop_reason
title: Completion response. title: Completion response.
type: object type: object
CompletionResponseStreamChunk: CompletionResponseStreamChunk:
@ -2507,7 +2510,7 @@ info:
description: "This is the specification of the llama stack that provides\n \ description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\ \ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\ \ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109" \ draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
title: '[DRAFT] Llama Stack Specification' title: '[DRAFT] Llama Stack Specification'
version: 0.0.1 version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -3712,21 +3715,21 @@ security:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
tags: tags:
- name: Evaluations
- name: Inspect
- name: RewardScoring
- name: Datasets
- name: Models - name: Models
- name: Telemetry - name: RewardScoring
- name: PostTraining
- name: SyntheticDataGeneration
- name: BatchInference
- name: Inference
- name: Agents
- name: Memory
- name: Safety
- name: Shields
- name: MemoryBanks - name: MemoryBanks
- name: Shields
- name: SyntheticDataGeneration
- name: Inference
- name: Inspect
- name: BatchInference
- name: Memory
- name: Datasets
- name: Agents
- name: PostTraining
- name: Telemetry
- name: Safety
- name: Evaluations
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" /> - description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage" - description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"

View file

@ -421,10 +421,8 @@ class Agents(Protocol):
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgentCreateResponse: ... ) -> AgentCreateResponse: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create") @webmethod(route="/agents/turn/create")
def create_agent_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,

View file

@ -67,14 +67,14 @@ class AgentsClient(Agents):
response.raise_for_status() response.raise_for_status()
return AgentSessionCreateResponse(**response.json()) return AgentSessionCreateResponse(**response.json())
def create_agent_turn( async def create_agent_turn(
self, self,
request: AgentTurnCreateRequest, request: AgentTurnCreateRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
if request.stream: if request.stream:
return self._stream_agent_turn(request) return self._stream_agent_turn(request)
else: else:
return self._nonstream_agent_turn(request) return await self._nonstream_agent_turn(request)
async def _stream_agent_turn( async def _stream_agent_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
@ -126,7 +126,7 @@ async def _run_agent(
for content in user_prompts: for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"]) cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agent_turn( iterator = await api.create_agent_turn(
AgentTurnCreateRequest( AgentTurnCreateRequest(
agent_id=create_response.agent_id, agent_id=create_response.agent_id,
session_id=session_response.session_id, session_id=session_response.session_id,

View file

@ -42,10 +42,10 @@ class InferenceClient(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def completion(self, request: CompletionRequest) -> AsyncGenerator: async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -139,7 +139,8 @@ async def run_main(
else: else:
logprobs_config = None logprobs_config = None
iterator = client.chat_completion( assert stream, "Non streaming not supported here"
iterator = await client.chat_completion(
model=model, model=model,
messages=[message], messages=[message],
stream=stream, stream=stream,

View file

@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
"""Completion response.""" """Completion response."""
completion_message: CompletionMessage content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
class BatchCompletionResponse(BaseModel): class BatchCompletionResponse(BaseModel):
"""Batch completion response.""" """Batch completion response."""
completion_message_batch: List[CompletionMessage] batch: List[CompletionResponse]
@json_schema_type @json_schema_type
@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):
@json_schema_type @json_schema_type
class BatchChatCompletionResponse(BaseModel): class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage] batch: List[ChatCompletionResponse]
@json_schema_type @json_schema_type
@ -181,10 +182,8 @@ class ModelStore(Protocol):
class Inference(Protocol): class Inference(Protocol):
model_store: ModelStore model_store: ModelStore
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -196,7 +195,7 @@ class Inference(Protocol):
# This method is not `async def` because it can result in either an # This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`. # `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat_completion")
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],

View file

@ -70,7 +70,7 @@ class InferenceRouter(Inference):
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model) await self.routing_table.register_model(model)
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -93,11 +93,11 @@ class InferenceRouter(Inference):
) )
provider = self.routing_table.get_provider_impl(model) provider = self.routing_table.get_provider_impl(model)
if stream: if stream:
return (chunk async for chunk in provider.chat_completion(**params)) return (chunk async for chunk in await provider.chat_completion(**params))
else: else:
return provider.chat_completion(**params) return await provider.chat_completion(**params)
def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -114,9 +114,9 @@ class InferenceRouter(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
if stream: if stream:
return (chunk async for chunk in provider.completion(**params)) return (chunk async for chunk in await provider.completion(**params))
else: else:
return provider.completion(**params) return await provider.completion(**params)
async def embeddings( async def embeddings(
self, self,

View file

@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() self.client.close()
def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) )
return tool_config return tool_config
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],

View file

@ -48,7 +48,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -58,7 +58,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -84,7 +84,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
if stream: if stream:
return self._stream_chat_completion(request, client) return self._stream_chat_completion(request, client)
else: else:
return self._nonstream_chat_completion(request, client) return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI self, request: ChatCompletionRequest, client: OpenAI

View file

@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -87,7 +87,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
if stream: if stream:
return self._stream_chat_completion(request, client) return self._stream_chat_completion(request, client)
else: else:
return self._nonstream_chat_completion(request, client) return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks self, request: ChatCompletionRequest, client: Fireworks

View file

@ -84,7 +84,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return ret return ret
def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -94,7 +94,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -118,7 +118,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if stream: if stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
else: else:
return self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
return { return {

View file

@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
if stream: if stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
else: else:
return self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest

View file

@ -64,7 +64,7 @@ class TogetherInferenceAdapter(
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -101,7 +101,7 @@ class TogetherInferenceAdapter(
if stream: if stream:
return self._stream_chat_completion(request, client) return self._stream_chat_completion(request, client)
else: else:
return self._nonstream_chat_completion(request, client) return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Together self, request: ChatCompletionRequest, client: Together

View file

@ -424,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason = None stop_reason = None
with tracing.span("inference"): with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
tools=self._get_tools(), tools=self._get_tools(),

View file

@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id, session_id=session_id,
) )
def create_agent_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,

View file

@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
) )
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import ( from llama_models.llama3.reference_impl.multimodal.model import (
@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
@ -297,15 +296,12 @@ class Llama:
if all(eos_reached): if all(eos_reached):
break break
def text_completion( def completion(
self, self,
content: InterleavedTextMedia, request: CompletionRequest,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator: ) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if ( if (
max_gen_len is None max_gen_len is None
or max_gen_len == 0 or max_gen_len == 0
@ -313,26 +309,25 @@ class Llama:
): ):
max_gen_len = self.model.params.max_seq_len - 1 max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content) model_input = self.formatter.encode_content(request.content)
yield from self.generate( yield from self.generate(
model_input=model_input, model_input=model_input,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=sampling_params.temperature,
top_p=top_p, top_p=sampling_params.top_p,
logprobs=logprobs, logprobs=bool(request.logprobs),
echo=echo, include_stop_token=True,
echo=False,
) )
def chat_completion( def chat_completion(
self, self,
messages: List[Message], request: ChatCompletionRequest,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator: ) -> Generator:
messages = chat_completion_request_to_messages(request)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if ( if (
max_gen_len is None max_gen_len is None
or max_gen_len == 0 or max_gen_len == 0
@ -343,12 +338,12 @@ class Llama:
yield from self.generate( yield from self.generate(
model_input=self.formatter.encode_dialog_prompt( model_input=self.formatter.encode_dialog_prompt(
messages, messages,
tool_prompt_format, request.tool_prompt_format,
), ),
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=sampling_params.temperature,
top_p=top_p, top_p=sampling_params.top_p,
logprobs=logprobs, logprobs=bool(request.logprobs),
include_stop_token=True, include_stop_token=True,
) )

View file

@ -13,9 +13,6 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama from .generation import Llama
@ -58,7 +55,18 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
def completion( def check_model(self, request) -> None:
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -66,9 +74,114 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError() if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
def chat_completion( request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
if request.stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
def impl():
stop_reason = None
for token_result in self.generator.completion(request):
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token in tokenizer.stop_tokens:
# not quite right semantically
stop_reason = StopReason.end_of_turn
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
content = self.generator.formatter.tokenizer.decode(tokens)
return CompletionResponse(
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -93,16 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
self.check_model(request)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
if SEMAPHORE.locked(): if SEMAPHORE.locked():
@ -111,26 +215,17 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
if request.stream: if request.stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
else: else:
return self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
def impl(): def impl():
messages = chat_completion_request_to_messages(request)
tokens = [] tokens = []
logprobs = [] logprobs = []
stop_reason = None stop_reason = None
for token_result in self.generator.chat_completion( for token_result in self.generator.chat_completion(request):
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token) tokens.append(token_result.token)
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":
@ -170,8 +265,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
def impl(): def impl():
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,
@ -184,14 +277,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stop_reason = None stop_reason = None
ipython = False ipython = False
for token_result in self.generator.chat_completion( for token_result in self.generator.chat_completion(request):
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token) tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"): if not ipython and token_result.text.startswith("<|python_tag|>"):

View file

@ -7,16 +7,17 @@
import os import os
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Generator, List, Optional from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir from .generation import Llama, model_checkpoint_dir
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup from .parallel_utils import ModelParallelProcessGroup
class ModelRunner: class ModelRunner:
@ -24,15 +25,13 @@ class ModelRunner:
self.llama = llama self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: InferenceArgs): def __call__(self, req: Any):
return self.llama.chat_completion( if isinstance(req, ChatCompletionRequest):
task.messages, return self.llama.chat_completion(req)
task.temperature, elif isinstance(req, CompletionRequest):
task.top_p, return self.llama.completion(req)
task.max_gen_len, else:
task.logprobs, raise ValueError(f"Unexpected task type {type(req)}")
task.tool_prompt_format,
)
def init_model_cb(config: MetaReferenceInferenceConfig): def init_model_cb(config: MetaReferenceInferenceConfig):
@ -77,23 +76,18 @@ class LlamaModelParallelGenerator:
def __exit__(self, exc_type, exc_value, exc_traceback): def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop() self.group.stop()
def chat_completion( def completion(
self, self,
messages: List[Message], request: CompletionRequest,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator: ) -> Generator:
req_obj = InferenceArgs( req_obj = deepcopy(request)
messages=deepcopy(messages), gen = self.group.run_inference(req_obj)
temperature=temperature, yield from gen
top_p=top_p,
max_gen_len=max_gen_len, def chat_completion(
logprobs=logprobs or False, self,
tool_prompt_format=tool_prompt_format, request: ChatCompletionRequest,
) ) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj) gen = self.group.run_inference(req_obj)
yield from gen yield from gen

View file

@ -4,6 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json import json
import multiprocessing import multiprocessing
import os import os
@ -11,10 +17,9 @@ import tempfile
import time import time
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Callable, Generator, List, Literal, Optional, Union from typing import Callable, Generator, Literal, Optional, Union
import torch import torch
import zmq import zmq
from fairscale.nn.model_parallel.initialize import ( from fairscale.nn.model_parallel.initialize import (
@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank, get_model_parallel_src_rank,
) )
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult from .generation import TokenResult
class InferenceArgs(BaseModel):
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ProcessingMessageName(str, Enum): class ProcessingMessageName(str, Enum):
ready_request = "ready_request" ready_request = "ready_request"
ready_response = "ready_response" ready_response = "ready_response"
@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ( type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request ProcessingMessageName.task_request
) )
task: InferenceArgs task: Union[CompletionRequest, ChatCompletionRequest]
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
@ -349,11 +345,13 @@ class ModelParallelProcessGroup:
self.process.join() self.process.join()
self.started = False self.started = False
def run_inference(self, inference_args: InferenceArgs) -> Generator: def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest]
) -> Generator:
assert not self.running, "inference already running" assert not self.running, "inference already running"
self.running = True self.running = True
self.request_socket.send(encode_msg(TaskRequest(task=inference_args))) self.request_socket.send(encode_msg(TaskRequest(task=req)))
try: try:
while True: while True:
obj_json = self.request_socket.recv() obj_json = self.request_socket.recv()

View file

@ -184,7 +184,7 @@ class LlamaGuardShield(ShieldBase):
# TODO: llama-stack inference protocol has issues with non-streaming inference code # TODO: llama-stack inference protocol has issues with non-streaming inference code
content = "" content = ""
async for chunk in self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
model=self.model, model=self.model,
messages=[shield_input_message], messages=[shield_input_message],
stream=True, stream=True,

View file

@ -134,7 +134,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
if self.engine: if self.engine:
self.engine.shutdown_background_loop() self.engine.shutdown_background_loop()
def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -152,7 +152,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
logprobs=logprobs, logprobs=logprobs,
) )
def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: list[Message], messages: list[Message],
@ -189,7 +189,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
if stream: if stream:
return self._stream_chat_completion(request, results_generator) return self._stream_chat_completion(request, results_generator)
else: else:
return self._nonstream_chat_completion(request, results_generator) return await self._nonstream_chat_completion(request, results_generator)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator self, request: ChatCompletionRequest, results_generator: AsyncGenerator

View file

@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages):
) )
turn_response = [ turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0
@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments(
) )
turn_response = [ turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0
@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments(
) )
turn_response = [ turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0
@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search(
) )
turn_response = [ turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0

View file

@ -126,6 +126,45 @@ async def test_model_list(inference_settings):
assert model_def.identifier == params["model"] assert model_def.identifier == params["model"]
@pytest.mark.asyncio
async def test_completion(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_id__ != "meta-reference":
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Roses are red,",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "violets are blue" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert len(chunks) == 51
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages): async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"] inference_impl = inference_settings["impl"]
@ -146,7 +185,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"] inference_impl = inference_settings["impl"]
response = [ response = [
r r
async for r in inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
messages=sample_messages, messages=sample_messages,
stream=True, stream=True,
**inference_settings["common_params"], **inference_settings["common_params"],
@ -217,7 +256,7 @@ async def test_chat_completion_with_tool_calling_streaming(
response = [ response = [
r r
async for r in inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
messages=messages, messages=messages,
tools=[sample_tool_definition], tools=[sample_tool_definition],
stream=True, stream=True,