mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Make all API methods async def
again
This commit is contained in:
parent
95a96afe34
commit
627edaf407
17 changed files with 120 additions and 145 deletions
|
@ -42,10 +42,10 @@ class InferenceClient(Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -139,7 +139,8 @@ async def run_main(
|
||||||
else:
|
else:
|
||||||
logprobs_config = None
|
logprobs_config = None
|
||||||
|
|
||||||
iterator = client.chat_completion(
|
assert stream, "Non streaming not supported here"
|
||||||
|
iterator = await client.chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|
|
@ -181,10 +181,8 @@ class ModelStore(Protocol):
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
model_store: ModelStore
|
model_store: ModelStore
|
||||||
|
|
||||||
# This method is not `async def` because it can result in either an
|
|
||||||
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
|
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -196,7 +194,7 @@ class Inference(Protocol):
|
||||||
# This method is not `async def` because it can result in either an
|
# This method is not `async def` because it can result in either an
|
||||||
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
||||||
@webmethod(route="/inference/chat_completion")
|
@webmethod(route="/inference/chat_completion")
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
|
|
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
await self.routing_table.register_model(model)
|
await self.routing_table.register_model(model)
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -93,11 +93,11 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model)
|
provider = self.routing_table.get_provider_impl(model)
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in provider.chat_completion(**params))
|
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||||
else:
|
else:
|
||||||
return provider.chat_completion(**params)
|
return await provider.chat_completion(**params)
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -114,9 +114,9 @@ class InferenceRouter(Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in provider.completion(**params))
|
return (chunk async for chunk in await provider.completion(**params))
|
||||||
else:
|
else:
|
||||||
return provider.completion(**params)
|
return await provider.completion(**params)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
)
|
)
|
||||||
return tool_config
|
return tool_config
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
|
|
@ -48,7 +48,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -58,7 +58,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -84,7 +84,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, client)
|
return self._stream_chat_completion(request, client)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request, client)
|
return await self._nonstream_chat_completion(request, client)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: OpenAI
|
self, request: ChatCompletionRequest, client: OpenAI
|
||||||
|
|
|
@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -87,7 +87,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, client)
|
return self._stream_chat_completion(request, client)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request, client)
|
return await self._nonstream_chat_completion(request, client)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Fireworks
|
self, request: ChatCompletionRequest, client: Fireworks
|
||||||
|
|
|
@ -84,7 +84,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -94,7 +94,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -118,7 +118,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
|
|
|
@ -64,7 +64,7 @@ class TogetherInferenceAdapter(
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -101,7 +101,7 @@ class TogetherInferenceAdapter(
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, client)
|
return self._stream_chat_completion(request, client)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request, client)
|
return await self._nonstream_chat_completion(request, client)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Together
|
self, request: ChatCompletionRequest, client: Together
|
||||||
|
|
|
@ -424,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
with tracing.span("inference"):
|
with tracing.span("inference"):
|
||||||
async for chunk in self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=self._get_tools(),
|
tools=self._get_tools(),
|
||||||
|
|
|
@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||||
from llama_models.llama3.api.datatypes import (
|
|
||||||
InterleavedTextMedia,
|
|
||||||
Message,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
|
@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_messages,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
||||||
|
@ -297,15 +296,11 @@ class Llama:
|
||||||
if all(eos_reached):
|
if all(eos_reached):
|
||||||
break
|
break
|
||||||
|
|
||||||
def text_completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
content: InterleavedTextMedia,
|
request: CompletionRequest,
|
||||||
temperature: float = 0.6,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
logprobs: bool = False,
|
|
||||||
echo: bool = False,
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
|
sampling_params = request.sampling_params
|
||||||
if (
|
if (
|
||||||
max_gen_len is None
|
max_gen_len is None
|
||||||
or max_gen_len == 0
|
or max_gen_len == 0
|
||||||
|
@ -313,26 +308,24 @@ class Llama:
|
||||||
):
|
):
|
||||||
max_gen_len = self.model.params.max_seq_len - 1
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
model_input = self.formatter.encode_content(content)
|
model_input = self.formatter.encode_content(request.content)
|
||||||
|
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=model_input,
|
model_input=model_input,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=sampling_params.temperature,
|
||||||
top_p=top_p,
|
top_p=sampling_params.top_p,
|
||||||
logprobs=logprobs,
|
logprobs=bool(request.logprobs),
|
||||||
echo=echo,
|
echo=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
request: ChatCompletionRequest,
|
||||||
temperature: float = 0.6,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
logprobs: bool = False,
|
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
|
sampling_params = request.sampling_params
|
||||||
|
max_gen_len = sampling_params.max_tokens
|
||||||
if (
|
if (
|
||||||
max_gen_len is None
|
max_gen_len is None
|
||||||
or max_gen_len == 0
|
or max_gen_len == 0
|
||||||
|
@ -343,12 +336,12 @@ class Llama:
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(
|
model_input=self.formatter.encode_dialog_prompt(
|
||||||
messages,
|
messages,
|
||||||
tool_prompt_format,
|
request.tool_prompt_format,
|
||||||
),
|
),
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=sampling_params.temperature,
|
||||||
top_p=top_p,
|
top_p=sampling_params.top_p,
|
||||||
logprobs=logprobs,
|
logprobs=bool(request.logprobs),
|
||||||
include_stop_token=True,
|
include_stop_token=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,6 @@ from llama_models.sku_list import resolve_model
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
chat_completion_request_to_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
|
@ -58,7 +55,18 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
def completion(
|
def check_model(self, request) -> None:
|
||||||
|
model = resolve_model(request.model)
|
||||||
|
if model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unknown model: {request.model}, Run `llama model list`"
|
||||||
|
)
|
||||||
|
elif model.descriptor() != self.model.descriptor():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -66,9 +74,19 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
raise NotImplementedError()
|
if logprobs:
|
||||||
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
def chat_completion(
|
request = CompletionRequest(
|
||||||
|
model=model,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
self.check_model(request)
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -93,16 +111,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
self.check_model(request)
|
||||||
model = resolve_model(request.model)
|
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Unknown model: {request.model}, Run `llama model list`"
|
|
||||||
)
|
|
||||||
elif model.descriptor() != self.model.descriptor():
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
if SEMAPHORE.locked():
|
if SEMAPHORE.locked():
|
||||||
|
@ -111,26 +120,17 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
def impl():
|
def impl():
|
||||||
messages = chat_completion_request_to_messages(request)
|
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
for token_result in self.generator.chat_completion(request):
|
||||||
messages=messages,
|
|
||||||
temperature=request.sampling_params.temperature,
|
|
||||||
top_p=request.sampling_params.top_p,
|
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
|
||||||
logprobs=request.logprobs,
|
|
||||||
tool_prompt_format=request.tool_prompt_format,
|
|
||||||
):
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
|
@ -170,8 +170,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
def impl():
|
def impl():
|
||||||
messages = chat_completion_request_to_messages(request)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
@ -184,14 +182,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
ipython = False
|
ipython = False
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
for token_result in self.generator.chat_completion(request):
|
||||||
messages=messages,
|
|
||||||
temperature=request.sampling_params.temperature,
|
|
||||||
top_p=request.sampling_params.top_p,
|
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
|
||||||
logprobs=request.logprobs,
|
|
||||||
tool_prompt_format=request.tool_prompt_format,
|
|
||||||
):
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||||
|
|
|
@ -7,16 +7,17 @@
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator, List, Optional
|
from typing import Any, Generator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama, model_checkpoint_dir
|
from .generation import Llama, model_checkpoint_dir
|
||||||
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
@ -24,15 +25,13 @@ class ModelRunner:
|
||||||
self.llama = llama
|
self.llama = llama
|
||||||
|
|
||||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||||
def __call__(self, task: InferenceArgs):
|
def __call__(self, req: Any):
|
||||||
return self.llama.chat_completion(
|
if isinstance(req, ChatCompletionRequest):
|
||||||
task.messages,
|
return self.llama.chat_completion(req)
|
||||||
task.temperature,
|
elif isinstance(req, CompletionRequest):
|
||||||
task.top_p,
|
return self.llama.completion(req)
|
||||||
task.max_gen_len,
|
else:
|
||||||
task.logprobs,
|
raise ValueError(f"Unexpected task type {type(req)}")
|
||||||
task.tool_prompt_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def init_model_cb(config: MetaReferenceInferenceConfig):
|
def init_model_cb(config: MetaReferenceInferenceConfig):
|
||||||
|
@ -77,23 +76,18 @@ class LlamaModelParallelGenerator:
|
||||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
self.group.stop()
|
self.group.stop()
|
||||||
|
|
||||||
def chat_completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
request: CompletionRequest,
|
||||||
temperature: float = 0.6,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
logprobs: bool = False,
|
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = InferenceArgs(
|
req_obj = deepcopy(request)
|
||||||
messages=deepcopy(messages),
|
gen = self.group.run_inference(req_obj)
|
||||||
temperature=temperature,
|
yield from gen
|
||||||
top_p=top_p,
|
|
||||||
max_gen_len=max_gen_len,
|
def chat_completion(
|
||||||
logprobs=logprobs or False,
|
self,
|
||||||
tool_prompt_format=tool_prompt_format,
|
request: ChatCompletionRequest,
|
||||||
)
|
) -> Generator:
|
||||||
|
req_obj = deepcopy(request)
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(req_obj)
|
||||||
yield from gen
|
yield from gen
|
||||||
|
|
|
@ -4,6 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
@ -11,10 +17,9 @@ import tempfile
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Generator, List, Literal, Optional, Union
|
from typing import Callable, Generator, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from fairscale.nn.model_parallel.initialize import (
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
|
@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
get_model_parallel_src_rank,
|
get_model_parallel_src_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||||
|
|
||||||
from .generation import TokenResult
|
from .generation import TokenResult
|
||||||
|
|
||||||
|
|
||||||
class InferenceArgs(BaseModel):
|
|
||||||
messages: List[Message]
|
|
||||||
temperature: float
|
|
||||||
top_p: float
|
|
||||||
max_gen_len: int
|
|
||||||
logprobs: bool
|
|
||||||
tool_prompt_format: ToolPromptFormat
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessingMessageName(str, Enum):
|
class ProcessingMessageName(str, Enum):
|
||||||
ready_request = "ready_request"
|
ready_request = "ready_request"
|
||||||
ready_response = "ready_response"
|
ready_response = "ready_response"
|
||||||
|
@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = (
|
type: Literal[ProcessingMessageName.task_request] = (
|
||||||
ProcessingMessageName.task_request
|
ProcessingMessageName.task_request
|
||||||
)
|
)
|
||||||
task: InferenceArgs
|
task: Union[CompletionRequest, ChatCompletionRequest]
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
|
@ -349,11 +345,13 @@ class ModelParallelProcessGroup:
|
||||||
self.process.join()
|
self.process.join()
|
||||||
self.started = False
|
self.started = False
|
||||||
|
|
||||||
def run_inference(self, inference_args: InferenceArgs) -> Generator:
|
def run_inference(
|
||||||
|
self, req: Union[CompletionRequest, ChatCompletionRequest]
|
||||||
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
|
self.request_socket.send(encode_msg(TaskRequest(task=req)))
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
obj_json = self.request_socket.recv()
|
obj_json = self.request_socket.recv()
|
||||||
|
|
|
@ -184,7 +184,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||||
content = ""
|
content = ""
|
||||||
async for chunk in self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[shield_input_message],
|
messages=[shield_input_message],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -134,7 +134,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||||
if self.engine:
|
if self.engine:
|
||||||
self.engine.shutdown_background_loop()
|
self.engine.shutdown_background_loop()
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -152,7 +152,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
|
@ -189,7 +189,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, results_generator)
|
return self._stream_chat_completion(request, results_generator)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request, results_generator)
|
return await self._nonstream_chat_completion(request, results_generator)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||||
|
|
|
@ -146,7 +146,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages):
|
||||||
inference_impl = inference_settings["impl"]
|
inference_impl = inference_settings["impl"]
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
messages=sample_messages,
|
messages=sample_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
**inference_settings["common_params"],
|
**inference_settings["common_params"],
|
||||||
|
@ -217,7 +217,7 @@ async def test_chat_completion_with_tool_calling_streaming(
|
||||||
|
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[sample_tool_definition],
|
tools=[sample_tool_definition],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue