Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-09-12 08:56:46 +09:00 committed by GitHub
commit aaea9fed12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 780 additions and 105 deletions

View file

@ -13,11 +13,8 @@ on:
branches: [ main ] branches: [ main ]
types: [opened, synchronize, reopened] types: [opened, synchronize, reopened]
paths: paths:
- 'llama_stack/**' - 'docs/_static/llama-stack-spec.yaml'
- '!llama_stack/ui/**' - 'docs/_static/llama-stack-spec.html'
- 'tests/**'
- 'uv.lock'
- 'pyproject.toml'
- '.github/workflows/conformance.yml' # This workflow itself - '.github/workflows/conformance.yml' # This workflow itself
concurrency: concurrency:
@ -43,10 +40,27 @@ jobs:
ref: ${{ github.event.pull_request.base.ref }} ref: ${{ github.event.pull_request.base.ref }}
path: 'base' path: 'base'
# Cache oasdiff to avoid checksum failures and speed up builds
- name: Cache oasdiff
id: cache-oasdiff
uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809
with:
path: ~/oasdiff
key: oasdiff-${{ runner.os }}
# Install oasdiff: https://github.com/oasdiff/oasdiff, a tool for detecting breaking changes in OpenAPI specs. # Install oasdiff: https://github.com/oasdiff/oasdiff, a tool for detecting breaking changes in OpenAPI specs.
- name: Install oasdiff - name: Install oasdiff
if: steps.cache-oasdiff.outputs.cache-hit != 'true'
run: | run: |
curl -fsSL https://raw.githubusercontent.com/oasdiff/oasdiff/main/install.sh | sh curl -fsSL https://raw.githubusercontent.com/oasdiff/oasdiff/main/install.sh | sh
cp /usr/local/bin/oasdiff ~/oasdiff
# Setup cached oasdiff
- name: Setup cached oasdiff
if: steps.cache-oasdiff.outputs.cache-hit == 'true'
run: |
sudo cp ~/oasdiff /usr/local/bin/oasdiff
sudo chmod +x /usr/local/bin/oasdiff
# Run oasdiff to detect breaking changes in the API specification # Run oasdiff to detect breaking changes in the API specification
# This step will fail if incompatible changes are detected, preventing breaking changes from being merged # This step will fail if incompatible changes are detected, preventing breaking changes from being merged

View file

