diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index 478f77773..67a46a1c5 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -10,12 +10,15 @@ import hashlib import json import os import sqlite3 +import uuid from collections.abc import Generator from contextlib import contextmanager from enum import StrEnum from pathlib import Path from typing import Any, Literal, cast +from openai.types.chat import ChatCompletion, ChatCompletionChunk + from llama_stack.log import get_logger logger = get_logger(__name__, category="testing") @@ -248,6 +251,20 @@ async def _patched_inference_method(original_method, self, client_type, endpoint recording = _current_storage.find_recording(request_hash) if recording: response_body = recording["response"]["body"] + if ( + isinstance(response_body, list) + and len(response_body) > 0 + and isinstance(response_body[0], ChatCompletionChunk) + ): + # We can't replay chatcompletions with the same id and we store them in a sqlite database with a unique constraint on the id. + # So we generate a new id and replace the old one. + newid = uuid.uuid4().hex + response_body[0].id = "chatcmpl-" + newid + elif isinstance(response_body, ChatCompletion): + # We can't replay chatcompletions with the same id and we store them in a sqlite database with a unique constraint on the id. + # So we generate a new id and replace the old one. + newid = uuid.uuid4().hex + response_body.id = "chatcmpl-" + newid if recording["response"].get("is_streaming", False):