Further bug fixes

This commit is contained in:
Ashwin Bharambe 2024-09-20 15:15:57 -07:00 committed by Xi Yan
parent 9252e81a7b
commit a57411b4b3
3 changed files with 30 additions and 18 deletions

View file

@ -10,21 +10,14 @@ from typing import Any, AsyncGenerator
import fire
import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .event_logger import EventLogger
from .inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
UserMessage,
)
from llama_stack.apis.inference import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
@ -48,7 +41,27 @@ class InferenceClient(Inference):
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",