Make all API methods async def again

This commit is contained in:
Ashwin Bharambe 2024-10-18 16:50:57 -07:00
parent 95a96afe34
commit 627edaf407
17 changed files with 120 additions and 145 deletions

View file

@ -13,9 +13,6 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
@ -58,7 +55,18 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
if self.config.create_distributed_process_group:
self.generator.stop()
def completion(
def check_model(self, request) -> None:
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -66,9 +74,19 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
def chat_completion(
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -93,16 +111,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream=stream,
logprobs=logprobs,
)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
self.check_model(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@ -111,26 +120,17 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
if request.stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
def impl():
messages = chat_completion_request_to_messages(request)
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
@ -170,8 +170,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
self, request: ChatCompletionRequest
) -> AsyncGenerator:
def impl():
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
@ -184,14 +182,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):