mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 13:28:40 +00:00
rename ModelInference to Inference
This commit is contained in:
parent
245461620d
commit
67f0510edd
18 changed files with 468 additions and 1636 deletions
|
@ -75,7 +75,7 @@ class RemoteImplConfig(BaseModel):
|
|||
url: str = Field(..., description="The URL of the remote module")
|
||||
|
||||
|
||||
class ModelInferenceConfig(BaseModel):
|
||||
class InferenceConfig(BaseModel):
|
||||
impl_config: Annotated[
|
||||
Union[InlineImplConfig, RemoteImplConfig],
|
||||
Field(discriminator="impl_type"),
|
||||
|
@ -130,7 +130,7 @@ class RemoteImplHydraConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
class ModelInferenceHydraConfig:
|
||||
class InferenceHydraConfig:
|
||||
impl_type: str
|
||||
inline_config: Optional[InlineImplHydraConfig] = None
|
||||
remote_config: Optional[RemoteImplHydraConfig] = None
|
||||
|
@ -142,18 +142,18 @@ class ModelInferenceHydraConfig:
|
|||
if self.impl_type == "remote":
|
||||
assert self.remote_config is not None
|
||||
|
||||
def convert_to_model_inferene_config(self):
|
||||
def convert_to_inference_config(self):
|
||||
if self.impl_type == "inline":
|
||||
inline_config = InlineImplHydraConfig(**self.inline_config)
|
||||
return ModelInferenceConfig(
|
||||
return InferenceConfig(
|
||||
impl_config=inline_config.convert_to_inline_impl_config()
|
||||
)
|
||||
elif self.impl_type == "remote":
|
||||
remote_config = RemoteImplHydraConfig(**self.remote_config)
|
||||
return ModelInferenceConfig(
|
||||
return InferenceConfig(
|
||||
impl_config=remote_config.convert_to_remote_impl_config()
|
||||
)
|
||||
|
||||
|
||||
cs = ConfigStore.instance()
|
||||
cs.store(name="model_inference_config", node=ModelInferenceHydraConfig)
|
||||
cs.store(name="inference_config", node=InferenceHydraConfig)
|
||||
|
|
|
@ -90,7 +90,7 @@ class BatchChatCompletionResponse(BaseModel):
|
|||
completion_message_batch: List[CompletionMessage]
|
||||
|
||||
|
||||
class ModelInference(Protocol):
|
||||
class Inference(Protocol):
|
||||
|
||||
@webmethod(route="/inference/completion")
|
||||
async def completion(
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from .api.config import ImplType, ModelInferenceConfig
|
||||
from .api.config import ImplType, InferenceConfig
|
||||
|
||||
|
||||
async def get_inference_api_instance(config: ModelInferenceConfig):
|
||||
async def get_inference_api_instance(config: InferenceConfig):
|
||||
if config.impl_config.impl_type == ImplType.inline.value:
|
||||
from .inference import ModelInferenceImpl
|
||||
from .inference import InferenceImpl
|
||||
|
||||
return ModelInferenceImpl(config.impl_config)
|
||||
return InferenceImpl(config.impl_config)
|
||||
|
||||
from .client import ModelInferenceClient
|
||||
from .client import InferenceClient
|
||||
|
||||
return ModelInferenceClient(config.impl_config.url)
|
||||
return InferenceClient(config.impl_config.url)
|
||||
|
|
|
@ -10,12 +10,12 @@ from .api import (
|
|||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
InstructModel,
|
||||
ModelInference,
|
||||
Inference,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
class ModelInferenceClient(ModelInference):
|
||||
class InferenceClient(Inference):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
|
@ -48,7 +48,7 @@ class ModelInferenceClient(ModelInference):
|
|||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = ModelInferenceClient(f"http://{host}:{port}")
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
message = UserMessage(content="hello world, help me out here")
|
||||
req = ChatCompletionRequest(
|
||||
|
|
|
@ -18,12 +18,12 @@ from .api.endpoints import (
|
|||
ChatCompletionRequest,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
ModelInference,
|
||||
Inference,
|
||||
)
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
class ModelInferenceImpl(ModelInference):
|
||||
class InferenceImpl(Inference):
|
||||
|
||||
def __init__(self, config: InlineImplConfig) -> None:
|
||||
self.config = config
|
||||
|
|
|
@ -11,7 +11,7 @@ from fastapi.responses import StreamingResponse
|
|||
from omegaconf import OmegaConf
|
||||
|
||||
from toolchain.utils import get_default_config_dir, parse_config
|
||||
from .api.config import ModelInferenceHydraConfig
|
||||
from .api.config import InferenceHydraConfig
|
||||
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
|
||||
|
||||
from .api_instance import get_inference_api_instance
|
||||
|
@ -43,13 +43,13 @@ async def startup():
|
|||
global InferenceApiInstance
|
||||
|
||||
config = get_config()
|
||||
hydra_config = ModelInferenceHydraConfig(
|
||||
**OmegaConf.to_container(config["model_inference_config"], resolve=True)
|
||||
hydra_config = InferenceHydraConfig(
|
||||
**OmegaConf.to_container(config["inference_config"], resolve=True)
|
||||
)
|
||||
model_inference_config = hydra_config.convert_to_model_inferene_config()
|
||||
inference_config = hydra_config.convert_to_inference_config()
|
||||
|
||||
InferenceApiInstance = await get_inference_api_instance(
|
||||
model_inference_config,
|
||||
inference_config,
|
||||
)
|
||||
await InferenceApiInstance.initialize()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue