fix agentic calling inference

This commit is contained in:
Xi Yan 2024-09-11 18:30:09 -07:00
parent 2501b3d7de
commit f55ffa8b53
4 changed files with 8 additions and 22 deletions

View file

@ -132,7 +132,7 @@ async def run_main(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}") api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
SearchToolDefinition(engine=SearchEngineType.brave), SearchToolDefinition(engine=SearchEngineType.bing),
WolframAlphaToolDefinition(), WolframAlphaToolDefinition(),
CodeInterpreterToolDefinition(), CodeInterpreterToolDefinition(),
] ]

View file

@ -388,19 +388,17 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
req = ChatCompletionRequest( tool_calls = []
model=self.agent_config.model, content = ""
messages=input_messages, stop_reason = None
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(), tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True, stream=True,
sampling_params=sampling_params, sampling_params=sampling_params,
) ):
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion_impl(req):
event = chunk.event event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start: if event.event_type == ChatCompletionResponseEventType.start:
continue continue

View file

@ -77,14 +77,6 @@ class MetaReferenceInferenceImpl(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
async for chunk in self.chat_completion_impl(request):
yield chunk
async def chat_completion_impl(
self, request: ChatCompletionRequest
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
messages = prepare_messages(request) messages = prepare_messages(request)
model = resolve_model(request.model) model = resolve_model(request.model)
if model is None: if model is None:

View file

@ -125,11 +125,7 @@ Event = Annotated[
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/log_event") @webmethod(route="/telemetry/log_event")
<<<<<<< migrate_request_wrapper
async def log_event(self, event: Event) -> None: ... async def log_event(self, event: Event) -> None: ...
=======
async def log_event(self, event: Event): ...
>>>>>>> main
@webmethod(route="/telemetry/get_trace", method="GET") @webmethod(route="/telemetry/get_trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> Trace: ...