@ -0,0 +1,701 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1ztegmwm4sp",
"metadata": {},
"source": [
"## LlamaStack + LangChain Integration Tutorial\n",
"\n",
"This notebook demonstrates how to integrate **LlamaStack** with **LangChain** to build a complete RAG (Retrieval-Augmented Generation) system.\n",
"\n",
"### Overview\n",
"\n",
"- **LlamaStack**: Provides the infrastructure for running LLMs and Open AI Compatible Vector Stores\n",
"- **LangChain**: Provides the framework for chaining operations and prompt templates\n",
"- **Integration**: Uses LlamaStack's OpenAI-compatible API with LangChain\n",
"\n",
"### What You'll See\n",
"\n",
"1. Setting up LlamaStack server with Fireworks AI provider\n",
"2. Creating and Querying Vector Stores\n",
"3. Building RAG chains with LangChain + LLAMAStack\n",
"4. Querying the chain for relevant information\n",
"\n",
"### Prerequisites\n",
"\n",
"- Fireworks API key\n",
"\n",
"---\n",
"\n",
"### 1. Installation and Setup"
]
},
{
"cell_type": "markdown",
"id": "2ktr5ls2cas",
"metadata": {},
"source": [
"#### Install Required Dependencies\n",
"\n",
"First, we install all the necessary packages for LangChain and FastAPI integration."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5b6a6a17-b931-4bea-8273-0d6e5563637a",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: uv in /Users/swapna942/miniconda3/lib/python3.12/site-packages (0.7.20)\n",
"\u001b[2mUsing Python 3.12.11 environment at: /Users/swapna942/miniconda3\u001b[0m\n",
"\u001b[2mAudited \u001b[1m7 packages\u001b[0m \u001b[2min 42ms\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"!pip install uv\n",
"!uv pip install fastapi uvicorn \"langchain>=0.2\" langchain-openai \\\n",
" langchain-community langchain-text-splitters \\\n",
" faiss-cpu"
]
},
{
"cell_type": "markdown",
"id": "wmt9jvqzh7n",
"metadata": {},
"source": [
"### 2. LlamaStack Server Setup\n",
"\n",
"#### Build and Start LlamaStack Server\n",
"\n",
"This section sets up the LlamaStack server with:\n",
"- **Fireworks AI** as the inference provider\n",
"- **Sentence Transformers** for embeddings\n",
"\n",
"The server runs on `localhost:8321` and provides OpenAI-compatible endpoints."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "dd2dacf3-ec8b-4cc7-8ff4-b5b6ea4a6e9e",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import os\n",
"import subprocess\n",
"import time\n",
"\n",
"# Remove UV_SYSTEM_PYTHON to ensure uv creates a proper virtual environment\n",
"# instead of trying to use system Python globally, which could cause permission issues\n",
"# and package conflicts with the system's Python installation\n",
"if \"UV_SYSTEM_PYTHON\" in os.environ:\n",
" del os.environ[\"UV_SYSTEM_PYTHON\"]\n",
"\n",
"def run_llama_stack_server_background():\n",
" \"\"\"Build and run LlamaStack server in one step using --run flag\"\"\"\n",
" log_file = open(\"llama_stack_server.log\", \"w\")\n",
" process = subprocess.Popen(\n",
" \"uv run --with llama-stack llama stack build --distro starter --image-type venv --run\",\n",
" shell=True,\n",
" stdout=log_file,\n",
" stderr=log_file,\n",
" text=True,\n",
" )\n",
"\n",
" print(f\"Building and starting Llama Stack server with PID: {process.pid}\")\n",
" return process\n",
"\n",
"\n",
"def wait_for_server_to_start():\n",
" import requests\n",
" from requests.exceptions import ConnectionError\n",
"\n",
" url = \"http://0.0.0.0:8321/v1/health\"\n",
" max_retries = 30\n",
" retry_interval = 1\n",
"\n",
" print(\"Waiting for server to start\", end=\"\")\n",
" for _ in range(max_retries):\n",
" try:\n",
" response = requests.get(url)\n",
" if response.status_code == 200:\n",
" print(\"\\nServer is ready!\")\n",
" return True\n",
" except ConnectionError:\n",
" print(\".\", end=\"\", flush=True)\n",
" time.sleep(retry_interval)\n",
"\n",
" print(\"\\nServer failed to start after\", max_retries * retry_interval, \"seconds\")\n",
" return False\n",
"\n",
"\n",
"def kill_llama_stack_server():\n",
" # Kill any existing llama stack server processes using pkill command\n",
" os.system(\"pkill -f llama_stack.core.server.server\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "28bd8dbd-4576-4e76-813f-21ab94db44a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building and starting Llama Stack server with PID: 19747\n",
"Waiting for server to start....\n",
"Server is ready!\n"
]
}
],
"source": [
"server_process = run_llama_stack_server_background()\n",
"assert wait_for_server_to_start()"
]
},
{
"cell_type": "markdown",
"id": "gr9cdcg4r7n",
"metadata": {},
"source": [
"#### Install LlamaStack Client\n",
"\n",
"Install the client library to interact with the LlamaStack server."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "487d2dbc-d071-400e-b4f0-dcee58f8dc95",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2mUsing Python 3.12.11 environment at: /Users/swapna942/miniconda3\u001b[0m\n",
"\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 27ms\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"!uv pip install llama_stack_client"
]
},
{
"cell_type": "markdown",
"id": "0j5hag7l9x89",
"metadata": {},
"source": [
"### 3. Initialize LlamaStack Client\n",
"\n",
"Create a client connection to the LlamaStack server with API keys for different providers:\n",
"\n",
"- **Fireworks API Key**: For Fireworks models\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ab4eff97-4565-4c73-b1b3-0020a4c7e2a5",
"metadata": {},
"outputs": [],
"source": [
"from llama_stack_client import LlamaStackClient\n",
"\n",
"client = LlamaStackClient(\n",
" base_url=\"http://0.0.0.0:8321\",\n",
" provider_data={\"fireworks_api_key\": \"***\"},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "vwhexjy1e8o",
"metadata": {},
"source": [
"#### Explore Available Models and Safety Features\n",
"\n",
"Check what models and safety shields are available through your LlamaStack instance."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "880443ef-ac3c-48b1-a80a-7dab5b25ac61",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:httpx:HTTP Request: GET http://0.0.0.0:8321/v1/models \"HTTP/1.1 200 OK\"\n",
"INFO:httpx:HTTP Request: GET http://0.0.0.0:8321/v1/shields \"HTTP/1.1 200 OK\"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Available Fireworks models:\n",
"- fireworks/accounts/fireworks/models/llama-v3p1-8b-instruct\n",
"- fireworks/accounts/fireworks/models/llama-v3p1-70b-instruct\n",
"- fireworks/accounts/fireworks/models/llama-v3p1-405b-instruct\n",
"- fireworks/accounts/fireworks/models/llama-v3p2-3b-instruct\n",
"- fireworks/accounts/fireworks/models/llama-v3p2-11b-vision-instruct\n",
"- fireworks/accounts/fireworks/models/llama-v3p2-90b-vision-instruct\n",
"- fireworks/accounts/fireworks/models/llama-v3p3-70b-instruct\n",
"- fireworks/accounts/fireworks/models/llama4-scout-instruct-basic\n",
"- fireworks/accounts/fireworks/models/llama4-maverick-instruct-basic\n",
"- fireworks/nomic-ai/nomic-embed-text-v1.5\n",
"- fireworks/accounts/fireworks/models/llama-guard-3-8b\n",
"- fireworks/accounts/fireworks/models/llama-guard-3-11b-vision\n",
"----\n",
"Available shields (safety models):\n",
"code-scanner\n",
"llama-guard\n",
"nemo-guardrail\n",
"----\n"
]
}
],
"source": [
"print(\"Available Fireworks models:\")\n",
"for m in client.models.list():\n",
" if m.identifier.startswith(\"fireworks/\"):\n",
" print(f\"- {m.identifier}\")\n",
"\n",
"print(\"----\")\n",
"print(\"Available shields (safety models):\")\n",
"for s in client.shields.list():\n",
" print(s.identifier)\n",
"print(\"----\")"
]
},
{
"cell_type": "markdown",
"id": "gojp7at31ht",
"metadata": {},
"source": [
"### 4. Vector Store Setup\n",
"\n",
"#### Create a Vector Store with File Upload\n",
"\n",
"Create a vector store using the OpenAI-compatible vector stores API:\n",
"\n",
"- **Vector Store**: OpenAI-compatible vector store for document storage\n",
"- **File Upload**: Automatic chunking and embedding of uploaded files \n",
"- **Embedding Model**: Sentence Transformers model for text embeddings\n",
"- **Dimensions**: 384-dimensional embeddings"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "be2c2899-ea53-4e5f-b6b8-ed425f5d6572",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/files \"HTTP/1.1 200 OK\"\n",
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/files \"HTTP/1.1 200 OK\"\n",
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/files \"HTTP/1.1 200 OK\"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"File(id='file-54652c95c56c4c34918a97d7ff8a4320', bytes=41, created_at=1757442621, expires_at=1788978621, filename='shipping_policy.txt', object='file', purpose='assistants')\n",
"File(id='file-fb1227c1d1854da1bd774d21e5b7e41c', bytes=48, created_at=1757442621, expires_at=1788978621, filename='returns_policy.txt', object='file', purpose='assistants')\n",
"File(id='file-673f874852fe42798675a13d06a256e2', bytes=45, created_at=1757442621, expires_at=1788978621, filename='support.txt', object='file', purpose='assistants')\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/vector_stores \"HTTP/1.1 200 OK\"\n"
]
}
],
"source": [
"from io import BytesIO\n",
"\n",
"docs = [\n",
" (\"Acme ships globally in 3-5 business days.\", {\"title\": \"Shipping Policy\"}),\n",
" (\"Returns are accepted within 30 days of purchase.\", {\"title\": \"Returns Policy\"}),\n",
" (\"Support is available 24/7 via chat and email.\", {\"title\": \"Support\"}),\n",
"]\n",
"\n",
"file_ids = []\n",
"for content, metadata in docs:\n",
" with BytesIO(content.encode()) as file_buffer:\n",
" file_buffer.name = f\"{metadata['title'].replace(' ', '_').lower()}.txt\"\n",
" create_file_response = client.files.create(file=file_buffer, purpose=\"assistants\")\n",
" print(create_file_response)\n",
" file_ids.append(create_file_response.id)\n",
"\n",
"# Create vector store with files\n",
"vector_store = client.vector_stores.create(\n",
" name=\"acme_docs\",\n",
" file_ids=file_ids,\n",
" embedding_model=\"sentence-transformers/all-MiniLM-L6-v2\",\n",
" embedding_dimension=384,\n",
" provider_id=\"faiss\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9061tmi1zpq",
"metadata": {},
"source": [
"#### Test Vector Store Search\n",
"\n",
"Query the vector store. This performs semantic search to find relevant documents based on the query."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ba9d1901-bd5e-4216-b3e6-19dc74551cc6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/vector_stores/vs_708c060b-45da-423e-8354-68529b4fd1a6/search \"HTTP/1.1 200 OK\"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Acme ships globally in 3-5 business days.\n",
"Returns are accepted within 30 days of purchase.\n"
]
}
],
"source": [
"search_response = client.vector_stores.search(\n",
" vector_store_id=vector_store.id,\n",
" query=\"How long does shipping take?\",\n",
" max_num_results=2\n",
")\n",
"for result in search_response.data:\n",
" content = result.content[0].text\n",
" print(content)"
]
},
{
"cell_type": "markdown",
"id": "usne6mbspms",
"metadata": {},
"source": [
"### 5. LangChain Integration\n",
"\n",
"#### Configure LangChain with LlamaStack\n",
"\n",
"Set up LangChain to use LlamaStack's OpenAI-compatible API:\n",
"\n",
"- **Base URL**: Points to LlamaStack's OpenAI endpoint\n",
"- **Headers**: Include Fireworks API key for model access\n",
"- **Model**: Use Meta Llama v3p1 8b instruct model for inference"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c378bd10-09c2-417c-bdfc-1e0a2dd19084",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"# Point LangChain to Llamastack Server\n",
"llm = ChatOpenAI(\n",
" base_url=\"http://0.0.0.0:8321/v1/openai/v1\",\n",
" api_key=\"dummy\",\n",
" model=\"fireworks/accounts/fireworks/models/llama-v3p1-8b-instruct\",\n",
" default_headers={\"X-LlamaStack-Provider-Data\": '{\"fireworks_api_key\": \"***\"}'},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5a4ddpcuk3l",
"metadata": {},
"source": [
"#### Test LLM Connection\n",
"\n",
"Verify that LangChain can successfully communicate with the LlamaStack server."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f88ffb5a-657b-4916-9375-c6ddc156c25e",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
]
},
{
"data": {
"text/plain": [
"AIMessage(content=\"A llama's gentle eyes shine bright,\\nIn the Andes, it roams through morning light.\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': None, 'model_name': 'fireworks/accounts/fireworks/models/llama-v3p1-8b-instruct', 'system_fingerprint': None, 'id': 'chatcmpl-602b5967-82a3-476b-9cd2-7d3b29b76ee8', 'service_tier': None, 'finish_reason': 'stop', 'logprobs': None}, id='run--0933c465-ff4d-4a7b-b7fb-fd97dd8244f3-0')"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Test llm with simple message\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a friendly assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Write a two-sentence poem about llama.\"},\n",
"]\n",
"llm.invoke(messages)"
]
},
{
"cell_type": "markdown",
"id": "0xh0jg6a0l4a",
"metadata": {},
"source": [
"### 6. Building the RAG Chain\n",
"\n",
"#### Create a Complete RAG Pipeline\n",
"\n",
"Build a LangChain pipeline that combines:\n",
"\n",
"1. **Vector Search**: Query LlamaStack's Open AI compatible Vector Store\n",
"2. **Context Assembly**: Format retrieved documents\n",
"3. **Prompt Template**: Structure the input for the LLM\n",
"4. **LLM Generation**: Generate answers using context\n",
"5. **Output Parsing**: Extract the final response\n",
"\n",
"**Chain Flow**: `Query → Vector Search → Context + Question → LLM → Response`"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9684427d-dcc7-4544-9af5-8b110d014c42",
"metadata": {},
"outputs": [],
"source": [
"# LangChain for prompt template and chaining + LLAMA Stack Client Vector DB and LLM chat completion\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.runnables import RunnableLambda, RunnablePassthrough\n",
"\n",
"\n",
"def join_docs(docs):\n",
" return \"\\n\\n\".join([f\"[{d.filename}] {d.content[0].text}\" for d in docs.data])\n",
"\n",
"PROMPT = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You are a helpful assistant. Use the following context to answer.\"),\n",
" (\"user\", \"Question: {question}\\n\\nContext:\\n{context}\"),\n",
" ]\n",
")\n",
"\n",
"vector_step = RunnableLambda(\n",
" lambda x: client.vector_stores.search(\n",
" vector_store_id=vector_store.id,\n",
" query=x,\n",
" max_num_results=2\n",
" )\n",
" )\n",
"\n",
"chain = (\n",
" {\"context\": vector_step | RunnableLambda(join_docs), \"question\": RunnablePassthrough()}\n",
" | PROMPT\n",
" | llm\n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0onu6rhphlra",
"metadata": {},
"source": [
"### 7. Testing the RAG System\n",
"\n",
"#### Example 1: Shipping Query"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "03322188-9509-446a-a4a8-ce3bb83ec87c",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/vector_stores/vs_708c060b-45da-423e-8354-68529b4fd1a6/search \"HTTP/1.1 200 OK\"\n",
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"❓ How long does shipping take?\n",
"💡 Acme ships globally in 3-5 business days. This means that shipping typically takes between 3 to 5 working days from the date of dispatch or order fulfillment.\n"
]
}
],
"source": [
"query = \"How long does shipping take?\"\n",
"response = chain.invoke(query)\n",
"print(\"❓\", query)\n",
"print(\"💡\", response)"
]
},
{
"cell_type": "markdown",
"id": "b7krhqj88ku",
"metadata": {},
"source": [
"#### Example 2: Returns Policy Query"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "61995550-bb0b-46a8-a5d0-023207475d60",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/vector_stores/vs_708c060b-45da-423e-8354-68529b4fd1a6/search \"HTTP/1.1 200 OK\"\n",
"INFO:httpx:HTTP Request: POST http://0.0.0.0:8321/v1/openai/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"❓ Can I return a product after 40 days?\n",
"💡 Based on the provided context, you cannot return a product after 40 days. The return window is limited to 30 days from the date of purchase.\n"
]
}
],
"source": [
"query = \"Can I return a product after 40 days?\"\n",
"response = chain.invoke(query)\n",
"print(\"❓\", query)\n",
"print(\"💡\", response)"
]
},
{
"cell_type": "markdown",
"id": "h4w24fadvjs",
"metadata": {},
"source": [
"---\n",
"We have successfully built a RAG system that combines:\n",
"\n",
"- **LlamaStack** for infrastructure (LLM serving + Vector Store)\n",
"- **LangChain** for orchestration (prompts + chains)\n",
"- **Fireworks** for high-quality language models\n",
"\n",
"### Key Benefits\n",
"\n",
"1. **Unified Infrastructure**: Single server for LLMs and Vector Store\n",
"2. **OpenAI Compatibility**: Easy integration with existing LangChain code\n",
"3. **Multi-Provider Support**: Switch between different LLM providers\n",
"4. **Production Ready**: Built-in safety shields and monitoring\n",
"\n",
"### Next Steps\n",
"\n",
"- Add more sophisticated document processing\n",
"- Implement conversation memory\n",
"- Add safety filtering and monitoring\n",
"- Scale to larger document collections\n",
"- Integrate with web frameworks like FastAPI or Streamlit\n",
"\n",
"---\n",
"\n",
"##### 🔧 Cleanup\n",
"\n",
"Don't forget to stop the LlamaStack server when you're done:\n",
"\n",
"```python\n",
"kill_llama_stack_server()\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "15647c46-22ce-4698-af3f-8161329d8e3a",
"metadata": {},
"outputs": [],
"source": [
"kill_llama_stack_server()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -6,11 +6,7 @@
import asyncio import asyncio
import json import json
import logging # allow-direct-logging
import threading
import time import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest import pytest
@ -18,7 +14,7 @@ from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
) )
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChoice, Choice as OpenAIChoiceChunk,
) )
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta, ChoiceDelta as OpenAIChoiceDelta,
@ -35,6 +31,9 @@ from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
CompletionMessage, CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
SystemMessage, SystemMessage,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
@ -61,41 +60,6 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
# -v -s --tb=short --disable-warnings # -v -s --tb=short --disable-warnings
class MockInferenceAdapterWithSleep:
def __init__(self, sleep_time: int, response: dict[str, Any]):
self.httpd = None
class DelayedRequestHandler(BaseHTTPRequestHandler):
# ruff: noqa: N802
def do_POST(self):
time.sleep(sleep_time)
response_body = json.dumps(response).encode("utf-8")
self.send_response(code=200)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", len(response_body))
self.end_headers()
self.wfile.write(response_body)
self.request_handler = DelayedRequestHandler
def __enter__(self):
httpd = HTTPServer(("", 0), self.request_handler)
self.httpd = httpd
host, port = httpd.server_address
httpd_thread = threading.Thread(target=httpd.serve_forever)
httpd_thread.daemon = True # stop server if this thread terminates
httpd_thread.start()
config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}")
inference_adapter = VLLMInferenceAdapter(config)
return inference_adapter
def __exit__(self, _exc_type, _exc_value, _traceback):
if self.httpd:
self.httpd.shutdown()
self.httpd.server_close()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def mock_openai_models_list(): def mock_openai_models_list():
with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list: with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list:
@ -201,7 +165,7 @@ async def test_tool_call_delta_empty_tool_call_buf():
async def mock_stream(): async def mock_stream():
delta = OpenAIChoiceDelta(content="", tool_calls=None) delta = OpenAIChoiceDelta(content="", tool_calls=None)
choices = [OpenAIChoice(delta=delta, finish_reason="stop", index=0)] choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
mock_chunk = OpenAIChatCompletionChunk( mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1", id="chunk-1",
created=1, created=1,
@ -227,7 +191,7 @@ async def test_tool_call_delta_streaming_arguments_dict():
model="foo", model="foo",
object="chat.completion.chunk", object="chat.completion.chunk",
choices=[ choices=[
OpenAIChoice( OpenAIChoiceChunk(
delta=OpenAIChoiceDelta( delta=OpenAIChoiceDelta(
content="", content="",
tool_calls=[ tool_calls=[
@ -252,7 +216,7 @@ async def test_tool_call_delta_streaming_arguments_dict():
model="foo", model="foo",
object="chat.completion.chunk", object="chat.completion.chunk",
choices=[ choices=[
OpenAIChoice( OpenAIChoiceChunk(
delta=OpenAIChoiceDelta( delta=OpenAIChoiceDelta(
content="", content="",
tool_calls=[ tool_calls=[
@ -277,7 +241,9 @@ async def test_tool_call_delta_streaming_arguments_dict():
model="foo", model="foo",
object="chat.completion.chunk", object="chat.completion.chunk",
choices=[ choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
], ],
) )
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
@ -301,7 +267,7 @@ async def test_multiple_tool_calls():
model="foo", model="foo",
object="chat.completion.chunk", object="chat.completion.chunk",
choices=[ choices=[
OpenAIChoice( OpenAIChoiceChunk(
delta=OpenAIChoiceDelta( delta=OpenAIChoiceDelta(
content="", content="",
tool_calls=[ tool_calls=[
@ -326,7 +292,7 @@ async def test_multiple_tool_calls():
model="foo", model="foo",
object="chat.completion.chunk", object="chat.completion.chunk",
choices=[ choices=[
OpenAIChoice( OpenAIChoiceChunk(
delta=OpenAIChoiceDelta( delta=OpenAIChoiceDelta(
content="", content="",
tool_calls=[ tool_calls=[
@ -351,7 +317,9 @@ async def test_multiple_tool_calls():
model="foo", model="foo",
object="chat.completion.chunk", object="chat.completion.chunk",
choices=[ choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
], ],
) )
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
@ -395,59 +363,6 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
assert chunks[0].event.event_type.value == "start" assert chunks[0].event.event_type.value == "start"
@pytest.mark.allow_network
def test_chat_completion_doesnt_block_event_loop(caplog):
loop = asyncio.new_event_loop()
loop.set_debug(True)
caplog.set_level(logging.WARNING)
# Log when event loop is blocked for more than 200ms
loop.slow_callback_duration = 0.5
# Sleep for 500ms in our delayed http response
sleep_time = 0.5
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
mock_response = {
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1,
"modle": "mock-model",
"choices": [
{
"message": {"content": ""},
"logprobs": None,
"finish_reason": "stop",
"index": 0,
}
],
}
async def do_chat_completion():
await inference_adapter.chat_completion(
"mock-model",
[],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter:
inference_adapter.model_store = AsyncMock()
inference_adapter.model_store.get_model.return_value = mock_model
loop.run_until_complete(inference_adapter.initialize())
# Clear the logs so far and run the actual chat completion we care about
caplog.clear()
loop.run_until_complete(do_chat_completion())
# Ensure we don't have any asyncio warnings in the captured log
# records from our chat completion call. A message gets logged
# here any time we exceed the slow_callback_duration configured
# above.
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
assert not asyncio_warnings
async def test_get_params_empty_tools(vllm_inference_adapter): async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest( request = ChatCompletionRequest(
tools=[], tools=[],
@ -696,3 +611,48 @@ async def test_health_status_failure(vllm_inference_adapter):
assert "Health check failed: Connection failed" in health_response["message"] assert "Health check failed: Connection failed" in health_response["message"]
mock_models.list.assert_called_once() mock_models.list.assert_called_once()
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
"""
Verify that openai_chat_completion is async and doesn't block the event loop.
To do this we mock the underlying inference with a sleep, start multiple
inference calls in parallel, and ensure the total time taken is less
than the sum of the individual sleep times.
"""
sleep_time = 0.5
async def mock_create(*args, **kwargs):
await asyncio.sleep(sleep_time)
return OpenAIChatCompletion(
id="chatcmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
content="nothing interesting",
),
finish_reason="stop",
index=0,
)
],
)
async def do_inference():
await vllm_inference_adapter.openai_chat_completion(
"mock-model", messages=["one fish", "two fish"], stream=False
)
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(side_effect=mock_create)
mock_create_client.return_value = mock_client
start_time = time.time()
await asyncio.gather(do_inference(), do_inference(), do_inference(), do_inference())
total_time = time.time() - start_time
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"