forked from phoenix-oss/llama-stack-mirror
Make all methods async def
again; add completion() for meta-reference (#270)
PR #201 had made several changes while trying to fix issues with getting the stream=False branches of inference and agents API working. As part of this, it made a change which was slightly gratuitous. Namely, making chat_completion() and brethren "def" instead of "async def". The rationale was that this allowed the user (within llama-stack) of this to use it as: ``` async for chunk in api.chat_completion(params) ``` However, it causes unnecessary confusion for several folks. Given that clients (e.g., llama-stack-apps) anyway use the SDK methods (which are completely isolated) this choice was not ideal. Let's revert back so the call now looks like: ``` async for chunk in await api.chat_completion(params) ``` Bonus: Added a completion() implementation for the meta-reference provider. Technically should have been another PR :)
This commit is contained in:
parent
95a96afe34
commit
2089427d60
23 changed files with 330 additions and 213 deletions
|
@ -21,7 +21,7 @@
|
|||
"info": {
|
||||
"title": "[DRAFT] Llama Stack Specification",
|
||||
"version": "0.0.1",
|
||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
|
||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
|
||||
},
|
||||
"servers": [
|
||||
{
|
||||
|
@ -2830,8 +2830,11 @@
|
|||
"CompletionResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"completion_message": {
|
||||
"$ref": "#/components/schemas/CompletionMessage"
|
||||
"content": {
|
||||
"type": "string"
|
||||
},
|
||||
"stop_reason": {
|
||||
"$ref": "#/components/schemas/StopReason"
|
||||
},
|
||||
"logprobs": {
|
||||
"type": "array",
|
||||
|
@ -2842,7 +2845,8 @@
|
|||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"completion_message"
|
||||
"content",
|
||||
"stop_reason"
|
||||
],
|
||||
"title": "Completion response."
|
||||
},
|
||||
|
@ -6075,49 +6079,49 @@
|
|||
],
|
||||
"tags": [
|
||||
{
|
||||
"name": "Evaluations"
|
||||
},
|
||||
{
|
||||
"name": "Inspect"
|
||||
"name": "Models"
|
||||
},
|
||||
{
|
||||
"name": "RewardScoring"
|
||||
},
|
||||
{
|
||||
"name": "Datasets"
|
||||
},
|
||||
{
|
||||
"name": "Models"
|
||||
},
|
||||
{
|
||||
"name": "Telemetry"
|
||||
},
|
||||
{
|
||||
"name": "PostTraining"
|
||||
},
|
||||
{
|
||||
"name": "SyntheticDataGeneration"
|
||||
},
|
||||
{
|
||||
"name": "BatchInference"
|
||||
},
|
||||
{
|
||||
"name": "Inference"
|
||||
},
|
||||
{
|
||||
"name": "Agents"
|
||||
},
|
||||
{
|
||||
"name": "Memory"
|
||||
},
|
||||
{
|
||||
"name": "Safety"
|
||||
"name": "MemoryBanks"
|
||||
},
|
||||
{
|
||||
"name": "Shields"
|
||||
},
|
||||
{
|
||||
"name": "MemoryBanks"
|
||||
"name": "SyntheticDataGeneration"
|
||||
},
|
||||
{
|
||||
"name": "Inference"
|
||||
},
|
||||
{
|
||||
"name": "Inspect"
|
||||
},
|
||||
{
|
||||
"name": "BatchInference"
|
||||
},
|
||||
{
|
||||
"name": "Memory"
|
||||
},
|
||||
{
|
||||
"name": "Datasets"
|
||||
},
|
||||
{
|
||||
"name": "Agents"
|
||||
},
|
||||
{
|
||||
"name": "PostTraining"
|
||||
},
|
||||
{
|
||||
"name": "Telemetry"
|
||||
},
|
||||
{
|
||||
"name": "Safety"
|
||||
},
|
||||
{
|
||||
"name": "Evaluations"
|
||||
},
|
||||
{
|
||||
"name": "BuiltinTool",
|
||||
|
|
|
@ -501,14 +501,17 @@ components:
|
|||
CompletionResponse:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
completion_message:
|
||||
$ref: '#/components/schemas/CompletionMessage'
|
||||
content:
|
||||
type: string
|
||||
logprobs:
|
||||
items:
|
||||
$ref: '#/components/schemas/TokenLogProbs'
|
||||
type: array
|
||||
stop_reason:
|
||||
$ref: '#/components/schemas/StopReason'
|
||||
required:
|
||||
- completion_message
|
||||
- content
|
||||
- stop_reason
|
||||
title: Completion response.
|
||||
type: object
|
||||
CompletionResponseStreamChunk:
|
||||
|
@ -2507,7 +2510,7 @@ info:
|
|||
description: "This is the specification of the llama stack that provides\n \
|
||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||
\ to\n best leverage Llama Models. The specification is still in\
|
||||
\ draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
|
||||
\ draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
|
||||
title: '[DRAFT] Llama Stack Specification'
|
||||
version: 0.0.1
|
||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||
|
@ -3712,21 +3715,21 @@ security:
|
|||
servers:
|
||||
- url: http://any-hosted-llama-stack.com
|
||||
tags:
|
||||
- name: Evaluations
|
||||
- name: Inspect
|
||||
- name: RewardScoring
|
||||
- name: Datasets
|
||||
- name: Models
|
||||
- name: Telemetry
|
||||
- name: PostTraining
|
||||
- name: SyntheticDataGeneration
|
||||
- name: BatchInference
|
||||
- name: Inference
|
||||
- name: Agents
|
||||
- name: Memory
|
||||
- name: Safety
|
||||
- name: Shields
|
||||
- name: RewardScoring
|
||||
- name: MemoryBanks
|
||||
- name: Shields
|
||||
- name: SyntheticDataGeneration
|
||||
- name: Inference
|
||||
- name: Inspect
|
||||
- name: BatchInference
|
||||
- name: Memory
|
||||
- name: Datasets
|
||||
- name: Agents
|
||||
- name: PostTraining
|
||||
- name: Telemetry
|
||||
- name: Safety
|
||||
- name: Evaluations
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||
name: BuiltinTool
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||
|
|
|
@ -421,10 +421,8 @@ class Agents(Protocol):
|
|||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse: ...
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/agents/turn/create")
|
||||
def create_agent_turn(
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
|
|
@ -67,14 +67,14 @@ class AgentsClient(Agents):
|
|||
response.raise_for_status()
|
||||
return AgentSessionCreateResponse(**response.json())
|
||||
|
||||
def create_agent_turn(
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
if request.stream:
|
||||
return self._stream_agent_turn(request)
|
||||
else:
|
||||
return self._nonstream_agent_turn(request)
|
||||
return await self._nonstream_agent_turn(request)
|
||||
|
||||
async def _stream_agent_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
|
@ -126,7 +126,7 @@ async def _run_agent(
|
|||
|
||||
for content in user_prompts:
|
||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||
iterator = api.create_agent_turn(
|
||||
iterator = await api.create_agent_turn(
|
||||
AgentTurnCreateRequest(
|
||||
agent_id=create_response.agent_id,
|
||||
session_id=session_response.session_id,
|
||||
|
|
|
@ -42,10 +42,10 @@ class InferenceClient(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -139,7 +139,8 @@ async def run_main(
|
|||
else:
|
||||
logprobs_config = None
|
||||
|
||||
iterator = client.chat_completion(
|
||||
assert stream, "Non streaming not supported here"
|
||||
iterator = await client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
|
|
|
@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
|
|||
class CompletionResponse(BaseModel):
|
||||
"""Completion response."""
|
||||
|
||||
completion_message: CompletionMessage
|
||||
content: str
|
||||
stop_reason: StopReason
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
|
||||
|
||||
|
@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
|
|||
class BatchCompletionResponse(BaseModel):
|
||||
"""Batch completion response."""
|
||||
|
||||
completion_message_batch: List[CompletionMessage]
|
||||
batch: List[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
completion_message_batch: List[CompletionMessage]
|
||||
batch: List[ChatCompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -181,10 +182,8 @@ class ModelStore(Protocol):
|
|||
class Inference(Protocol):
|
||||
model_store: ModelStore
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/inference/completion")
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -196,7 +195,7 @@ class Inference(Protocol):
|
|||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
|
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
|
|||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.routing_table.register_model(model)
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -93,11 +93,11 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.chat_completion(**params))
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
else:
|
||||
return provider.chat_completion(**params)
|
||||
return await provider.chat_completion(**params)
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -114,9 +114,9 @@ class InferenceRouter(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.completion(**params))
|
||||
return (chunk async for chunk in await provider.completion(**params))
|
||||
else:
|
||||
return provider.completion(**params)
|
||||
return await provider.completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
return tool_config
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
|
@ -48,7 +48,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -58,7 +58,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -84,7 +84,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
|
|
|
@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -87,7 +87,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
|
|
|
@ -84,7 +84,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
return ret
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -94,7 +94,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -118,7 +118,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
|
|
|
@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
|
|
|
@ -64,7 +64,7 @@ class TogetherInferenceAdapter(
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -101,7 +101,7 @@ class TogetherInferenceAdapter(
|
|||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Together
|
||||
|
|
|
@ -424,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason = None
|
||||
|
||||
with tracing.span("inference"):
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=self._get_tools(),
|
||||
|
|
|
@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
def create_agent_turn(
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
|
|
@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
InterleavedTextMedia,
|
||||
Message,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
|
@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model
|
|||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
@ -297,15 +296,12 @@ class Llama:
|
|||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def text_completion(
|
||||
def completion(
|
||||
self,
|
||||
content: InterleavedTextMedia,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
request: CompletionRequest,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
|
@ -313,26 +309,25 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
model_input = self.formatter.encode_content(content)
|
||||
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Generator:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
|
@ -343,12 +338,12 @@ class Llama:
|
|||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
tool_prompt_format,
|
||||
request.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -13,9 +13,6 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
|
@ -58,7 +55,18 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
def completion(
|
||||
def check_model(self, request) -> None:
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -66,9 +74,114 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
def chat_completion(
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
logprobs = None
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs = [
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.token in tokenizer.stop_tokens:
|
||||
# not quite right semantically
|
||||
stop_reason = StopReason.end_of_turn
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -93,16 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
|
@ -111,26 +215,17 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
def impl():
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
@ -170,8 +265,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
def impl():
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
|
@ -184,14 +277,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
|
|
|
@ -7,16 +7,17 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Generator, List, Optional
|
||||
from typing import Any, Generator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
@ -24,15 +25,13 @@ class ModelRunner:
|
|||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, task: InferenceArgs):
|
||||
return self.llama.chat_completion(
|
||||
task.messages,
|
||||
task.temperature,
|
||||
task.top_p,
|
||||
task.max_gen_len,
|
||||
task.logprobs,
|
||||
task.tool_prompt_format,
|
||||
)
|
||||
def __call__(self, req: Any):
|
||||
if isinstance(req, ChatCompletionRequest):
|
||||
return self.llama.chat_completion(req)
|
||||
elif isinstance(req, CompletionRequest):
|
||||
return self.llama.completion(req)
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {type(req)}")
|
||||
|
||||
|
||||
def init_model_cb(config: MetaReferenceInferenceConfig):
|
||||
|
@ -77,23 +76,18 @@ class LlamaModelParallelGenerator:
|
|||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.group.stop()
|
||||
|
||||
def chat_completion(
|
||||
def completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
request: CompletionRequest,
|
||||
) -> Generator:
|
||||
req_obj = InferenceArgs(
|
||||
messages=deepcopy(messages),
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs or False,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
yield from gen
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
yield from gen
|
||||
|
|
|
@ -4,6 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
|
@ -11,10 +17,9 @@ import tempfile
|
|||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Callable, Generator, List, Literal, Optional, Union
|
||||
from typing import Callable, Generator, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import zmq
|
||||
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
|
@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
get_model_parallel_src_rank,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
from .generation import TokenResult
|
||||
|
||||
|
||||
class InferenceArgs(BaseModel):
|
||||
messages: List[Message]
|
||||
temperature: float
|
||||
top_p: float
|
||||
max_gen_len: int
|
||||
logprobs: bool
|
||||
tool_prompt_format: ToolPromptFormat
|
||||
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
ready_request = "ready_request"
|
||||
ready_response = "ready_response"
|
||||
|
@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
|
|||
type: Literal[ProcessingMessageName.task_request] = (
|
||||
ProcessingMessageName.task_request
|
||||
)
|
||||
task: InferenceArgs
|
||||
task: Union[CompletionRequest, ChatCompletionRequest]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
|
@ -349,11 +345,13 @@ class ModelParallelProcessGroup:
|
|||
self.process.join()
|
||||
self.started = False
|
||||
|
||||
def run_inference(self, inference_args: InferenceArgs) -> Generator:
|
||||
def run_inference(
|
||||
self, req: Union[CompletionRequest, ChatCompletionRequest]
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
self.running = True
|
||||
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
|
||||
self.request_socket.send(encode_msg(TaskRequest(task=req)))
|
||||
try:
|
||||
while True:
|
||||
obj_json = self.request_socket.recv()
|
||||
|
|
|
@ -184,7 +184,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
content = ""
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=True,
|
||||
|
|
|
@ -134,7 +134,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
if self.engine:
|
||||
self.engine.shutdown_background_loop()
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -152,7 +152,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[Message],
|
||||
|
@ -189,7 +189,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request, results_generator)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, results_generator)
|
||||
return await self._nonstream_chat_completion(request, results_generator)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
|
|
|
@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages):
|
|||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments(
|
|||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments(
|
|||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search(
|
|||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
|
|
@ -126,6 +126,45 @@ async def test_model_list(inference_settings):
|
|||
assert model_def.identifier == params["model"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_id__ != "meta-reference":
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
response = await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=False,
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert "violets are blue" in response.content
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
||||
assert len(chunks) == 51
|
||||
last = chunks[-1]
|
||||
assert last.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
|
@ -146,7 +185,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages):
|
|||
inference_impl = inference_settings["impl"]
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
async for r in await inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
|
@ -217,7 +256,7 @@ async def test_chat_completion_with_tool_calling_streaming(
|
|||
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
async for r in await inference_impl.chat_completion(
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=True,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue