Remove request wrapper migration (#64)

* [1/n] migrate inference/chat_completion

* migrate inference/completion

* inference/completion

* inference regenerate openapi spec

* safety api

* migrate agentic system

* migrate apis without implementations

* re-generate openapi spec

* remove hack from openapi generator

* fix inference

* fix inference

* openapi generator rerun

* Simplified Telemetry API and tying it to logger (#57)

* Simplified Telemetry API and tying it to logger

* small update which adds a METRIC type

* move span events one level down into structured log events

---------

Co-authored-by: Ashwin Bharambe <ashwin@meta.com>

* fix api to work with openapi generator

* fix agentic calling inference

* together adapter inference

* update inference adapters

---------

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Xi Yan 2024-09-12 15:03:49 -07:00 committed by GitHub
parent 1d0e91d802
commit 5712566061
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 1211 additions and 3031 deletions

View file

@ -416,7 +416,16 @@ class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/turn/create") @webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn( async def create_agentic_system_turn(
self, self,
request: AgenticSystemTurnCreateRequest, agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AgenticSystemTurnResponseStreamChunk: ... ) -> AgenticSystemTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get") @webmethod(route="/agentic_system/turn/get")

View file

@ -73,9 +73,7 @@ class AgenticSystemClient(AgenticSystem):
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/agentic_system/turn/create", f"{self.base_url}/agentic_system/turn/create",
json={ json=encodable_dict(request),
"request": encodable_dict(request),
},
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as response: ) as response:

View file

@ -388,19 +388,17 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
req = ChatCompletionRequest( tool_calls = []
model=self.agent_config.model, content = ""
messages=input_messages, stop_reason = None
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(), tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True, stream=True,
sampling_params=sampling_params, sampling_params=sampling_params,
) ):
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion(req):
event = chunk.event event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start: if event.event_type == ChatCompletionResponseEventType.start:
continue continue

View file

@ -114,8 +114,26 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
async def create_agentic_system_turn( async def create_agentic_system_turn(
self, self,
request: AgenticSystemTurnCreateRequest, agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgenticSystemTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=stream,
)
agent_id = request.agent_id agent_id = request.agent_id
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found" assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id] agent = AGENT_INSTANCES_BY_ID[agent_id]

View file

@ -51,11 +51,21 @@ class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion") @webmethod(route="/batch_inference/completion")
async def batch_completion( async def batch_completion(
self, self,
request: BatchCompletionRequest, model: str,
content_batch: List[InterleavedTextMedia],
sampling_params: Optional[SamplingParams] = SamplingParams(),
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ... ) -> BatchCompletionResponse: ...
@webmethod(route="/batch_inference/chat_completion") @webmethod(route="/batch_inference/chat_completion")
async def batch_chat_completion( async def batch_chat_completion(
self, self,
request: BatchChatCompletionRequest, model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ... ) -> BatchChatCompletionResponse: ...

View file

@ -46,7 +46,8 @@ class Datasets(Protocol):
@webmethod(route="/datasets/create") @webmethod(route="/datasets/create")
def create_dataset( def create_dataset(
self, self,
request: CreateDatasetRequest, uuid: str,
dataset: TrainEvalDataset,
) -> None: ... ) -> None: ...
@webmethod(route="/datasets/get") @webmethod(route="/datasets/get")

View file

@ -86,19 +86,19 @@ class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/") @webmethod(route="/evaluate/text_generation/")
def evaluate_text_generation( def evaluate_text_generation(
self, self,
request: EvaluateTextGenerationRequest, metrics: List[TextGenerationMetric],
) -> EvaluationJob: ... ) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/") @webmethod(route="/evaluate/question_answering/")
def evaluate_question_answering( def evaluate_question_answering(
self, self,
request: EvaluateQuestionAnsweringRequest, metrics: List[QuestionAnsweringMetric],
) -> EvaluationJob: ... ) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/") @webmethod(route="/evaluate/summarization/")
def evaluate_summarization( def evaluate_summarization(
self, self,
request: EvaluateSummarizationRequest, metrics: List[SummarizationMetric],
) -> EvaluationJob: ... ) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs") @webmethod(route="/evaluate/jobs")

View file

@ -76,7 +76,28 @@ class FireworksInferenceAdapter(Inference):
return options return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request) messages = prepare_messages(request)
# accumulate sampling params and other options to pass to fireworks # accumulate sampling params and other options to pass to fireworks

View file

@ -84,7 +84,28 @@ class OllamaInferenceAdapter(Inference):
return options return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request) messages = prepare_messages(request)
# accumulate sampling params and other options to pass to ollama # accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request) options = self.get_ollama_chat_options(request)

View file

@ -82,7 +82,28 @@ class TGIAdapter(Inference):
return options return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request) messages = prepare_messages(request)
model_input = self.formatter.encode_dialog_prompt(messages) model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens) prompt = self.tokenizer.decode(model_input.tokens)

