This commit is contained in:
Neil Mehta 2025-04-29 08:22:55 -04:00 committed by GitHub
commit ec9fa30d36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 2524 additions and 0 deletions

View file

@ -127,6 +127,7 @@ Here is a list of the various API providers and available distributions that can
| Anthropic | Hosted | | ✅ | | | |
| Gemini | Hosted | | ✅ | | | |
| watsonx | Hosted | | ✅ | | | |
| LM Studio | Single Node | | ✅ | | | |
### Distributions

View file

@ -0,0 +1 @@
../../llama_stack/templates/lmstudio/build.yaml

View file

@ -0,0 +1 @@
../../llama_stack/templates/lmstudio/run.yaml

View file

@ -0,0 +1,70 @@
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# LM Studio Distribution
The `llamastack/distribution-lmstudio` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::lmstudio` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
### Environment Variables
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
### Models
The following models are available by default:
- `meta-llama-3-8b-instruct `
- `meta-llama-3-70b-instruct `
- `meta-llama-3.1-8b-instruct `
- `meta-llama-3.1-70b-instruct `
- `llama-3.2-1b-instruct `
- `llama-3.2-3b-instruct `
- `llama-3.3-70b-instruct `
- `nomic-embed-text-v1.5 `
- `all-minilm-l6-v2 `
## Set up LM Studio
Download LM Studio from [https://lmstudio.ai/download](https://lmstudio.ai/download). Start the server by opening LM Studio and navigating to the `Developer` Tab, or, run the CLI command `lms server start`.
## Running Llama Stack with LM Studio
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-lmstudio \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT
```
### Via Conda
```bash
llama stack build --template lmstudio --image-type conda
llama stack run ./run.yaml \
--port 5001
```

View file

@ -233,6 +233,7 @@ class InferenceRouter(Inference):
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
return 1
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:

View file

@ -298,4 +298,13 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="lmstudio",
pip_packages=["lmstudio"],
module="llama_stack.providers.remote.inference.lmstudio",
config_class="llama_stack.providers.remote.inference.lmstudio.LMStudioImplConfig",
),
),
]

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import LMStudioImplConfig
async def get_adapter_impl(config: LMStudioImplConfig, _deps):
from .lmstudio import LMStudioInferenceAdapter
impl = LMStudioInferenceAdapter(config.url)
await impl.initialize()
return impl

View file

