mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
rename toolchain/ --> llama_toolchain/
This commit is contained in:
parent
d95f5f863d
commit
f9111652ef
73 changed files with 36 additions and 37 deletions
2
llama_toolchain/inference/api/__init__.py
Normal file
2
llama_toolchain/inference/api/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .datatypes import * # noqa: F401 F403
|
||||
from .endpoints import * # noqa: F401 F403
|
159
llama_toolchain/inference/api/config.py
Normal file
159
llama_toolchain/inference/api/config.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from hydra.core.config_store import ConfigStore
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import QuantizationConfig
|
||||
|
||||
|
||||
class ImplType(Enum):
|
||||
inline = "inline"
|
||||
remote = "remote"
|
||||
|
||||
|
||||
class CheckpointType(Enum):
|
||||
pytorch = "pytorch"
|
||||
huggingface = "huggingface"
|
||||
|
||||
|
||||
# This enum represents the format in which weights are specified
|
||||
# This does not necessarily always equal what quantization is desired
|
||||
# at runtime since there can be on-the-fly conversions done
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8_mixed"
|
||||
|
||||
|
||||
class PytorchCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
||||
CheckpointType.pytorch.value
|
||||
)
|
||||
checkpoint_dir: str
|
||||
tokenizer_path: str
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
|
||||
CheckpointType.huggingface.value
|
||||
)
|
||||
repo_id: str # or model_name ?
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
class ModelCheckpointConfig(BaseModel):
|
||||
checkpoint: Annotated[
|
||||
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
|
||||
Field(discriminator="checkpoint_type"),
|
||||
]
|
||||
|
||||
|
||||
class InlineImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
|
||||
checkpoint_config: ModelCheckpointConfig
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
|
||||
|
||||
class RemoteImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
|
||||
url: str = Field(..., description="The URL of the remote module")
|
||||
|
||||
|
||||
class InferenceConfig(BaseModel):
|
||||
impl_config: Annotated[
|
||||
Union[InlineImplConfig, RemoteImplConfig],
|
||||
Field(discriminator="impl_type"),
|
||||
]
|
||||
|
||||
|
||||
# Hydra does not like unions of containers and
|
||||
# Pydantic does not like Literals
|
||||
# Adding a simple dataclass with custom coversion
|
||||
# to config classes
|
||||
|
||||
|
||||
@dataclass
|
||||
class InlineImplHydraConfig:
|
||||
checkpoint_type: str # "pytorch" / "HF"
|
||||
# pytorch checkpoint required args
|
||||
checkpoint_dir: str
|
||||
tokenizer_path: str
|
||||
model_parallel_size: int
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
# TODO: huggingface checkpoint required args
|
||||
|
||||
def convert_to_inline_impl_config(self):
|
||||
if self.checkpoint_type == "pytorch":
|
||||
return InlineImplConfig(
|
||||
checkpoint_config=ModelCheckpointConfig(
|
||||
checkpoint=PytorchCheckpoint(
|
||||
checkpoint_type=CheckpointType.pytorch.value,
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
tokenizer_path=self.tokenizer_path,
|
||||
model_parallel_size=self.model_parallel_size,
|
||||
)
|
||||
),
|
||||
quantization=self.quantization,
|
||||
max_seq_len=self.max_seq_len,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("HF Checkpoint not supported yet")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteImplHydraConfig:
|
||||
url: str
|
||||
|
||||
def convert_to_remote_impl_config(self):
|
||||
return RemoteImplConfig(
|
||||
url=self.url,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceHydraConfig:
|
||||
impl_type: str
|
||||
inline_config: Optional[InlineImplHydraConfig] = None
|
||||
remote_config: Optional[RemoteImplHydraConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.impl_type in ["inline", "remote"]
|
||||
if self.impl_type == "inline":
|
||||
assert self.inline_config is not None
|
||||
if self.impl_type == "remote":
|
||||
assert self.remote_config is not None
|
||||
|
||||
def convert_to_inference_config(self):
|
||||
if self.impl_type == "inline":
|
||||
inline_config = InlineImplHydraConfig(**self.inline_config)
|
||||
return InferenceConfig(
|
||||
impl_config=inline_config.convert_to_inline_impl_config()
|
||||
)
|
||||
elif self.impl_type == "remote":
|
||||
remote_config = RemoteImplHydraConfig(**self.remote_config)
|
||||
return InferenceConfig(
|
||||
impl_config=remote_config.convert_to_remote_impl_config()
|
||||
)
|
||||
|
||||
|
||||
cs = ConfigStore.instance()
|
||||
cs.store(name="inference_config", node=InferenceHydraConfig)
|
68
llama_toolchain/inference/api/datatypes.py
Normal file
68
llama_toolchain/inference/api/datatypes.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from strong_typing.schema import json_schema_type
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
top_k: Optional[int] = 0
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QuantizationType(Enum):
|
||||
bf16 = "bf16"
|
||||
fp8 = "fp8"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Fp8QuantizationConfig(BaseModel):
|
||||
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Bf16QuantizationConfig(BaseModel):
|
||||
type: Literal[QuantizationType.bf16.value] = (
|
||||
QuantizationType.bf16.value
|
||||
)
|
||||
|
||||
|
||||
QuantizationConfig = Annotated[
|
||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseEventType(Enum):
|
||||
start = "start"
|
||||
complete = "complete"
|
||||
progress = "progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolCallParseStatus(Enum):
|
||||
started = "started"
|
||||
in_progress = "in_progress"
|
||||
failure = "failure"
|
||||
success = "success"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolCallDelta(BaseModel):
|
||||
content: Union[str, ToolCall]
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseEvent(BaseModel):
|
||||
"""Chat completion response event."""
|
||||
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: Union[str, ToolCallDelta]
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
stop_reason: Optional[StopReason] = None
|
117
llama_toolchain/inference/api/endpoints.py
Normal file
117
llama_toolchain/inference/api/endpoints.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
from .datatypes import * # noqa: F403
|
||||
from typing import Optional, Protocol
|
||||
|
||||
# this dependency is annoying and we need a forked up version anyway
|
||||
from pyopenapi import webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionRequest(BaseModel):
|
||||
model: PretrainedModel
|
||||
content: InterleavedTextAttachment
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
quantization_config: Optional[QuantizationConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(BaseModel):
|
||||
completion_message: CompletionMessage
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(BaseModel):
|
||||
"""streamed completion response."""
|
||||
|
||||
delta: str
|
||||
stop_reason: Optional[StopReason] = None
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchCompletionRequest(BaseModel):
|
||||
model: PretrainedModel
|
||||
content_batch: List[InterleavedTextAttachment]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
quantization_config: Optional[QuantizationConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchCompletionResponse(BaseModel):
|
||||
completion_message_batch: List[CompletionMessage]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: InstructModel
|
||||
messages: List[Message]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
|
||||
# zero-shot tool definitions as input to the model
|
||||
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
quantization_config: Optional[QuantizationConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(BaseModel):
|
||||
"""SSE-stream of these events."""
|
||||
|
||||
event: ChatCompletionResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
completion_message: CompletionMessage
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionRequest(BaseModel):
|
||||
model: InstructModel
|
||||
messages_batch: List[List[Message]]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
|
||||
# zero-shot tool definitions as input to the model
|
||||
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
quantization_config: Optional[QuantizationConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
completion_message_batch: List[CompletionMessage]
|
||||
|
||||
|
||||
class Inference(Protocol):
|
||||
|
||||
@webmethod(route="/inference/completion")
|
||||
async def completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||
|
||||
@webmethod(route="/inference/batch_completion")
|
||||
async def batch_completion(
|
||||
self,
|
||||
request: BatchCompletionRequest,
|
||||
) -> List[CompletionResponse]: ...
|
||||
|
||||
@webmethod(route="/inference/batch_chat_completion")
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
request: BatchChatCompletionRequest,
|
||||
) -> List[ChatCompletionResponse]: ...
|
Loading…
Add table
Add a link
Reference in a new issue