View file

@ -76,7 +76,29 @@ class TogetherInferenceAdapter(Inference):
return options return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
# accumulate sampling params and other options to pass to together # accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request) options = self.get_together_chat_options(request)
together_model = self.resolve_together_model(request.model) together_model = self.resolve_together_model(request.model)

View file

@ -85,6 +85,8 @@ class CompletionRequest(BaseModel):
@json_schema_type @json_schema_type
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
"""Completion response."""
completion_message: CompletionMessage completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
@ -108,6 +110,8 @@ class BatchCompletionRequest(BaseModel):
@json_schema_type @json_schema_type
class BatchCompletionResponse(BaseModel): class BatchCompletionResponse(BaseModel):
"""Batch completion response."""
completion_message_batch: List[CompletionMessage] completion_message_batch: List[CompletionMessage]
@ -137,6 +141,8 @@ class ChatCompletionResponseStreamChunk(BaseModel):
@json_schema_type @json_schema_type
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
"""Chat completion response."""
completion_message: CompletionMessage completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
@ -170,13 +176,25 @@ class Inference(Protocol):
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
async def completion( async def completion(
self, self,
request: CompletionRequest, model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat_completion")
async def chat_completion( async def chat_completion(
self, self,
request: ChatCompletionRequest, model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/embeddings") @webmethod(route="/inference/embeddings")

View file

@ -10,10 +10,10 @@ from typing import Any, AsyncGenerator
import fire import fire
import httpx import httpx
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_toolchain.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from .api import ( from .api import (
ChatCompletionRequest, ChatCompletionRequest,
@ -52,9 +52,7 @@ class InferenceClient(Inference):
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/inference/chat_completion", f"{self.base_url}/inference/chat_completion",
json={ json=encodable_dict(request),
"request": encodable_dict(request),
},
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as response: ) as response:

View file

@ -22,9 +22,12 @@ from llama_toolchain.inference.api import (
ToolCallParseStatus, ToolCallParseStatus,
) )
from llama_toolchain.inference.prepare_messages import prepare_messages from llama_toolchain.inference.prepare_messages import prepare_messages
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process. # we don't support multiple concurrent requests to this process.
@ -50,10 +53,30 @@ class MetaReferenceInferenceImpl(Inference):
# hm, when stream=False, we should not be doing SSE :/ which is what the # hm, when stream=False, we should not be doing SSE :/ which is what the
# top-level server is going to do. make the typing more specific here # top-level server is going to do. make the typing more specific here
async def chat_completion( async def chat_completion(
self, request: ChatCompletionRequest self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[ ) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]: ]:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request) messages = prepare_messages(request)
model = resolve_model(request.model) model = resolve_model(request.model)
if model is None: if model is None:

View file

@ -179,13 +179,33 @@ class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune") @webmethod(route="/post_training/supervised_fine_tune")
def supervised_fine_tune( def supervised_fine_tune(
self, self,
request: PostTrainingSFTRequest, job_uuid: str,
model: str,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
algorithm: FinetuningAlgorithm,
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize") @webmethod(route="/post_training/preference_optimize")
def preference_optimize( def preference_optimize(
self, self,
request: PostTrainingRLHFRequest, job_uuid: str,
finetuned_model: URL,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post_training/jobs") @webmethod(route="/post_training/jobs")

View file

@ -50,5 +50,6 @@ class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score") @webmethod(route="/reward_scoring/score")
def reward_score( def reward_score(
self, self,
request: RewardScoringRequest, dialog_generations: List[DialogGenerations],
model: str,
) -> Union[RewardScoringResponse]: ... ) -> Union[RewardScoringResponse]: ...

View file

@ -86,5 +86,6 @@ class Safety(Protocol):
@webmethod(route="/safety/run_shields") @webmethod(route="/safety/run_shields")
async def run_shields( async def run_shields(
self, self,
request: RunShieldRequest, messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse: ... ) -> RunShieldResponse: ...

View file

@ -13,10 +13,10 @@ import fire
import httpx import httpx
from llama_models.llama3.api.datatypes import UserMessage from llama_models.llama3.api.datatypes import UserMessage
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_toolchain.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from .api import * # noqa: F403 from .api import * # noqa: F403
@ -43,9 +43,7 @@ class SafetyClient(Safety):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/safety/run_shields", f"{self.base_url}/safety/run_shields",
json={ json=encodable_dict(request),
"request": encodable_dict(request),
},
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) )

View file

@ -52,13 +52,12 @@ class MetaReferenceSafetyImpl(Safety):
async def run_shields( async def run_shields(
self, self,
request: RunShieldRequest, messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse: ) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in request.shields] shields = [shield_config_to_shield(c, self.config) for c in shields]
responses = await asyncio.gather( responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
*[shield.run(request.messages) for shield in shields]
)
return RunShieldResponse(responses=responses) return RunShieldResponse(responses=responses)

View file

@ -48,5 +48,7 @@ class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate") @webmethod(route="/synthetic_data_generation/generate")
def synthetic_data_generate( def synthetic_data_generate(
self, self,
request: SyntheticDataGenerationRequest, dialogs: List[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: Optional[str] = None,
) -> Union[SyntheticDataGenerationResponse]: ... ) -> Union[SyntheticDataGenerationResponse]: ...

View file

@ -125,7 +125,7 @@ Event = Annotated[
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/log_event") @webmethod(route="/telemetry/log_event")
async def log_event(self, event: Event): ... async def log_event(self, event: Event) -> None: ...
@webmethod(route="/telemetry/get_trace", method="GET") @webmethod(route="/telemetry/get_trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> Trace: ...

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -471,14 +471,9 @@ class Generator:
from dataclasses import make_dataclass from dataclasses import make_dataclass
if len(op.request_params) == 1 and "Request" in first[1].__name__: op_name = "".join(word.capitalize() for word in op.name.split("_"))
# TODO(ashwin): Undo the "Request" hack and this entire block eventually request_name = f"{op_name}Request"
request_name = first[1].__name__ + "Wrapper" request_type = make_dataclass(request_name, op.request_params)
request_type = make_dataclass(request_name, op.request_params)
else:
op_name = "".join(word.capitalize() for word in op.name.split("_"))
request_name = f"{op_name}Request"
request_type = make_dataclass(request_name, op.request_params)
requestBody = RequestBody( requestBody = RequestBody(
content={ content={

View file

@ -249,7 +249,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
stream=True, stream=True,
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
) )
iterator = self.api.chat_completion(request) iterator = self.api.chat_completion(
request.model,
request.messages,
stream=request.stream,
tools=request.tools,
)
events = [] events = []
async for chunk in iterator: async for chunk in iterator:

View file

@ -61,7 +61,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
], ],
stream=False, stream=False,
) )
iterator = self.api.chat_completion(request) iterator = self.api.chat_completion(
request.model, request.messages, stream=request.stream
)
async for r in iterator: async for r in iterator:
response = r response = r
print(response.completion_message.content) print(response.completion_message.content)