@ -0,0 +1,478 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
import re
from typing import Any, AsyncIterator, List, Literal, Optional, Union
import lmstudio as lms
from openai import AsyncOpenAI as OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
GrammarResponseFormat,
GreedySamplingStrategy,
JsonSchemaResponseFormat,
Message,
SamplingParams,
StopReason,
ToolConfig,
ToolDefinition,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media,
interleaved_content_as_str,
)
LlmPredictionStopReason = Literal[
"userStopped",
"modelUnloaded",
"failed",
"eosFound",
"stopStringFound",
"toolCalls",
"maxPredictedTokensReached",
"contextLengthReached",
]
class LMStudioClient:
def __init__(self, url: str) -> None:
self.url = url
self.sdk_client = lms.Client(self.url)
self.openai_client = OpenAI(base_url=f"http://{url}/v1", api_key="lmstudio")
# Standard error handling helper methods
def _log_error(self, error, context=""):
"""Centralized error logging method"""
logging.warning(f"Error in LMStudio {context}: {error}")
async def _create_fallback_chat_stream(
self, error_message="I encountered an error processing your request."
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
"""Create a standardized fallback stream for chat completions"""
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=error_message),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
)
)
async def _create_fallback_completion_stream(self, error_message="Error processing response"):
"""Create a standardized fallback stream for text completions"""
yield CompletionResponseStreamChunk(
delta=error_message,
)
def _create_fallback_chat_response(
self, error_message="I encountered an error processing your request."
) -> ChatCompletionResponse:
"""Create a standardized fallback response for chat completions"""
return ChatCompletionResponse(
completion_message=CompletionMessage(
role="assistant",
content=error_message,
stop_reason=StopReason.end_of_message,
)
)
def _create_fallback_completion_response(self, error_message="Error processing response") -> CompletionResponse:
"""Create a standardized fallback response for text completions"""
return CompletionResponse(
content=error_message,
stop_reason=StopReason.end_of_message,
)
def _handle_json_extraction(self, content, context="JSON extraction"):
"""Standardized method to extract valid JSON from potentially malformed content"""
try:
json_content = json.loads(content)
return json.dumps(json_content) # Re-serialize to ensure valid JSON
except json.JSONDecodeError as e:
self._log_error(e, f"{context} - Attempting to extract valid JSON")
json_patterns = [
r"(\{.*\})", # Match anything between curly braces
r"(\[.*\])", # Match anything between square brackets
r"```json\s*([\s\S]*?)\s*```", # Match content in JSON code blocks
r"```\s*([\s\S]*?)\s*```", # Match content in any code blocks
]
for pattern in json_patterns:
json_match = re.search(pattern, content, re.DOTALL)
if json_match:
valid_json = json_match.group(1)
try:
json_content = json.loads(valid_json)
return json.dumps(json_content) # Re-serialize to ensure valid JSON
except json.JSONDecodeError:
continue # Try the next pattern
# If we couldn't extract valid JSON, log a warning
self._log_error("Failed to extract valid JSON", context)
return None
async def check_if_model_present_in_lmstudio(self, provider_model_id):
models = await asyncio.to_thread(self.sdk_client.list_downloaded_models)
model_ids = [m.model_key for m in models]
if provider_model_id in model_ids:
return True
model_ids = [id.split("/")[-1] for id in model_ids]
if provider_model_id in model_ids:
return True
return False
async def get_embedding_model(self, provider_model_id: str):
model = await asyncio.to_thread(self.sdk_client.embedding.model, provider_model_id)
return model
async def embed(self, embedding_model: lms.EmbeddingModel, contents: Union[str, List[str]]):
embeddings = await asyncio.to_thread(embedding_model.embed, contents)
return embeddings
async def get_llm(self, provider_model_id: str) -> lms.LLM:
model = await asyncio.to_thread(self.sdk_client.llm.model, provider_model_id)
return model
async def _llm_respond_non_tools(
self,
llm: lms.LLM,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
json_schema: Optional[JsonSchemaResponseFormat] = None,
stream: Optional[bool] = False,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
chat = self._convert_message_list_to_lmstudio_chat(messages)
config = self._get_completion_config_from_params(sampling_params)
if stream:
async def stream_generator() -> AsyncIterator[ChatCompletionResponseStreamChunk]:
prediction_stream = await asyncio.to_thread(
llm.respond_stream,
history=chat,
config=config,
response_format=json_schema,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
async for chunk in self._async_iterate(prediction_stream):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=chunk.content),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
)
)
return stream_generator()
else:
response = await asyncio.to_thread(
llm.respond,
history=chat,
config=config,
response_format=json_schema,
)
return self._convert_prediction_to_chat_response(response)
async def _llm_respond_with_tools(
self,
llm: lms.LLM,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
json_schema: Optional[JsonSchemaResponseFormat] = None,
stream: Optional[bool] = False,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
try:
model_key = llm.get_info().model_key
request = ChatCompletionRequest(
model=model_key,
messages=messages,
sampling_params=sampling_params,
response_format=json_schema,
tools=tools,
tool_config=tool_config,
stream=stream,
)
rest_request = await self._convert_request_to_rest_call(request)
if stream:
try:
stream = await self.openai_client.chat.completions.create(**rest_request)
return convert_openai_chat_completion_stream(stream, enable_incremental_tool_calls=True)
except Exception as e:
self._log_error(e, "streaming tool calling")
return self._create_fallback_chat_stream()
try:
response = await self.openai_client.chat.completions.create(**rest_request)
if response:
result = convert_openai_chat_completion_choice(response.choices[0])
return result
else:
# Handle empty response
self._log_error("Empty response from OpenAI API", "chat completion")
return self._create_fallback_chat_response()
except Exception as e:
self._log_error(e, "non-streaming tool calling")
return self._create_fallback_chat_response()
except Exception as e:
self._log_error(e, "_llm_respond_with_tools")
# Return a fallback response
if stream:
return self._create_fallback_chat_stream()
else:
return self._create_fallback_chat_response()
async def llm_respond(
self,
llm: lms.LLM,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
json_schema: Optional[JsonSchemaResponseFormat] = None,
stream: Optional[bool] = False,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if tools is None or len(tools) == 0:
return await self._llm_respond_non_tools(
llm=llm,
messages=messages,
sampling_params=sampling_params,
json_schema=json_schema,
stream=stream,
)
else:
return await self._llm_respond_with_tools(
llm=llm,
messages=messages,
sampling_params=sampling_params,
json_schema=json_schema,
stream=stream,
tools=tools,
tool_config=tool_config,
)
async def llm_completion(
self,
llm: lms.LLM,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
json_schema: Optional[JsonSchemaResponseFormat] = None,
stream: Optional[bool] = False,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
config = self._get_completion_config_from_params(sampling_params)
if stream:
async def stream_generator() -> AsyncIterator[CompletionResponseStreamChunk]:
try:
prediction_stream = await asyncio.to_thread(
llm.complete_stream,
prompt=interleaved_content_as_str(content),
config=config,
response_format=json_schema,
)
async for chunk in self._async_iterate(prediction_stream):
yield CompletionResponseStreamChunk(
delta=chunk.content,
)
except Exception as e:
self._log_error(e, "streaming completion")
# Return a fallback response in case of error
yield CompletionResponseStreamChunk(
delta="Error processing response",
)
return stream_generator()
else:
try:
response = await asyncio.to_thread(
llm.complete,
prompt=interleaved_content_as_str(content),
config=config,
response_format=json_schema,
)
# If we have a JSON schema, ensure the response is valid JSON
if json_schema is not None:
valid_json = self._handle_json_extraction(response.content, "completion response")
if valid_json:
return CompletionResponse(
content=valid_json, # Already serialized in _handle_json_extraction
stop_reason=self._get_stop_reason(response.stats.stop_reason),
)
# If we couldn't extract valid JSON, continue with the original content
return CompletionResponse(
content=response.content,
stop_reason=self._get_stop_reason(response.stats.stop_reason),
)
except Exception as e:
self._log_error(e, "LMStudio completion")
# Return a fallback response with an error message
return self._create_fallback_completion_response()
def _convert_message_list_to_lmstudio_chat(self, messages: List[Message]) -> lms.Chat:
chat = lms.Chat()
for message in messages:
if content_has_media(message.content):
raise NotImplementedError("Media content is not supported in LMStudio messages")
if message.role == "user":
chat.add_user_message(interleaved_content_as_str(message.content))
elif message.role == "system":
chat.add_system_prompt(interleaved_content_as_str(message.content))
elif message.role == "assistant":
chat.add_assistant_response(interleaved_content_as_str(message.content))
else:
raise ValueError(f"Unsupported message role: {message.role}")
return chat
def _convert_prediction_to_chat_response(self, result: lms.PredictionResult) -> ChatCompletionResponse:
response = ChatCompletionResponse(
completion_message=CompletionMessage(
content=result.content,
stop_reason=self._get_stop_reason(result.stats.stop_reason),
tool_calls=None,
)
)
return response
def _get_completion_config_from_params(
self,
params: Optional[SamplingParams] = None,
) -> lms.LlmPredictionConfigDict:
options = lms.LlmPredictionConfigDict()
if params is None:
return options
if isinstance(params.strategy, GreedySamplingStrategy):
options.update({"temperature": 0.0})
elif isinstance(params.strategy, TopPSamplingStrategy):
options.update(
{
"temperature": params.strategy.temperature,
"topPSampling": params.strategy.top_p,
}
)
elif isinstance(params.strategy, TopKSamplingStrategy):
options.update({"topKSampling": params.strategy.top_k})
else:
raise ValueError(f"Unsupported sampling strategy: {params.strategy}")
options.update(
{
"maxTokens": params.max_tokens if params.max_tokens != 0 else None,
"repetitionPenalty": (params.repetition_penalty if params.repetition_penalty != 0 else None),
}
)
return options
def _get_stop_reason(self, stop_reason: LlmPredictionStopReason) -> StopReason:
if stop_reason == "eosFound":
return StopReason.end_of_message
elif stop_reason == "maxPredictedTokensReached":
return StopReason.out_of_tokens
else:
return StopReason.end_of_turn
async def _async_iterate(self, iterable):
"""Asynchronously iterate over a synchronous iterable."""
iterator = iter(iterable)
def safe_next(it):
"""This is necessary to communicate StopIteration across threads"""
try:
return (next(it), False)
except StopIteration:
return (None, True)
while True:
item, done = await asyncio.to_thread(safe_next, iterator)
if done:
break
yield item
async def _convert_request_to_rest_call(self, request: ChatCompletionRequest) -> dict:
compatible_request = self._convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model
compatible_request["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
if request.response_format:
if isinstance(request.response_format, JsonSchemaResponseFormat):
compatible_request["response_format"] = {
"type": "json_schema",
"json_schema": request.response_format.json_schema,
}
elif isinstance(request.response_format, GrammarResponseFormat):
compatible_request["response_format"] = {
"type": "grammar",
"bnf": request.response_format.bnf,
}
if request.tools is not None:
compatible_request["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
compatible_request["logprobs"] = False
compatible_request["stream"] = request.stream
compatible_request["extra_headers"] = {b"User-Agent": b"llama-stack: lmstudio-inference-adapter"}
return compatible_request
def _convert_sampling_params(self, sampling_params: Optional[SamplingParams]) -> dict:
params: dict[str, Any] = {}
if sampling_params is None:
return params
params["frequency_penalty"] = sampling_params.repetition_penalty
if sampling_params.max_tokens:
params["max_completion_tokens"] = sampling_params.max_tokens
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
params["top_p"] = sampling_params.strategy.top_p
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
params["extra_body"]["top_k"] = sampling_params.strategy.top_k
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
params["temperature"] = 0.0
return params

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
DEFAULT_LMSTUDIO_URL = "localhost:1234"
class LMStudioImplConfig(BaseModel):
url: str = DEFAULT_LMSTUDIO_URL
@classmethod
def sample_run_config(cls, url: str = DEFAULT_LMSTUDIO_URL, **kwargs) -> Dict[str, Any]:
return {"url": url}

View file

@ -0,0 +1,278 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
)
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media,
)
from .models import MODEL_ENTRIES
class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
self.url = url
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
@property
def client(self) -> LMStudioClient:
return LMStudioClient(url=self.url)
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported by LM Studio Provider")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported by LM Studio Provider")
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model_obj = await self.model_store.get_model(model)
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"messages": messages,
"frequency_penalty": frequency_penalty,
"function_call": function_call,
"functions": functions,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"n": n,
"parallel_tool_calls": parallel_tool_calls,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"top_logprobs": top_logprobs,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
return await self.openai_client.chat.completions.create(**params) # type: ignore
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("LM Studio does not support non-string prompts for completion")
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model_obj = await self.model_store.get_model(model)
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"prompt": prompt,
"best_of": best_of,
"echo": echo,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_tokens": max_tokens,
"n": n,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
return await self.openai_client.completions.create(**params) # type: ignore
async def initialize(self) -> None:
pass
async def register_model(self, model):
await self.register_helper.register_model(model)
return model
async def unregister_model(self, model_id):
pass
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
assert all(not content_has_media(content) for content in contents), (
"Media content not supported in embedding model"
)
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id)
embedding_model = await self.client.get_embedding_model(model.provider_model_id)
string_contents = [item.text if hasattr(item, "text") else str(item) for item in contents]
embeddings = await self.client.embed(embedding_model, string_contents)
return EmbeddingsResponse(embeddings=embeddings)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None, # Default value changed from ToolChoice.auto to None
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[
Union[JsonSchemaResponseFormat, GrammarResponseFormat]
] = None, # Moved and type changed
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id)
json_schema_format = response_format if isinstance(response_format, JsonSchemaResponseFormat) else None
if response_format is not None and not isinstance(response_format, JsonSchemaResponseFormat):
raise ValueError(
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
)
return await self.client.llm_respond(
llm=llm,
messages=messages,
sampling_params=sampling_params,
json_schema=json_schema_format,
stream=stream,
tool_config=tool_config,
tools=tools,
)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, # Skip this for now
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id)
if content_has_media(content):
raise NotImplementedError("Media content not supported in LM Studio Provider")
if not isinstance(response_format, JsonSchemaResponseFormat):
raise ValueError(
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
)
return await self.client.llm_completion(llm, content, sampling_params, response_format, stream)

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
MODEL_ENTRIES = [
ProviderModelEntry(
provider_model_id="meta-llama-3-8b-instruct",
aliases=[],
llama_model=CoreModelId.llama3_8b_instruct.value,
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="meta-llama-3-70b-instruct",
aliases=[],
llama_model=CoreModelId.llama3_70b_instruct.value,
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="meta-llama-3.1-8b-instruct",
aliases=[],
llama_model=CoreModelId.llama3_1_8b_instruct.value,
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="meta-llama-3.1-70b-instruct",
aliases=[],
llama_model=CoreModelId.llama3_1_70b_instruct.value,
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="llama-3.2-1b-instruct",
aliases=[],
llama_model=CoreModelId.llama3_2_1b_instruct.value,
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="llama-3.2-3b-instruct",
aliases=[],
llama_model=CoreModelId.llama3_2_3b_instruct.value,
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="llama-3.3-70b-instruct",
aliases=[],
llama_model=CoreModelId.llama3_3_70b_instruct.value,
model_type=ModelType.llm,
),
# embedding model
ProviderModelEntry(
provider_model_id="nomic-embed-text-v1.5",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 2048,
},
),
ProviderModelEntry(
provider_model_id="all-minilm-l6-v2",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
),
]

