mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
incorporating feedback
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
af33a8c982
commit
574dffbe38
5 changed files with 56 additions and 25 deletions
6
docs/_static/llama-stack-spec.html
vendored
6
docs/_static/llama-stack-spec.html
vendored
|
@ -5261,7 +5261,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/prompts/{prompt_id}/default-version": {
|
"/v1/prompts/{prompt_id}/set-default-version": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -17829,6 +17829,10 @@
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"description": "Updated dictionary of variable names to their default values."
|
"description": "Updated dictionary of variable names to their default values."
|
||||||
|
},
|
||||||
|
"version": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The current version of the prompt being updated (as a string)."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
6
docs/_static/llama-stack-spec.yaml
vendored
6
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3726,7 +3726,7 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/ScoreBatchRequest'
|
$ref: '#/components/schemas/ScoreBatchRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/prompts/{prompt_id}/default-version:
|
/v1/prompts/{prompt_id}/set-default-version:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -13231,6 +13231,10 @@ components:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
Updated dictionary of variable names to their default values.
|
Updated dictionary of variable names to their default values.
|
||||||
|
version:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The current version of the prompt being updated (as a string).
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- prompt
|
- prompt
|
||||||
|
|
|
@ -153,12 +153,14 @@ class Prompts(Protocol):
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
variables: dict[str, str] | None = None,
|
variables: dict[str, str] | None = None,
|
||||||
|
version: str | None = None,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Update an existing prompt (increments version).
|
"""Update an existing prompt (increments version).
|
||||||
|
|
||||||
:param prompt_id: The identifier of the prompt to update.
|
:param prompt_id: The identifier of the prompt to update.
|
||||||
:param prompt: The updated prompt text content.
|
:param prompt: The updated prompt text content.
|
||||||
:param variables: Updated dictionary of variable names to their default values.
|
:param variables: Updated dictionary of variable names to their default values.
|
||||||
|
:param version: The current version of the prompt being updated (as a string).
|
||||||
:returns: The updated Prompt resource with incremented version.
|
:returns: The updated Prompt resource with incremented version.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
@ -174,7 +176,7 @@ class Prompts(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/prompts/{prompt_id}/default-version", method="PUT")
|
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT")
|
||||||
async def set_default_version(
|
async def set_default_version(
|
||||||
self,
|
self,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
|
|
|
@ -46,10 +46,20 @@ class PromptServiceImpl(Prompts):
|
||||||
)
|
)
|
||||||
self.kvstore = await kvstore_impl(kvstore_config)
|
self.kvstore = await kvstore_impl(kvstore_config)
|
||||||
|
|
||||||
def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str:
|
def _get_default_key(self, prompt_id: str) -> str:
|
||||||
|
"""Get the KVStore key that stores the default version number."""
|
||||||
|
return f"prompts:v1:{prompt_id}:default"
|
||||||
|
|
||||||
|
async def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str:
|
||||||
|
"""Get the KVStore key for prompt data, returning default version if applicable."""
|
||||||
if version:
|
if version:
|
||||||
return self._get_version_key(prompt_id, version)
|
return self._get_version_key(prompt_id, version)
|
||||||
return f"prompts:v1:{prompt_id}:default"
|
|
||||||
|
default_key = self._get_default_key(prompt_id)
|
||||||
|
resolved_version = await self.kvstore.get(default_key)
|
||||||
|
if resolved_version is None:
|
||||||
|
raise ValueError(f"Prompt {prompt_id}:default not found")
|
||||||
|
return self._get_version_key(prompt_id, resolved_version)
|
||||||
|
|
||||||
def _get_version_key(self, prompt_id: str, version: str) -> str:
|
def _get_version_key(self, prompt_id: str, version: str) -> str:
|
||||||
"""Get the KVStore key for a specific prompt version."""
|
"""Get the KVStore key for a specific prompt version."""
|
||||||
|
@ -102,22 +112,10 @@ class PromptServiceImpl(Prompts):
|
||||||
|
|
||||||
async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt:
|
async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt:
|
||||||
"""Get a prompt by its identifier and optional version."""
|
"""Get a prompt by its identifier and optional version."""
|
||||||
if version:
|
key = await self._get_prompt_key(prompt_id, version)
|
||||||
key = self._get_version_key(prompt_id, version)
|
data = await self.kvstore.get(key)
|
||||||
data = await self.kvstore.get(key)
|
if data is None:
|
||||||
if data is None:
|
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
|
||||||
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
|
||||||
else:
|
|
||||||
default_key = self._get_prompt_key(prompt_id)
|
|
||||||
default_version = await self.kvstore.get(default_key)
|
|
||||||
if default_version is None:
|
|
||||||
raise ValueError(f"Prompt with ID '{prompt_id}' not found")
|
|
||||||
|
|
||||||
key = self._get_version_key(prompt_id, default_version)
|
|
||||||
data = await self.kvstore.get(key)
|
|
||||||
if data is None:
|
|
||||||
raise ValueError(f"Prompt with ID '{prompt_id}' not found")
|
|
||||||
|
|
||||||
return self._deserialize_prompt(data)
|
return self._deserialize_prompt(data)
|
||||||
|
|
||||||
async def create_prompt(
|
async def create_prompt(
|
||||||
|
@ -135,7 +133,7 @@ class PromptServiceImpl(Prompts):
|
||||||
data = self._serialize_prompt(prompt_obj)
|
data = self._serialize_prompt(prompt_obj)
|
||||||
await self.kvstore.set(version_key, data)
|
await self.kvstore.set(version_key, data)
|
||||||
|
|
||||||
default_key = self._get_prompt_key(prompt_obj.prompt_id)
|
default_key = self._get_default_key(prompt_obj.prompt_id)
|
||||||
await self.kvstore.set(default_key, "1")
|
await self.kvstore.set(default_key, "1")
|
||||||
|
|
||||||
return prompt_obj
|
return prompt_obj
|
||||||
|
@ -145,13 +143,20 @@ class PromptServiceImpl(Prompts):
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
variables: dict[str, str] | None = None,
|
variables: dict[str, str] | None = None,
|
||||||
|
version: str | None = None,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Update an existing prompt (increments version)."""
|
"""Update an existing prompt (increments version)."""
|
||||||
if variables is None:
|
if variables is None:
|
||||||
variables = {}
|
variables = {}
|
||||||
|
|
||||||
current_prompt = await self.get_prompt(prompt_id)
|
current_prompt = await self.get_prompt(prompt_id)
|
||||||
new_version = str(int(current_prompt.version) + 1)
|
if version and current_prompt.version != version:
|
||||||
|
raise ValueError(
|
||||||
|
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{current_prompt.version}' in request."
|
||||||
|
)
|
||||||
|
|
||||||
|
current_version = current_prompt.version if version is None else version
|
||||||
|
new_version = str(int(current_version) + 1)
|
||||||
|
|
||||||
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
||||||
|
|
||||||
|
@ -159,7 +164,7 @@ class PromptServiceImpl(Prompts):
|
||||||
data = self._serialize_prompt(updated_prompt)
|
data = self._serialize_prompt(updated_prompt)
|
||||||
await self.kvstore.set(version_key, data)
|
await self.kvstore.set(version_key, data)
|
||||||
|
|
||||||
default_key = self._get_prompt_key(prompt_id)
|
default_key = self._get_default_key(prompt_id)
|
||||||
await self.kvstore.set(default_key, new_version)
|
await self.kvstore.set(default_key, new_version)
|
||||||
|
|
||||||
return updated_prompt
|
return updated_prompt
|
||||||
|
@ -207,7 +212,7 @@ class PromptServiceImpl(Prompts):
|
||||||
if data is None:
|
if data is None:
|
||||||
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
||||||
|
|
||||||
default_key = self._get_prompt_key(prompt_id)
|
default_key = self._get_default_key(prompt_id)
|
||||||
await self.kvstore.set(default_key, version)
|
await self.kvstore.set(default_key, version)
|
||||||
|
|
||||||
return self._deserialize_prompt(data)
|
return self._deserialize_prompt(data)
|
||||||
|
|
|
@ -46,6 +46,22 @@ class TestPrompts:
|
||||||
assert updated.version == "2"
|
assert updated.version == "2"
|
||||||
assert updated.prompt == "Updated"
|
assert updated.prompt == "Updated"
|
||||||
|
|
||||||
|
async def test_update_prompt_with_version(self, store):
|
||||||
|
version_for_update = "1"
|
||||||
|
|
||||||
|
prompt = await store.create_prompt("Original")
|
||||||
|
assert prompt.version == "1"
|
||||||
|
prompt = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"}, version_for_update)
|
||||||
|
assert prompt.version == "2"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# now this is a stale version
|
||||||
|
await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, version_for_update)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# this version does not exist
|
||||||
|
await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, "99")
|
||||||
|
|
||||||
async def test_delete_prompt(self, store):
|
async def test_delete_prompt(self, store):
|
||||||
prompt = await store.create_prompt("to be deleted")
|
prompt = await store.create_prompt("to be deleted")
|
||||||
await store.delete_prompt(prompt.prompt_id)
|
await store.delete_prompt(prompt.prompt_id)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue