Add inference test

Run it as:

```
PROVIDER_ID=test-remote \
 PROVIDER_CONFIG=$PWD/llama_stack/providers/tests/inference/provider_config_example.yaml \
 pytest -s llama_stack/providers/tests/inference/test_inference.py \
 --tb=auto \
 --disable-warnings
```
This commit is contained in:
Ashwin Bharambe 2024-10-07 15:46:16 -07:00 committed by Ashwin Bharambe
parent 4fa467731e
commit 3ae2b712e8
8 changed files with 356 additions and 54 deletions

View file

@ -67,25 +67,26 @@ class InferenceClient(Inference):
logprobs=logprobs,
)
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}", "red"
)
return
if stream:
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}",
"red",
)
return
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if request.stream:
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if "error" in data:
cprint(data, "red")
continue
@ -93,11 +94,20 @@ class InferenceClient(Inference):
yield ChatCompletionResponseStreamChunk(
**json.loads(data)
)
else:
yield ChatCompletionResponse(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
else:
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
response.raise_for_status()
j = response.json()
yield ChatCompletionResponse(**j)
async def run_main(