View file

@ -344,6 +344,42 @@
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"lmstudio": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"lmstudio",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"meta-reference-gpu": [
"accelerate",
"aiosqlite",

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .lmstudio import get_distribution_template # noqa: F401

View file

@ -0,0 +1,30 @@
version: '2'
distribution_spec:
description: Use LM Studio for running LLM inference
providers:
inference:
- remote::lmstudio
safety:
- inline::llama-guard
vector_io:
- inline::faiss
- remote::chromadb
- remote::pgvector
agents:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
telemetry:
- inline::meta-reference
tool_runtime:
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
image_type: conda

View file

@ -0,0 +1,58 @@
# LM Studio Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
## Set up LM Studio
Download LM Studio from [https://lmstudio.ai/download](https://lmstudio.ai/download). Start the server by opening LM Studio and navigating to the `Developer` Tab, or, run the CLI command `lms server start`.
## Running Llama Stack with LM Studio
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT
```
### Via Conda
```bash
llama stack build --template lmstudio --image-type conda
llama stack run ./run.yaml \
--port 5001
```

View file

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.lmstudio import LMStudioImplConfig
from llama_stack.providers.remote.inference.lmstudio.models import MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::lmstudio"],
"safety": ["inline::llama-guard"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"agents": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"telemetry": ["inline::meta-reference"],
"tool_runtime": [
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
],
}
name = "lmstudio"
lmstudio_provider = Provider(
provider_id="lmstudio",
provider_type="remote::lmstudio",
config=LMStudioImplConfig.sample_run_config(),
)
available_models = {
"lmstudio": MODEL_ENTRIES,
}
default_models = get_model_registry(available_models)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name="lmstudio",
distro_type="self_hosted",
description="Use LM Studio for running LLM inference",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [lmstudio_provider],
"vector_io": [vector_io_provider],
},
default_models=default_models,
default_shields=[],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
},
)

View file

@ -0,0 +1,44 @@
# Report for LM Studio distribution
## Supported Models
| Model Descriptor | lmstudio |
|:---|:---|
| meta-llama/Llama-3-8B-Instruct | ✅ |
| meta-llama/Llama-3-70B-Instruct | ✅ |
| meta-llama/Llama-3.1-8B-Instruct | ✅ |
| meta-llama/Llama-3.1-70B-Instruct | ✅ |
| meta-llama/Llama-3.1-405B-Instruct-FP8 | ✅ |
| meta-llama/Llama-3.2-1B-Instruct | ✅ |
| meta-llama/Llama-3.2-3B-Instruct | ✅ |
| meta-llama/Llama-3.2-11B-Vision-Instruct | ❌ |
| meta-llama/Llama-3.2-90B-Vision-Instruct | ❌ |
| meta-llama/Llama-3.3-70B-Instruct | ✅ |
| meta-llama/Llama-Guard-3-11B-Vision | ❌ |
| meta-llama/Llama-Guard-3-1B | ❌ |
| meta-llama/Llama-Guard-3-8B | ❌ |
| meta-llama/Llama-Guard-2-8B | ❌ |
## Inference
| Model | API | Capability | Test | Status |
|:----- |:-----|:-----|:-----|:-----|
| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ |
| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ |
| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ |
| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ |
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ❌ |
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ |
| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ |
| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ |
| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ❌ |
## Vector IO
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| /retrieve | | test_vector_db_retrieve | ✅ |
## Agents
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| /create_agent_turn | rag | test_rag_agent | ❓ |
| /create_agent_turn | custom_tool | test_custom_tool | ❓ |
| /create_agent_turn | code_execution | test_code_interpreter_for_attachments | ❓ |

View file

@ -0,0 +1,158 @@
version: '2'
image_name: lmstudio
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: lmstudio
provider_type: remote::lmstudio
config:
url: localhost:1234
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/lmstudio}/faiss_store.db
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/lmstudio}/agents_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/lmstudio}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/lmstudio}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/lmstudio}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/lmstudio/trace_store.db}
tool_runtime:
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/lmstudio}/registry.db
models:
- metadata: {}
model_id: meta-llama-3-8b-instruct
provider_id: lmstudio
provider_model_id: meta-llama-3-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama-3-70b-instruct
provider_id: lmstudio
provider_model_id: meta-llama-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama-3.1-8b-instruct
provider_id: lmstudio
provider_model_id: meta-llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama-3.1-70b-instruct
provider_id: lmstudio
provider_model_id: meta-llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: llama-3.2-1b-instruct
provider_id: lmstudio
provider_model_id: llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: llama-3.2-3b-instruct
provider_id: lmstudio
provider_model_id: llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: llama-3.3-70b-instruct
provider_id: lmstudio
provider_model_id: llama-3.3-70b-instruct
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 2048
model_id: nomic-embed-text-v1.5
provider_id: lmstudio
provider_model_id: nomic-embed-text-v1.5
model_type: embedding
- metadata:
embedding_dimension: 384
model_id: all-minilm-l6-v2
provider_id: lmstudio
provider_model_id: all-minilm-l6-v2
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -19,6 +19,7 @@
| Together | 50.0% | 40 | 80 |
| Fireworks | 50.0% | 40 | 80 |
| Openai | 100.0% | 56 | 56 |
| Lmstudio | 100.0% | 24 | 24 |
@ -230,3 +231,48 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai
| test_chat_streaming_tool_calling | ✅ | ✅ |
| test_chat_streaming_tool_choice_none | ✅ | ✅ |
| test_chat_streaming_tool_choice_required | ✅ | ✅ |
## Lmstudio
```bash
# Run all tests for this provider:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=lmstudio -v
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=lmstudio -k "test_chat_non_streaming_basic and earth"
```
**Model Key (Lmstudio)**
| Display Name | Full Model ID |
| --- | --- |
| Llama-4-Scout-Instruct | `llama-4-scout-17b-16e-instruct` |
| Test | Llama-4-Scout-Instruct |
| --- | --- |
| test_chat_non_streaming_basic (earth) | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
| test_chat_non_streaming_structured_output (calendar) | ✅ |
| test_chat_non_streaming_structured_output (math) | ✅ |
| test_chat_non_streaming_tool_calling | ✅ |
| test_chat_non_streaming_tool_choice_none | ✅ |
| test_chat_non_streaming_tool_choice_required | ✅ |
| test_chat_streaming_basic (earth) | ✅ |
| test_chat_streaming_basic (saturn) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
| test_chat_streaming_structured_output (calendar) | ✅ |
| test_chat_streaming_structured_output (math) | ✅ |
| test_chat_streaming_tool_calling | ✅ |
| test_chat_streaming_tool_choice_none | ✅ |
| test_chat_streaming_tool_choice_required | ✅ |

View file

@ -0,0 +1,10 @@
base_url: http://localhost:1234/v1/
models:
- llama-4-scout-17b-16e-instruct
model_display_names:
llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
test_exclusions:
llama-4-scout-17b-16e-instruct:
- test_chat_non_streaming_image
- test_chat_streaming_image
- test_chat_multi_turn_multiple_images

File diff suppressed because it is too large Load diff