mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Add inference test
Run it as: ``` PROVIDER_ID=test-remote \ PROVIDER_CONFIG=$PWD/llama_stack/providers/tests/inference/provider_config_example.yaml \ pytest -s llama_stack/providers/tests/inference/test_inference.py \ --tb=auto \ --disable-warnings ```
This commit is contained in:
parent
4fa467731e
commit
3ae2b712e8
8 changed files with 356 additions and 54 deletions
|
|
@ -67,25 +67,26 @@ class InferenceClient(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
content = await response.aread()
|
||||
cprint(
|
||||
f"Error: HTTP {response.status_code} {content.decode()}", "red"
|
||||
)
|
||||
return
|
||||
if stream:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
content = await response.aread()
|
||||
cprint(
|
||||
f"Error: HTTP {response.status_code} {content.decode()}",
|
||||
"red",
|
||||
)
|
||||
return
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if request.stream:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
|
@ -93,11 +94,20 @@ class InferenceClient(Inference):
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
**json.loads(data)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponse(**json.loads(data))
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
else:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
yield ChatCompletionResponse(**j)
|
||||
|
||||
|
||||
async def run_main(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue