rename ModelInference to Inference

This commit is contained in:
rsm 2024-07-21 12:19:52 -07:00
parent 245461620d
commit 67f0510edd
18 changed files with 468 additions and 1636 deletions

View file

@ -75,7 +75,7 @@ class RemoteImplConfig(BaseModel):
url: str = Field(..., description="The URL of the remote module")
class ModelInferenceConfig(BaseModel):
class InferenceConfig(BaseModel):
impl_config: Annotated[
Union[InlineImplConfig, RemoteImplConfig],
Field(discriminator="impl_type"),
@ -130,7 +130,7 @@ class RemoteImplHydraConfig:
@dataclass
class ModelInferenceHydraConfig:
class InferenceHydraConfig:
impl_type: str
inline_config: Optional[InlineImplHydraConfig] = None
remote_config: Optional[RemoteImplHydraConfig] = None
@ -142,18 +142,18 @@ class ModelInferenceHydraConfig:
if self.impl_type == "remote":
assert self.remote_config is not None
def convert_to_model_inferene_config(self):
def convert_to_inference_config(self):
if self.impl_type == "inline":
inline_config = InlineImplHydraConfig(**self.inline_config)
return ModelInferenceConfig(
return InferenceConfig(
impl_config=inline_config.convert_to_inline_impl_config()
)
elif self.impl_type == "remote":
remote_config = RemoteImplHydraConfig(**self.remote_config)
return ModelInferenceConfig(
return InferenceConfig(
impl_config=remote_config.convert_to_remote_impl_config()
)
cs = ConfigStore.instance()
cs.store(name="model_inference_config", node=ModelInferenceHydraConfig)
cs.store(name="inference_config", node=InferenceHydraConfig)

View file

@ -90,7 +90,7 @@ class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
class ModelInference(Protocol):
class Inference(Protocol):
@webmethod(route="/inference/completion")
async def completion(

View file

@ -1,12 +1,12 @@
from .api.config import ImplType, ModelInferenceConfig
from .api.config import ImplType, InferenceConfig
async def get_inference_api_instance(config: ModelInferenceConfig):
async def get_inference_api_instance(config: InferenceConfig):
if config.impl_config.impl_type == ImplType.inline.value:
from .inference import ModelInferenceImpl
from .inference import InferenceImpl
return ModelInferenceImpl(config.impl_config)
return InferenceImpl(config.impl_config)
from .client import ModelInferenceClient
from .client import InferenceClient
return ModelInferenceClient(config.impl_config.url)
return InferenceClient(config.impl_config.url)

View file

@ -10,12 +10,12 @@ from .api import (
ChatCompletionResponseStreamChunk,
CompletionRequest,
InstructModel,
ModelInference,
Inference,
UserMessage,
)
class ModelInferenceClient(ModelInference):
class InferenceClient(Inference):
def __init__(self, base_url: str):
self.base_url = base_url
@ -48,7 +48,7 @@ class ModelInferenceClient(ModelInference):
async def run_main(host: str, port: int):
client = ModelInferenceClient(f"http://{host}:{port}")
client = InferenceClient(f"http://{host}:{port}")
message = UserMessage(content="hello world, help me out here")
req = ChatCompletionRequest(

View file

@ -18,12 +18,12 @@ from .api.endpoints import (
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
ModelInference,
Inference,
)
from .model_parallel import LlamaModelParallelGenerator
class ModelInferenceImpl(ModelInference):
class InferenceImpl(Inference):
def __init__(self, config: InlineImplConfig) -> None:
self.config = config

View file

@ -11,7 +11,7 @@ from fastapi.responses import StreamingResponse
from omegaconf import OmegaConf
from toolchain.utils import get_default_config_dir, parse_config
from .api.config import ModelInferenceHydraConfig
from .api.config import InferenceHydraConfig
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
from .api_instance import get_inference_api_instance
@ -43,13 +43,13 @@ async def startup():
global InferenceApiInstance
config = get_config()
hydra_config = ModelInferenceHydraConfig(
**OmegaConf.to_container(config["model_inference_config"], resolve=True)
hydra_config = InferenceHydraConfig(
**OmegaConf.to_container(config["inference_config"], resolve=True)
)
model_inference_config = hydra_config.convert_to_model_inferene_config()
inference_config = hydra_config.convert_to_inference_config()
InferenceApiInstance = await get_inference_api_instance(
model_inference_config,
inference_config,
)
await InferenceApiInstance.initialize()