mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 13:44:38 +00:00
fixed update_prompt to properly handle latest and default version, made version a required parameter, and removed unused CreatePromptRequest
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
1390660dcf
commit
beb2db487d
6 changed files with 43 additions and 41 deletions
|
@ -42,7 +42,7 @@ class TestPrompts:
|
|||
|
||||
async def test_update_prompt(self, store):
|
||||
prompt = await store.create_prompt("Original")
|
||||
updated = await store.update_prompt(prompt.prompt_id, "Updated", ["v"])
|
||||
updated = await store.update_prompt(prompt.prompt_id, "Updated", "1", ["v"])
|
||||
assert updated.version == "2"
|
||||
assert updated.prompt == "Updated"
|
||||
|
||||
|
@ -51,16 +51,16 @@ class TestPrompts:
|
|||
|
||||
prompt = await store.create_prompt("Original")
|
||||
assert prompt.version == "1"
|
||||
prompt = await store.update_prompt(prompt.prompt_id, "Updated", ["v"], version_for_update)
|
||||
prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
|
||||
assert prompt.version == "2"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# now this is a stale version
|
||||
await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], version_for_update)
|
||||
await store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# this version does not exist
|
||||
await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], "99")
|
||||
await store.update_prompt(prompt.prompt_id, "Another Update", "99", ["v"])
|
||||
|
||||
async def test_delete_prompt(self, store):
|
||||
prompt = await store.create_prompt("to be deleted")
|
||||
|
@ -80,7 +80,7 @@ class TestPrompts:
|
|||
|
||||
async def test_version(self, store):
|
||||
prompt = await store.create_prompt("V1")
|
||||
await store.update_prompt(prompt.prompt_id, "V2")
|
||||
await store.update_prompt(prompt.prompt_id, "V2", "1")
|
||||
|
||||
v1 = await store.get_prompt(prompt.prompt_id, version="1")
|
||||
assert v1.version == "1" and v1.prompt == "V1"
|
||||
|
@ -89,11 +89,16 @@ class TestPrompts:
|
|||
assert latest.version == "2" and latest.prompt == "V2"
|
||||
|
||||
async def test_set_default_version(self, store):
|
||||
prompt = await store.create_prompt("V1")
|
||||
await store.update_prompt(prompt.prompt_id, "V2")
|
||||
prompt0 = await store.create_prompt("V1")
|
||||
prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", "1")
|
||||
|
||||
await store.set_default_version(prompt.prompt_id, "1")
|
||||
assert (await store.get_prompt(prompt.prompt_id)).version == "1"
|
||||
assert (await store.get_prompt(prompt0.prompt_id)).version == "2"
|
||||
prompt_default = await store.set_default_version(prompt0.prompt_id, "1")
|
||||
assert (await store.get_prompt(prompt0.prompt_id)).version == "1"
|
||||
assert prompt_default.version == "1"
|
||||
|
||||
prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
|
||||
assert prompt2.version == "3"
|
||||
|
||||
async def test_prompt_id_generation_and_validation(self, store):
|
||||
prompt = await store.create_prompt("Test")
|
||||
|
@ -105,8 +110,8 @@ class TestPrompts:
|
|||
|
||||
async def test_list_shows_default_versions(self, store):
|
||||
prompt = await store.create_prompt("V1")
|
||||
await store.update_prompt(prompt.prompt_id, "V2")
|
||||
await store.update_prompt(prompt.prompt_id, "V3")
|
||||
await store.update_prompt(prompt.prompt_id, "V2", "1")
|
||||
await store.update_prompt(prompt.prompt_id, "V3", "2")
|
||||
|
||||
response = await store.list_prompts()
|
||||
listed_prompt = response.data[0]
|
||||
|
@ -120,8 +125,8 @@ class TestPrompts:
|
|||
|
||||
async def test_get_all_prompt_versions(self, store):
|
||||
prompt = await store.create_prompt("V1")
|
||||
await store.update_prompt(prompt.prompt_id, "V2")
|
||||
await store.update_prompt(prompt.prompt_id, "V3")
|
||||
await store.update_prompt(prompt.prompt_id, "V2", "1")
|
||||
await store.update_prompt(prompt.prompt_id, "V3", "2")
|
||||
|
||||
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
|
||||
assert len(versions) == 3
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue