From a654467552f654a2deaad7618933dcd9ac68c20b Mon Sep 17 00:00:00 2001 From: Michael Dawson Date: Wed, 28 May 2025 12:23:15 -0700 Subject: [PATCH 1/6] feat: add cpu/cuda config for prompt guard (#2194) # What does this PR do? Previously prompt guard was hard coded to require cuda which prevented it from being used on an instance without a cuda support. This PR allows prompt guard to be configured to use either cpu or cuda. [//]: # (If resolving an issue, uncomment and update the line below) Closes [#2133](https://github.com/meta-llama/llama-stack/issues/2133) ## Test Plan (Edited after incorporating suggestion) 1) started stack configured with prompt guard as follows on a system without a GPU and validated prompt guard could be used through the APIs 2) validated on a system with a gpu (but without llama stack) that the python selecting between cpu and cuda support returned the right value when a cuda device was available. 3) ran the unit tests as per - https://github.com/meta-llama/llama-stack/blob/main/tests/unit/README.md [//]: # (## Documentation) --------- Signed-off-by: Michael Dawson --- .../providers/inline/safety/prompt_guard/prompt_guard.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 56ce8285f..ff87889ea 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -75,7 +75,9 @@ class PromptGuardShield: self.temperature = temperature self.threshold = threshold - self.device = "cuda" + self.device = "cpu" + if torch.cuda.is_available(): + self.device = "cuda" # load model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_dir) From bfdd15d1fa2abcd40b56cf6bb895a4fb3c4211b2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 28 May 2025 13:17:48 -0700 Subject: [PATCH 2/6] fix(responses): use input, not original_input when storing the Response (#2300) We must store the full (re-hydrated) input not just the original input in the Response object. Of course, this is not very space efficient and we should likely find a better storage scheme so that we can only store unique entries in the database and then re-hydrate them efficiently later. But that can be done safely later. Closes https://github.com/meta-llama/llama-stack/issues/2299 ## Test Plan Unit test --- .../agents/meta_reference/openai_responses.py | 21 +++--- .../meta_reference/test_openai_responses.py | 66 +++++++++++++++++++ 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 3a56d41ef..1fcb1c461 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -292,12 +292,12 @@ class OpenAIResponsesImpl: async def _store_response( self, response: OpenAIResponseObject, - original_input: str | list[OpenAIResponseInput], + input: str | list[OpenAIResponseInput], ) -> None: new_input_id = f"msg_{uuid.uuid4()}" - if isinstance(original_input, str): + if isinstance(input, str): # synthesize a message from the input string - input_content = OpenAIResponseInputMessageContentText(text=original_input) + input_content = OpenAIResponseInputMessageContentText(text=input) input_content_item = OpenAIResponseMessage( role="user", content=[input_content], @@ -307,7 +307,7 @@ class OpenAIResponsesImpl: else: # we already have a list of messages input_items_data = [] - for input_item in original_input: + for input_item in input: if isinstance(input_item, OpenAIResponseMessage): # These may or may not already have an id, so dump to dict, check for id, and add if missing input_item_dict = input_item.model_dump() @@ -334,7 +334,6 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, ): stream = False if stream is None else stream - original_input = input # Keep reference for storage output_messages: list[OpenAIResponseOutput] = [] @@ -372,7 +371,7 @@ class OpenAIResponsesImpl: inference_result=inference_result, ctx=ctx, output_messages=output_messages, - original_input=original_input, + input=input, model=model, store=store, tools=tools, @@ -382,7 +381,7 @@ class OpenAIResponsesImpl: inference_result=inference_result, ctx=ctx, output_messages=output_messages, - original_input=original_input, + input=input, model=model, store=store, tools=tools, @@ -393,7 +392,7 @@ class OpenAIResponsesImpl: inference_result: Any, ctx: ChatCompletionContext, output_messages: list[OpenAIResponseOutput], - original_input: str | list[OpenAIResponseInput], + input: str | list[OpenAIResponseInput], model: str, store: bool | None, tools: list[OpenAIResponseInputTool] | None, @@ -423,7 +422,7 @@ class OpenAIResponsesImpl: if store: await self._store_response( response=response, - original_input=original_input, + input=input, ) return response @@ -433,7 +432,7 @@ class OpenAIResponsesImpl: inference_result: Any, ctx: ChatCompletionContext, output_messages: list[OpenAIResponseOutput], - original_input: str | list[OpenAIResponseInput], + input: str | list[OpenAIResponseInput], model: str, store: bool | None, tools: list[OpenAIResponseInputTool] | None, @@ -544,7 +543,7 @@ class OpenAIResponsesImpl: if store: await self._store_response( response=final_response, - original_input=original_input, + input=input, ) # Emit response.completed diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 9c491accb..5b6cee0ec 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -628,3 +628,69 @@ async def test_responses_store_list_input_items_logic(): result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc) assert result.object == "list" assert len(result.data) == 0 # Should return no items + + +@pytest.mark.asyncio +async def test_store_response_uses_rehydrated_input_with_previous_response( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test that _store_response uses the full re-hydrated input (including previous responses) + rather than just the original input when previous_response_id is provided.""" + + # Setup - Create a previous response that should be included in the stored input + previous_response = OpenAIResponseObjectWithInput( + id="resp-previous-123", + object="response", + created_at=1234567890, + model="meta-llama/Llama-3.1-8B-Instruct", + status="completed", + input=[ + OpenAIResponseMessage( + id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")] + ) + ], + output=[ + OpenAIResponseMessage( + id="msg-prev-assistant", + role="assistant", + content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")], + ) + ], + ) + + mock_responses_store.get_response_object.return_value = previous_response + + current_input = "Now what is 3+3?" + model = "meta-llama/Llama-3.1-8B-Instruct" + mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml") + mock_inference_api.openai_chat_completion.return_value = mock_chat_completion + + # Execute - Create response with previous_response_id + result = await openai_responses_impl.create_openai_response( + input=current_input, + model=model, + previous_response_id="resp-previous-123", + store=True, + ) + + store_call_args = mock_responses_store.store_response_object.call_args + stored_input = store_call_args.kwargs["input"] + + # Verify that the stored input contains the full re-hydrated conversation: + # 1. Previous user message + # 2. Previous assistant response + # 3. Current user message + assert len(stored_input) == 3 + + assert stored_input[0].role == "user" + assert stored_input[0].content[0].text == "What is 2+2?" + + assert stored_input[1].role == "assistant" + assert stored_input[1].content[0].text == "2+2 equals 4." + + assert stored_input[2].role == "user" + assert stored_input[2].content == "Now what is 3+3?" + + # Verify the response itself is correct + assert result.model == model + assert result.status == "completed" From f0d8ceb2422247b1c68bfda9d92f9561012310df Mon Sep 17 00:00:00 2001 From: Mark Campbell Date: Thu, 29 May 2025 17:53:45 +0100 Subject: [PATCH 3/6] chore: fix flaky distro_codegen script (#2305) # What does this PR do? Adds an import for all of the template modules before the executor to prevent deadlock Closes #2278 ## Test Plan ``` # Run the pre-commit multiple times and verify the deadlock doesn't occur for i in {1..10}; do pre-commit run --all-files; done ``` --- scripts/distro_codegen.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/scripts/distro_codegen.py b/scripts/distro_codegen.py index 8820caf55..d33c5de67 100755 --- a/scripts/distro_codegen.py +++ b/scripts/distro_codegen.py @@ -107,6 +107,13 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[ return None, [] +def pre_import_templates(template_dirs: list[Path]) -> None: + # Pre-import all template modules to avoid deadlocks. + for template_dir in template_dirs: + module_name = f"llama_stack.templates.{template_dir.name}" + importlib.import_module(module_name) + + def main(): templates_dir = REPO_ROOT / "llama_stack" / "templates" change_tracker = ChangedPathTracker() @@ -118,6 +125,8 @@ def main(): template_dirs = list(find_template_dirs(templates_dir)) task = progress.add_task("Processing distribution templates...", total=len(template_dirs)) + pre_import_templates(template_dirs) + # Create a partial function with the progress bar process_func = partial(process_template, progress=progress, change_tracker=change_tracker) From 168c7113dfd779825cf116727974e60204fda148 Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Thu, 29 May 2025 11:54:23 -0500 Subject: [PATCH 4/6] fix(providers): update sambanova json schema mode (#2306) # What does this PR do? Updates sambanova inference to use strict as false in json_schema structured output ## Test Plan pytest -s -v tests/integration/inference/test_text_inference.py --stack-config=sambanova --text-model=sambanova/Meta-Llama-3.3-70B-Instruct --- llama_stack/providers/remote/inference/sambanova/sambanova.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index d182aa1dc..20f863665 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -218,7 +218,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): "json_schema": { "name": name, "schema": fmt, - "strict": True, + "strict": False, }, } if request.tools: From 2603f10f95fcd302297158adb709d2a84c9f60af Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 29 May 2025 14:33:09 -0700 Subject: [PATCH 5/6] feat: support postgresql inference store (#2310) # What does this PR do? * Added support postgresql inference store * Added 'oracle' template that demos how to config postgresql stores (except for telemetry, which is not supported currently) ## Test Plan llama stack build --template oracle --image-type conda --run LLAMA_STACK_CONFIG=http://localhost:8321 pytest -s -v tests/integration/ --text-model accounts/fireworks/models/llama-v3p3-70b-instruct -k 'inference_store' --- llama_stack/providers/utils/kvstore/config.py | 2 +- .../sqlite.py => sqlalchemy_sqlstore.py} | 38 +-- .../providers/utils/sqlstore/sqlstore.py | 38 ++- llama_stack/templates/bedrock/build.yaml | 1 + llama_stack/templates/cerebras/build.yaml | 1 + llama_stack/templates/ci-tests/build.yaml | 1 + llama_stack/templates/dell/build.yaml | 2 +- llama_stack/templates/fireworks/build.yaml | 2 +- llama_stack/templates/groq/build.yaml | 1 + llama_stack/templates/hf-endpoint/build.yaml | 2 +- .../templates/hf-serverless/build.yaml | 2 +- llama_stack/templates/llama_api/build.yaml | 1 + .../templates/meta-reference-gpu/build.yaml | 2 +- llama_stack/templates/nvidia/build.yaml | 2 +- llama_stack/templates/ollama/build.yaml | 2 +- .../templates/open-benchmark/build.yaml | 1 + llama_stack/templates/passthrough/build.yaml | 2 +- .../templates/postgres-demo/__init__.py | 7 + .../templates/postgres-demo/build.yaml | 24 ++ .../templates/postgres-demo/postgres_demo.py | 164 +++++++++++++ llama_stack/templates/postgres-demo/run.yaml | 224 ++++++++++++++++++ llama_stack/templates/remote-vllm/build.yaml | 2 +- llama_stack/templates/sambanova/build.yaml | 1 + llama_stack/templates/starter/build.yaml | 1 + llama_stack/templates/template.py | 14 +- llama_stack/templates/tgi/build.yaml | 2 +- llama_stack/templates/together/build.yaml | 2 +- llama_stack/templates/verification/build.yaml | 1 + llama_stack/templates/vllm-gpu/build.yaml | 1 + llama_stack/templates/watsonx/build.yaml | 1 + .../inference/test_openai_completion.py | 21 +- tests/unit/utils/test_sqlstore.py | 4 +- 32 files changed, 516 insertions(+), 53 deletions(-) rename llama_stack/providers/utils/sqlstore/{sqlite/sqlite.py => sqlalchemy_sqlstore.py} (83%) create mode 100644 llama_stack/templates/postgres-demo/__init__.py create mode 100644 llama_stack/templates/postgres-demo/build.yaml create mode 100644 llama_stack/templates/postgres-demo/postgres_demo.py create mode 100644 llama_stack/templates/postgres-demo/run.yaml diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index e9aac6e8c..bbb0c5c0a 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -65,7 +65,7 @@ class SqliteKVStoreConfig(CommonConfig): class PostgresKVStoreConfig(CommonConfig): type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value host: str = "localhost" - port: int = 5432 + port: str = "5432" db: str = "llamastack" user: str password: str | None = None diff --git a/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py similarity index 83% rename from llama_stack/providers/utils/sqlstore/sqlite/sqlite.py rename to llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 0ef5f0fa1..825220679 100644 --- a/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -19,10 +19,10 @@ from sqlalchemy import ( Text, select, ) -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from ..api import ColumnDefinition, ColumnType, SqlStore -from ..sqlstore import SqliteSqlStoreConfig +from .api import ColumnDefinition, ColumnType, SqlStore +from .sqlstore import SqlAlchemySqlStoreConfig TYPE_MAPPING: dict[ColumnType, Any] = { ColumnType.INTEGER: Integer, @@ -35,9 +35,10 @@ TYPE_MAPPING: dict[ColumnType, Any] = { } -class SqliteSqlStoreImpl(SqlStore): - def __init__(self, config: SqliteSqlStoreConfig): - self.engine = create_async_engine(config.engine_str) +class SqlAlchemySqlStoreImpl(SqlStore): + def __init__(self, config: SqlAlchemySqlStoreConfig): + self.config = config + self.async_session = async_sessionmaker(create_async_engine(config.engine_str)) self.metadata = MetaData() async def create_table( @@ -78,13 +79,14 @@ class SqliteSqlStoreImpl(SqlStore): # Create the table in the database if it doesn't exist # checkfirst=True ensures it doesn't try to recreate if it's already there - async with self.engine.begin() as conn: + engine = create_async_engine(self.config.engine_str) + async with engine.begin() as conn: await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) async def insert(self, table: str, data: Mapping[str, Any]) -> None: - async with self.engine.begin() as conn: - await conn.execute(self.metadata.tables[table].insert(), data) - await conn.commit() + async with self.async_session() as session: + await session.execute(self.metadata.tables[table].insert(), data) + await session.commit() async def fetch_all( self, @@ -93,7 +95,7 @@ class SqliteSqlStoreImpl(SqlStore): limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> list[dict[str, Any]]: - async with self.engine.begin() as conn: + async with self.async_session() as session: query = select(self.metadata.tables[table]) if where: for key, value in where.items(): @@ -117,7 +119,7 @@ class SqliteSqlStoreImpl(SqlStore): query = query.order_by(self.metadata.tables[table].c[name].desc()) else: raise ValueError(f"Invalid order '{order_type}' for column '{name}'") - result = await conn.execute(query) + result = await session.execute(query) if result.rowcount == 0: return [] return [dict(row._mapping) for row in result] @@ -142,20 +144,20 @@ class SqliteSqlStoreImpl(SqlStore): if not where: raise ValueError("where is required for update") - async with self.engine.begin() as conn: + async with self.async_session() as session: stmt = self.metadata.tables[table].update() for key, value in where.items(): stmt = stmt.where(self.metadata.tables[table].c[key] == value) - await conn.execute(stmt, data) - await conn.commit() + await session.execute(stmt, data) + await session.commit() async def delete(self, table: str, where: Mapping[str, Any]) -> None: if not where: raise ValueError("where is required for delete") - async with self.engine.begin() as conn: + async with self.async_session() as session: stmt = self.metadata.tables[table].delete() for key, value in where.items(): stmt = stmt.where(self.metadata.tables[table].c[key] == value) - await conn.execute(stmt) - await conn.commit() + await session.execute(stmt) + await session.commit() diff --git a/llama_stack/providers/utils/sqlstore/sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlstore.py index 99f64805f..3091e8f96 100644 --- a/llama_stack/providers/utils/sqlstore/sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -5,6 +5,7 @@ # the root directory of this source tree. +from abc import abstractmethod from enum import Enum from pathlib import Path from typing import Annotated, Literal @@ -21,7 +22,18 @@ class SqlStoreType(Enum): postgres = "postgres" -class SqliteSqlStoreConfig(BaseModel): +class SqlAlchemySqlStoreConfig(BaseModel): + @property + @abstractmethod + def engine_str(self) -> str: ... + + # TODO: move this when we have a better way to specify dependencies with internal APIs + @property + def pip_packages(self) -> list[str]: + return ["sqlalchemy[asyncio]"] + + +class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig): type: Literal["sqlite"] = SqlStoreType.sqlite.value db_path: str = Field( default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), @@ -39,18 +51,26 @@ class SqliteSqlStoreConfig(BaseModel): db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name, ) - # TODO: move this when we have a better way to specify dependencies with internal APIs @property def pip_packages(self) -> list[str]: - return ["sqlalchemy[asyncio]"] + return super().pip_packages + ["aiosqlite"] -class PostgresSqlStoreConfig(BaseModel): +class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig): type: Literal["postgres"] = SqlStoreType.postgres.value + host: str = "localhost" + port: str = "5432" + db: str = "llamastack" + user: str + password: str | None = None + + @property + def engine_str(self) -> str: + return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}" @property def pip_packages(self) -> list[str]: - raise NotImplementedError("Postgres is not implemented yet") + return super().pip_packages + ["asyncpg"] SqlStoreConfig = Annotated[ @@ -60,12 +80,10 @@ SqlStoreConfig = Annotated[ def sqlstore_impl(config: SqlStoreConfig) -> SqlStore: - if config.type == SqlStoreType.sqlite.value: - from .sqlite.sqlite import SqliteSqlStoreImpl + if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]: + from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl - impl = SqliteSqlStoreImpl(config) - elif config.type == SqlStoreType.postgres.value: - raise NotImplementedError("Postgres is not implemented yet") + impl = SqlAlchemySqlStoreImpl(config) else: raise ValueError(f"Unknown sqlstore type {config.type}") diff --git a/llama_stack/templates/bedrock/build.yaml b/llama_stack/templates/bedrock/build.yaml index 09fbf307d..97a06f77a 100644 --- a/llama_stack/templates/bedrock/build.yaml +++ b/llama_stack/templates/bedrock/build.yaml @@ -30,4 +30,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml index 95b0302f2..f26f4ed9b 100644 --- a/llama_stack/templates/cerebras/build.yaml +++ b/llama_stack/templates/cerebras/build.yaml @@ -30,4 +30,5 @@ distribution_spec: - inline::rag-runtime image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index 6fe96c603..9f4fbbdda 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -31,4 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/templates/dell/build.yaml index d37215f35..513df16c1 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/templates/dell/build.yaml @@ -31,5 +31,5 @@ distribution_spec: - inline::rag-runtime image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index f162d9b43..be19181c0 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -32,5 +32,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/groq/build.yaml b/llama_stack/templates/groq/build.yaml index 92b46ce66..819df22f0 100644 --- a/llama_stack/templates/groq/build.yaml +++ b/llama_stack/templates/groq/build.yaml @@ -27,4 +27,5 @@ distribution_spec: - inline::rag-runtime image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-endpoint/build.yaml b/llama_stack/templates/hf-endpoint/build.yaml index 4d09cc33e..8ede83694 100644 --- a/llama_stack/templates/hf-endpoint/build.yaml +++ b/llama_stack/templates/hf-endpoint/build.yaml @@ -30,5 +30,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-serverless/build.yaml b/llama_stack/templates/hf-serverless/build.yaml index d06c628ac..d0752db9a 100644 --- a/llama_stack/templates/hf-serverless/build.yaml +++ b/llama_stack/templates/hf-serverless/build.yaml @@ -31,5 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/llama_api/build.yaml b/llama_stack/templates/llama_api/build.yaml index d0dc08923..857e5f014 100644 --- a/llama_stack/templates/llama_api/build.yaml +++ b/llama_stack/templates/llama_api/build.yaml @@ -31,4 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index e0ac87e47..53ad411e3 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -30,5 +30,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index e1e6fb3d8..6bd8a0100 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -25,5 +25,5 @@ distribution_spec: - inline::rag-runtime image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 9d8ba3a1e..36a120897 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -33,5 +33,5 @@ distribution_spec: - remote::wolfram-alpha image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/templates/open-benchmark/build.yaml index aa6d876fe..840f1e1db 100644 --- a/llama_stack/templates/open-benchmark/build.yaml +++ b/llama_stack/templates/open-benchmark/build.yaml @@ -34,4 +34,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/passthrough/build.yaml b/llama_stack/templates/passthrough/build.yaml index 7560f1032..46b99cb75 100644 --- a/llama_stack/templates/passthrough/build.yaml +++ b/llama_stack/templates/passthrough/build.yaml @@ -32,5 +32,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/postgres-demo/__init__.py b/llama_stack/templates/postgres-demo/__init__.py new file mode 100644 index 000000000..81473cb73 --- /dev/null +++ b/llama_stack/templates/postgres-demo/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .postgres_demo import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/postgres-demo/build.yaml b/llama_stack/templates/postgres-demo/build.yaml new file mode 100644 index 000000000..8f3648abe --- /dev/null +++ b/llama_stack/templates/postgres-demo/build.yaml @@ -0,0 +1,24 @@ +version: '2' +distribution_spec: + description: Quick start template for running Llama Stack with several popular providers + providers: + inference: + - remote::fireworks + - remote::vllm + vector_io: + - remote::chromadb + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::rag-runtime + - remote::model-context-protocol +image_type: conda +additional_pip_packages: +- asyncpg +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/postgres-demo/postgres_demo.py b/llama_stack/templates/postgres-demo/postgres_demo.py new file mode 100644 index 000000000..d2e352320 --- /dev/null +++ b/llama_stack/templates/postgres-demo/postgres_demo.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) +from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig +from llama_stack.providers.remote.inference.fireworks.models import ( + MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, +) +from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig +from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig +from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry +from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, + get_model_registry, +) + + +def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]: + # in this template, we allow each API key to be optional + providers = [ + ( + "fireworks", + FIREWORKS_MODEL_ENTRIES, + FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"), + ), + ] + inference_providers = [] + available_models = {} + for provider_id, model_entries, config in providers: + inference_providers.append( + Provider( + provider_id=provider_id, + provider_type=f"remote::{provider_id}", + config=config, + ) + ) + available_models[provider_id] = model_entries + inference_providers.append( + Provider( + provider_id="vllm-inference", + provider_type="remote::vllm", + config=VLLMInferenceAdapterConfig.sample_run_config( + url="${env.VLLM_URL:http://localhost:8000/v1}", + ), + ) + ) + return inference_providers, available_models + + +def get_distribution_template() -> DistributionTemplate: + inference_providers, available_models = get_inference_providers() + providers = { + "inference": ([p.provider_type for p in inference_providers]), + "vector_io": ["remote::chromadb"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::rag-runtime", + "remote::model-context-protocol", + ], + } + name = "postgres-demo" + + vector_io_providers = [ + Provider( + provider_id="${env.ENABLE_CHROMADB+chromadb}", + provider_type="remote::chromadb", + config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"), + ), + ] + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ] + + default_models = get_model_registry(available_models) + default_models.append( + ModelInput( + model_id="${env.INFERENCE_MODEL}", + provider_id="vllm-inference", + ) + ) + postgres_config = { + "type": "postgres", + "host": "${env.POSTGRES_HOST:localhost}", + "port": "${env.POSTGRES_PORT:5432}", + "db": "${env.POSTGRES_DB:llamastack}", + "user": "${env.POSTGRES_USER:llamastack}", + "password": "${env.POSTGRES_PASSWORD:llamastack}", + } + + return DistributionTemplate( + name=name, + distro_type="self_hosted", + description="Quick start template for running Llama Stack with several popular providers", + container_image=None, + template_path=None, + providers=providers, + available_models_by_provider=available_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": inference_providers, + "vector_io": vector_io_providers, + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + config=dict( + persistence_store=postgres_config, + responses_store=postgres_config, + ), + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + config=dict( + service_name="${env.OTEL_SERVICE_NAME:}", + sinks="${env.TELEMETRY_SINKS:console}", + ), + ) + ], + }, + default_models=default_models, + default_tool_groups=default_tool_groups, + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + metadata_store=PostgresKVStoreConfig.model_validate(postgres_config), + inference_store=PostgresSqlStoreConfig.model_validate(postgres_config), + ), + }, + run_config_env_vars={ + "LLAMA_STACK_PORT": ( + "8321", + "Port for the Llama Stack distribution server", + ), + "FIREWORKS_API_KEY": ( + "", + "Fireworks API Key", + ), + }, + ) diff --git a/llama_stack/templates/postgres-demo/run.yaml b/llama_stack/templates/postgres-demo/run.yaml new file mode 100644 index 000000000..889b8eaa7 --- /dev/null +++ b/llama_stack/templates/postgres-demo/run.yaml @@ -0,0 +1,224 @@ +version: '2' +image_name: postgres-demo +apis: +- agents +- inference +- safety +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: fireworks + provider_type: remote::fireworks + config: + url: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY:} + - provider_id: vllm-inference + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:4096} + api_token: ${env.VLLM_API_TOKEN:fake} + tls_verify: ${env.VLLM_TLS_VERIFY:true} + vector_io: + - provider_id: ${env.ENABLE_CHROMADB+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: postgres + host: ${env.POSTGRES_HOST:localhost} + port: ${env.POSTGRES_PORT:5432} + db: ${env.POSTGRES_DB:llamastack} + user: ${env.POSTGRES_USER:llamastack} + password: ${env.POSTGRES_PASSWORD:llamastack} + responses_store: + type: postgres + host: ${env.POSTGRES_HOST:localhost} + port: ${env.POSTGRES_PORT:5432} + db: ${env.POSTGRES_DB:llamastack} + user: ${env.POSTGRES_USER:llamastack} + password: ${env.POSTGRES_PASSWORD:llamastack} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:} + sinks: ${env.TELEMETRY_SINKS:console} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: postgres + host: ${env.POSTGRES_HOST:localhost} + port: ${env.POSTGRES_PORT:5432} + db: ${env.POSTGRES_DB:llamastack} + user: ${env.POSTGRES_USER:llamastack} + password: ${env.POSTGRES_PASSWORD:llamastack} + table_name: llamastack_kvstore +inference_store: + type: postgres + host: ${env.POSTGRES_HOST:localhost} + port: ${env.POSTGRES_PORT:5432} + db: ${env.POSTGRES_DB:llamastack} + user: ${env.POSTGRES_USER:llamastack} + password: ${env.POSTGRES_PASSWORD:llamastack} +models: +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-guard-3-8b + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-8b + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-8B + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-8b + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-guard-3-11b-vision + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-11B-Vision + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama4-scout-instruct-basic + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama4-maverick-instruct-basic + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 8192 + model_id: nomic-ai/nomic-embed-text-v1.5 + provider_id: fireworks + provider_model_id: nomic-ai/nomic-embed-text-v1.5 + model_type: embedding +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: vllm-inference + model_type: llm +shields: +- shield_id: meta-llama/Llama-Guard-3-8B +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8321 diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index fcd4deeff..16fe5d4fd 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -32,5 +32,5 @@ distribution_spec: - remote::wolfram-alpha image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index b644dcfdc..14b1c8974 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -23,4 +23,5 @@ distribution_spec: - remote::wolfram-alpha image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index 652814ffd..ec97c7d3e 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -36,4 +36,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index ec5cd38ea..4013f08f9 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -28,8 +28,8 @@ from llama_stack.distribution.datatypes import ( from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig def get_model_registry( @@ -64,6 +64,8 @@ class RunConfigSettings(BaseModel): default_tool_groups: list[ToolGroupInput] | None = None default_datasets: list[DatasetInput] | None = None default_benchmarks: list[BenchmarkInput] | None = None + metadata_store: KVStoreConfig | None = None + inference_store: SqlStoreConfig | None = None def run_config( self, @@ -114,11 +116,13 @@ class RunConfigSettings(BaseModel): container_image=container_image, apis=apis, providers=provider_configs, - metadata_store=SqliteKVStoreConfig.sample_run_config( + metadata_store=self.metadata_store + or SqliteKVStoreConfig.sample_run_config( __distro_dir__=f"~/.llama/distributions/{name}", db_name="registry.db", ), - inference_store=SqliteSqlStoreConfig.sample_run_config( + inference_store=self.inference_store + or SqliteSqlStoreConfig.sample_run_config( __distro_dir__=f"~/.llama/distributions/{name}", db_name="inference_store.db", ), @@ -164,7 +168,7 @@ class DistributionTemplate(BaseModel): providers=self.providers, ), image_type="conda", # default to conda, can be overridden - additional_pip_packages=additional_pip_packages, + additional_pip_packages=sorted(set(additional_pip_packages)), ) def generate_markdown_docs(self) -> str: diff --git a/llama_stack/templates/tgi/build.yaml b/llama_stack/templates/tgi/build.yaml index 652900c84..361b0b680 100644 --- a/llama_stack/templates/tgi/build.yaml +++ b/llama_stack/templates/tgi/build.yaml @@ -31,5 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index 4a556a66f..5ffeac873 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -32,5 +32,5 @@ distribution_spec: - remote::wolfram-alpha image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/verification/build.yaml b/llama_stack/templates/verification/build.yaml index cb7ab4798..ce083dbba 100644 --- a/llama_stack/templates/verification/build.yaml +++ b/llama_stack/templates/verification/build.yaml @@ -36,4 +36,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/vllm-gpu/build.yaml b/llama_stack/templates/vllm-gpu/build.yaml index 5a9d003cb..d5ff0f1f4 100644 --- a/llama_stack/templates/vllm-gpu/build.yaml +++ b/llama_stack/templates/vllm-gpu/build.yaml @@ -31,4 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml index 87233fb26..e68ace183 100644 --- a/llama_stack/templates/watsonx/build.yaml +++ b/llama_stack/templates/watsonx/build.yaml @@ -29,4 +29,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 2cd76a23d..190840f70 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -268,9 +268,9 @@ def test_openai_chat_completion_streaming_with_n(compat_client, client_with_mode False, ], ) -def test_inference_store(openai_client, client_with_models, text_model_id, stream): +def test_inference_store(compat_client, client_with_models, text_model_id, stream): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) - client = openai_client + client = compat_client # make a chat completion message = "Hello, world!" response = client.chat.completions.create( @@ -301,9 +301,14 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea retrieved_response = client.chat.completions.retrieve(response_id) assert retrieved_response.id == response_id - assert retrieved_response.input_messages[0]["content"] == message, retrieved_response assert retrieved_response.choices[0].message.content == content, retrieved_response + input_content = ( + getattr(retrieved_response.input_messages[0], "content", None) + or retrieved_response.input_messages[0]["content"] + ) + assert input_content == message, retrieved_response + @pytest.mark.parametrize( "stream", @@ -312,9 +317,9 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea False, ], ) -def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream): +def test_inference_store_tool_calls(compat_client, client_with_models, text_model_id, stream): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) - client = openai_client + client = compat_client # make a chat completion message = "What's the weather in Tokyo? Use the get_weather function to get the weather." response = client.chat.completions.create( @@ -361,7 +366,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode retrieved_response = client.chat.completions.retrieve(response_id) assert retrieved_response.id == response_id - assert retrieved_response.input_messages[0]["content"] == message + input_content = ( + getattr(retrieved_response.input_messages[0], "content", None) + or retrieved_response.input_messages[0]["content"] + ) + assert input_content == message, retrieved_response tool_calls = retrieved_response.choices[0].message.tool_calls # sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called if tool_calls: diff --git a/tests/unit/utils/test_sqlstore.py b/tests/unit/utils/test_sqlstore.py index 8ded760ef..6231e9082 100644 --- a/tests/unit/utils/test_sqlstore.py +++ b/tests/unit/utils/test_sqlstore.py @@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory import pytest from llama_stack.providers.utils.sqlstore.api import ColumnType -from llama_stack.providers.utils.sqlstore.sqlite.sqlite import SqliteSqlStoreImpl +from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig @@ -17,7 +17,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig async def test_sqlite_sqlstore(): with TemporaryDirectory() as tmp_dir: db_name = "test.db" - sqlstore = SqliteSqlStoreImpl( + sqlstore = SqlAlchemySqlStoreImpl( SqliteSqlStoreConfig( db_path=tmp_dir + "/" + db_name, ) From f2c2a05f588e13c8e369a0c7799b1dc4fe23beea Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 29 May 2025 15:27:59 -0700 Subject: [PATCH 6/6] OpenAI compat embeddings API --- docs/_static/llama-stack-spec.html | 176 ++++++++++++++++++ docs/_static/llama-stack-spec.yaml | 144 ++++++++++++++ llama_stack/apis/inference/inference.py | 62 ++++++ llama_stack/distribution/routers/inference.py | 29 +++ .../providers/inline/inference/vllm/vllm.py | 11 ++ .../remote/inference/bedrock/bedrock.py | 11 ++ .../remote/inference/cerebras/cerebras.py | 11 ++ .../remote/inference/databricks/databricks.py | 11 ++ .../remote/inference/fireworks/fireworks.py | 11 ++ .../remote/inference/nvidia/nvidia.py | 11 ++ .../remote/inference/ollama/ollama.py | 11 ++ .../remote/inference/openai/openai.py | 52 ++++++ .../inference/passthrough/passthrough.py | 11 ++ .../remote/inference/runpod/runpod.py | 11 ++ .../providers/remote/inference/tgi/tgi.py | 11 ++ .../remote/inference/together/together.py | 11 ++ .../providers/remote/inference/vllm/vllm.py | 11 ++ .../remote/inference/watsonx/watsonx.py | 11 ++ .../utils/inference/embedding_mixin.py | 49 +++++ .../utils/inference/litellm_openai_mixin.py | 51 +++++ 20 files changed, 706 insertions(+) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 9c1c3170f..770abfb27 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3607,6 +3607,49 @@ } } }, + "/v1/openai/v1/embeddings": { + "post": { + "responses": { + "200": { + "description": "An OpenAIEmbeddingsResponse containing the embeddings.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAIEmbeddingsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "Generate OpenAI-compatible embeddings for the given input using the specified model.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenaiEmbeddingsRequest" + } + } + }, + "required": true + } + } + }, "/v1/openai/v1/models": { "get": { "responses": { @@ -11767,6 +11810,139 @@ "title": "OpenAICompletionChoice", "description": "A choice from an OpenAI-compatible completion response." }, + "OpenaiEmbeddingsRequest": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint." + }, + "input": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ], + "description": "Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings." + }, + "encoding_format": { + "type": "string", + "description": "(Optional) The format to return the embeddings in. Can be either \"float\" or \"base64\". Defaults to \"float\"." + }, + "dimensions": { + "type": "integer", + "description": "(Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models." + }, + "user": { + "type": "string", + "description": "(Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse." + } + }, + "additionalProperties": false, + "required": [ + "model", + "input" + ], + "title": "OpenaiEmbeddingsRequest" + }, + "OpenAIEmbeddingData": { + "type": "object", + "properties": { + "object": { + "type": "string", + "const": "embedding", + "default": "embedding", + "description": "The object type, which will be \"embedding\"" + }, + "embedding": { + "oneOf": [ + { + "type": "array", + "items": { + "type": "number" + } + }, + { + "type": "string" + } + ], + "description": "The embedding vector as a list of floats (when encoding_format=\"float\") or as a base64-encoded string (when encoding_format=\"base64\")" + }, + "index": { + "type": "integer", + "description": "The index of the embedding in the input list" + } + }, + "additionalProperties": false, + "required": [ + "object", + "embedding", + "index" + ], + "title": "OpenAIEmbeddingData", + "description": "A single embedding data object from an OpenAI-compatible embeddings response." + }, + "OpenAIEmbeddingUsage": { + "type": "object", + "properties": { + "prompt_tokens": { + "type": "integer", + "description": "The number of tokens in the input" + }, + "total_tokens": { + "type": "integer", + "description": "The total number of tokens used" + } + }, + "additionalProperties": false, + "required": [ + "prompt_tokens", + "total_tokens" + ], + "title": "OpenAIEmbeddingUsage", + "description": "Usage information for an OpenAI-compatible embeddings response." + }, + "OpenAIEmbeddingsResponse": { + "type": "object", + "properties": { + "object": { + "type": "string", + "const": "list", + "default": "list", + "description": "The object type, which will be \"list\"" + }, + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIEmbeddingData" + }, + "description": "List of embedding data objects" + }, + "model": { + "type": "string", + "description": "The model that was used to generate the embeddings" + }, + "usage": { + "$ref": "#/components/schemas/OpenAIEmbeddingUsage", + "description": "Usage information" + } + }, + "additionalProperties": false, + "required": [ + "object", + "data", + "model", + "usage" + ], + "title": "OpenAIEmbeddingsResponse", + "description": "Response from an OpenAI-compatible embeddings request." + }, "OpenAIModel": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 1afe870cf..15842ff19 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2520,6 +2520,38 @@ paths: schema: $ref: '#/components/schemas/OpenaiCompletionRequest' required: true + /v1/openai/v1/embeddings: + post: + responses: + '200': + description: >- + An OpenAIEmbeddingsResponse containing the embeddings. + content: + application/json: + schema: + $ref: '#/components/schemas/OpenAIEmbeddingsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: >- + Generate OpenAI-compatible embeddings for the given input using the specified + model. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/OpenaiEmbeddingsRequest' + required: true /v1/openai/v1/models: get: responses: @@ -8177,6 +8209,118 @@ components: title: OpenAICompletionChoice description: >- A choice from an OpenAI-compatible completion response. + OpenaiEmbeddingsRequest: + type: object + properties: + model: + type: string + description: >- + The identifier of the model to use. The model must be an embedding model + registered with Llama Stack and available via the /models endpoint. + input: + oneOf: + - type: string + - type: array + items: + type: string + description: >- + Input text to embed, encoded as a string or array of strings. To embed + multiple inputs in a single request, pass an array of strings. + encoding_format: + type: string + description: >- + (Optional) The format to return the embeddings in. Can be either "float" + or "base64". Defaults to "float". + dimensions: + type: integer + description: >- + (Optional) The number of dimensions the resulting output embeddings should + have. Only supported in text-embedding-3 and later models. + user: + type: string + description: >- + (Optional) A unique identifier representing your end-user, which can help + OpenAI to monitor and detect abuse. + additionalProperties: false + required: + - model + - input + title: OpenaiEmbeddingsRequest + OpenAIEmbeddingData: + type: object + properties: + object: + type: string + const: embedding + default: embedding + description: >- + The object type, which will be "embedding" + embedding: + oneOf: + - type: array + items: + type: number + - type: string + description: >- + The embedding vector as a list of floats (when encoding_format="float") + or as a base64-encoded string (when encoding_format="base64") + index: + type: integer + description: >- + The index of the embedding in the input list + additionalProperties: false + required: + - object + - embedding + - index + title: OpenAIEmbeddingData + description: >- + A single embedding data object from an OpenAI-compatible embeddings response. + OpenAIEmbeddingUsage: + type: object + properties: + prompt_tokens: + type: integer + description: The number of tokens in the input + total_tokens: + type: integer + description: The total number of tokens used + additionalProperties: false + required: + - prompt_tokens + - total_tokens + title: OpenAIEmbeddingUsage + description: >- + Usage information for an OpenAI-compatible embeddings response. + OpenAIEmbeddingsResponse: + type: object + properties: + object: + type: string + const: list + default: list + description: The object type, which will be "list" + data: + type: array + items: + $ref: '#/components/schemas/OpenAIEmbeddingData' + description: List of embedding data objects + model: + type: string + description: >- + The model that was used to generate the embeddings + usage: + $ref: '#/components/schemas/OpenAIEmbeddingUsage' + description: Usage information + additionalProperties: false + required: + - object + - data + - model + - usage + title: OpenAIEmbeddingsResponse + description: >- + Response from an OpenAI-compatible embeddings request. OpenAIModel: type: object properties: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index e79dc6d94..74697dd18 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -783,6 +783,48 @@ class OpenAICompletion(BaseModel): object: Literal["text_completion"] = "text_completion" +@json_schema_type +class OpenAIEmbeddingData(BaseModel): + """A single embedding data object from an OpenAI-compatible embeddings response. + + :param object: The object type, which will be "embedding" + :param embedding: The embedding vector as a list of floats (when encoding_format="float") or as a base64-encoded string (when encoding_format="base64") + :param index: The index of the embedding in the input list + """ + + object: Literal["embedding"] = "embedding" + embedding: list[float] | str + index: int + + +@json_schema_type +class OpenAIEmbeddingUsage(BaseModel): + """Usage information for an OpenAI-compatible embeddings response. + + :param prompt_tokens: The number of tokens in the input + :param total_tokens: The total number of tokens used + """ + + prompt_tokens: int + total_tokens: int + + +@json_schema_type +class OpenAIEmbeddingsResponse(BaseModel): + """Response from an OpenAI-compatible embeddings request. + + :param object: The object type, which will be "list" + :param data: List of embedding data objects + :param model: The model that was used to generate the embeddings + :param usage: Usage information + """ + + object: Literal["list"] = "list" + data: list[OpenAIEmbeddingData] + model: str + usage: OpenAIEmbeddingUsage + + class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... @@ -1076,6 +1118,26 @@ class InferenceProvider(Protocol): """ ... + @webmethod(route="/openai/v1/embeddings", method="POST") + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + """Generate OpenAI-compatible embeddings for the given input using the specified model. + + :param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. + :param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings. + :param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float". + :param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. + :param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + :returns: An OpenAIEmbeddingsResponse containing the embeddings. + """ + ... + class Inference(InferenceProvider): """Llama Stack Inference API for generating completions, chat completions, and embeddings. diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index f77b19302..763bd9105 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -45,6 +45,7 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingsResponse, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -546,6 +547,34 @@ class InferenceRouter(Inference): await self.store.store_chat_completion(response, messages) return response + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + logger.debug( + f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", + ) + model_obj = await self.routing_table.get_model(model) + if model_obj is None: + raise ValueError(f"Model '{model}' not found") + if model_obj.model_type != ModelType.embedding: + raise ValueError(f"Model '{model}' is not an embedding model") + + params = dict( + model=model_obj.identifier, + input=input, + encoding_format=encoding_format, + dimensions=dimensions, + user=user, + ) + + provider = self.routing_table.get_provider_impl(model_obj.identifier) + return await provider.openai_embeddings(**params) + async def list_chat_completions( self, after: str | None = None, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 438cb14a0..bf54462b5 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -40,6 +40,7 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -410,6 +411,16 @@ class VLLMInferenceImpl( ) -> EmbeddingsResponse: raise NotImplementedError() + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def chat_completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 0404a578f..952d86f1a 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -22,6 +22,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -197,3 +198,13 @@ class BedrockInferenceAdapter( response_body = json.loads(response.get("body").read()) embeddings.append(response_body.get("embedding")) return EmbeddingsResponse(embeddings=embeddings) + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 685375346..952118e24 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -194,3 +195,13 @@ class CerebrasInferenceAdapter( task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 5c36eac3e..1dc18b97f 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -20,6 +20,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -152,3 +153,13 @@ class DatabricksInferenceAdapter( task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index b6d3984c6..fe21685dd 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -37,6 +37,7 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingsResponse, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -286,6 +287,16 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 333486fe4..4c68322e0 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -29,6 +29,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -238,6 +239,16 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def chat_completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 3b4287673..8863e0edc 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -370,6 +371,16 @@ class OllamaInferenceAdapter( return model + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index c3c25edd3..6f3a686a8 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -14,6 +14,9 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -38,6 +41,7 @@ logger = logging.getLogger(__name__) # | batch_chat_completion | LiteLLMOpenAIMixin | # | openai_completion | AsyncOpenAI | # | openai_chat_completion | AsyncOpenAI | +# | openai_embeddings | AsyncOpenAI | # class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: OpenAIConfig) -> None: @@ -171,3 +175,51 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): user=user, ) return await self._openai_client.chat.completions.create(**params) + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + model_id = (await self.model_store.get_model(model)).provider_resource_id + if model_id.startswith("openai/"): + model_id = model_id[len("openai/") :] + + # Prepare parameters for OpenAI embeddings API + params = { + "model": model_id, + "input": input, + } + + if encoding_format is not None: + params["encoding_format"] = encoding_format + if dimensions is not None: + params["dimensions"] = dimensions + if user is not None: + params["user"] = user + + # Call OpenAI embeddings API + response = await self._openai_client.embeddings.create(**params) + + data = [] + for i, embedding_data in enumerate(response.data): + data.append( + OpenAIEmbeddingData( + embedding=embedding_data.embedding, + index=i, + ) + ) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response.usage.prompt_tokens, + total_tokens=response.usage.total_tokens, + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=response.model, + usage=usage, + ) diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 78ee52641..6cf4680e2 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -19,6 +19,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -210,6 +211,16 @@ class PassthroughInferenceAdapter(Inference): task_type=task_type, ) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 2706aa15e..f8c98893e 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.inference.inference import OpenAIEmbeddingsResponse # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper @@ -134,3 +135,13 @@ class RunpodInferenceAdapter( task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 8f6666462..292d74ef8 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -23,6 +23,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, ResponseFormatType, SamplingParams, @@ -291,6 +292,16 @@ class _HfAdapter( ) -> EmbeddingsResponse: raise NotImplementedError() + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 562e6e0ff..7305a638d 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -23,6 +23,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, ResponseFormatType, SamplingParams, @@ -267,6 +268,16 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi embeddings = [item.embedding for item in r.data] return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index fe2d8bec1..9f38d9abf 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -38,6 +38,7 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -507,6 +508,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index c1299e11f..59f5f5562 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -260,6 +261,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): ) -> EmbeddingsResponse: raise NotImplementedError("embedding is not supported for watsonx") + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 7c8144c62..97cf87360 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -4,7 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 import logging +import struct from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -15,6 +17,9 @@ from llama_stack.apis.inference import ( EmbeddingTaskType, InterleavedContentItem, ModelStore, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, TextTruncation, ) from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str @@ -43,6 +48,50 @@ class SentenceTransformerEmbeddingMixin: ) return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + # Convert input to list format if it's a single string + input_list = [input] if isinstance(input, str) else input + if not input_list: + raise ValueError("Empty list not supported") + + # Get the model and generate embeddings + model_obj = await self.model_store.get_model(model) + embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) + embeddings = embedding_model.encode(input_list, show_progress_bar=False) + + # Convert embeddings to the requested format + data = [] + for i, embedding in enumerate(embeddings): + if encoding_format == "base64": + # Convert float array to base64 string + float_bytes = struct.pack(f"{len(embedding)}f", *embedding) + embedding_value = base64.b64encode(float_bytes).decode("ascii") + else: + # Default to float format + embedding_value = embedding.tolist() + + data.append( + OpenAIEmbeddingData( + embedding=embedding_value, + index=i, + ) + ) + + # Not returning actual token usage + usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1) + return OpenAIEmbeddingsResponse( + data=data, + model=model_obj.provider_resource_id, + usage=usage, + ) + def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": global EMBEDDING_MODELS diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 4d17db21e..dab10bc55 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 +import struct from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -35,6 +37,9 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -264,6 +269,52 @@ class LiteLLMOpenAIMixin( embeddings = [data["embedding"] for data in response["data"]] return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + model_obj = await self.model_store.get_model(model) + + # Convert input to list if it's a string + input_list = [input] if isinstance(input, str) else input + + # Call litellm embedding function + # litellm.drop_params = True + response = litellm.embedding( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + input=input_list, + api_key=self.get_api_key(), + api_base=self.api_base, + dimensions=dimensions, + ) + + # Convert response to OpenAI format + data = [] + for i, embedding_data in enumerate(response["data"]): + # we encode to base64 if the encoding format is base64 in the request + if encoding_format == "base64": + byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"]) + embedding = base64.b64encode(byte_data).decode("utf-8") + else: + embedding = embedding_data["embedding"] + + data.append(OpenAIEmbeddingData(embedding=embedding, index=i)) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response["usage"]["prompt_tokens"], + total_tokens=response["usage"]["total_tokens"], + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=model_obj.provider_resource_id, + usage=usage, + ) + async def openai_completion( self, model: str,