Separate chat_completion stream and non-stream implementations

This is a pretty important requirement. The streaming response type is
an AsyncGenerator while the non-stream one is a single object. So far
this has worked _sometimes_ due to various pre-existing hacks (and in
some cases, just failed.)
This commit is contained in:
Ashwin Bharambe 2024-10-08 10:52:16 -07:00 committed by Ashwin Bharambe
parent f8752ab8dc
commit 0c9eb3341c
5 changed files with 346 additions and 287 deletions

View file

@ -42,10 +42,10 @@ class InferenceClient(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -66,48 +66,57 @@ class InferenceClient(Inference):
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with httpx.AsyncClient() as client:
if stream:
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}",
"red",
)
return
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if "error" in data:
cprint(data, "red")
continue
response.raise_for_status()
j = response.json()
yield ChatCompletionResponse(**j)
yield ChatCompletionResponseStreamChunk(
**json.loads(data)
)
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
else:
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}",
"red",
)
return
response.raise_for_status()
j = response.json()
yield ChatCompletionResponse(**j)
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if "error" in data:
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main(