Pass 1 for pre-commit fixes

This commit is contained in:
Matt Clayton 2025-04-27 15:24:37 -04:00
parent cfc6bdae68
commit 59e1c5f4a0
7 changed files with 119 additions and 109 deletions

View file

@ -33,7 +33,7 @@ The following models are available by default:
- `meta-llama-3.1-70b-instruct `
- `llama-3.2-1b-instruct `
- `llama-3.2-3b-instruct `
- `llama-3.2-70b-instruct `
- `llama-3.3-70b-instruct `
- `nomic-embed-text-v1.5 `
- `all-minilm-l6-v2 `

View file

@ -1,13 +1,20 @@
import asyncio
from typing import AsyncIterator, AsyncGenerator, List, Literal, Optional, Union
import lmstudio as lms
import json
import re
import logging
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
import re
from typing import Any, AsyncIterator, List, Literal, Optional, Union
import lmstudio as lms
from openai import AsyncOpenAI as OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
from llama_stack.apis.inference.inference import (
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
@ -16,15 +23,14 @@ from llama_stack.apis.inference.inference import (
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
GrammarResponseFormat,
GreedySamplingStrategy,
JsonSchemaResponseFormat,
Message,
ToolConfig,
ToolDefinition,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
SamplingParams,
StopReason,
ToolConfig,
ToolDefinition,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
@ -38,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media,
interleaved_content_as_str,
)
from openai import AsyncOpenAI as OpenAI
LlmPredictionStopReason = Literal[
"userStopped",
@ -57,13 +62,15 @@ class LMStudioClient:
self.url = url
self.sdk_client = lms.Client(self.url)
self.openai_client = OpenAI(base_url=f"http://{url}/v1", api_key="lmstudio")
# Standard error handling helper methods
def _log_error(self, error, context=""):
"""Centralized error logging method"""
logging.warning(f"Error in LMStudio {context}: {error}")
async def _create_fallback_chat_stream(self, error_message="I encountered an error processing your request."):
async def _create_fallback_chat_stream(
self, error_message="I encountered an error processing your request."
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
"""Create a standardized fallback stream for chat completions"""
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -83,30 +90,32 @@ class LMStudioClient:
delta=TextDelta(text=""),
)
)
async def _create_fallback_completion_stream(self, error_message="Error processing response"):
"""Create a standardized fallback stream for text completions"""
yield CompletionResponseStreamChunk(
delta=error_message,
)
def _create_fallback_chat_response(self, error_message="I encountered an error processing your request."):
def _create_fallback_chat_response(
self, error_message="I encountered an error processing your request."
) -> ChatCompletionResponse:
"""Create a standardized fallback response for chat completions"""
return ChatCompletionResponse(
message=Message(
completion_message=CompletionMessage(
role="assistant",
content=error_message,
),
stop_reason=StopReason.end_of_message,
stop_reason=StopReason.end_of_message,
)
)
def _create_fallback_completion_response(self, error_message="Error processing response"):
def _create_fallback_completion_response(self, error_message="Error processing response") -> CompletionResponse:
"""Create a standardized fallback response for text completions"""
return CompletionResponse(
content=error_message,
stop_reason=StopReason.end_of_message,
)
def _handle_json_extraction(self, content, context="JSON extraction"):
"""Standardized method to extract valid JSON from potentially malformed content"""
try:
@ -114,14 +123,14 @@ class LMStudioClient:
return json.dumps(json_content) # Re-serialize to ensure valid JSON
except json.JSONDecodeError as e:
self._log_error(e, f"{context} - Attempting to extract valid JSON")
json_patterns = [
r'(\{.*\})', # Match anything between curly braces
r'(\[.*\])', # Match anything between square brackets
r'```json\s*([\s\S]*?)\s*```', # Match content in JSON code blocks
r'```\s*([\s\S]*?)\s*```', # Match content in any code blocks
r"(\{.*\})", # Match anything between curly braces
r"(\[.*\])", # Match anything between square brackets
r"```json\s*([\s\S]*?)\s*```", # Match content in JSON code blocks
r"```\s*([\s\S]*?)\s*```", # Match content in any code blocks
]
for pattern in json_patterns:
json_match = re.search(pattern, content, re.DOTALL)
if json_match:
@ -131,7 +140,7 @@ class LMStudioClient:
return json.dumps(json_content) # Re-serialize to ensure valid JSON
except json.JSONDecodeError:
continue # Try the next pattern
# If we couldn't extract valid JSON, log a warning
self._log_error("Failed to extract valid JSON", context)
return None
@ -148,14 +157,10 @@ class LMStudioClient:
return False
async def get_embedding_model(self, provider_model_id: str):
model = await asyncio.to_thread(
self.sdk_client.embedding.model, provider_model_id
)
model = await asyncio.to_thread(self.sdk_client.embedding.model, provider_model_id)
return model
async def embed(
self, embedding_model: lms.EmbeddingModel, contents: Union[str, List[str]]
):
async def embed(self, embedding_model: lms.EmbeddingModel, contents: Union[str, List[str]]):
embeddings = await asyncio.to_thread(embedding_model.embed, contents)
return embeddings
@ -170,14 +175,12 @@ class LMStudioClient:
sampling_params: Optional[SamplingParams] = None,
json_schema: Optional[JsonSchemaResponseFormat] = None,
stream: Optional[bool] = False,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
chat = self._convert_message_list_to_lmstudio_chat(messages)
config = self._get_completion_config_from_params(sampling_params)
if stream:
async def stream_generator():
async def stream_generator() -> AsyncIterator[ChatCompletionResponseStreamChunk]:
prediction_stream = await asyncio.to_thread(
llm.respond_stream,
history=chat,
@ -191,7 +194,7 @@ class LMStudioClient:
delta=TextDelta(text=""),
)
)
async for chunk in self._async_iterate(prediction_stream):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -225,9 +228,7 @@ class LMStudioClient:
stream: Optional[bool] = False,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
try:
model_key = llm.get_info().model_key
request = ChatCompletionRequest(
@ -240,17 +241,15 @@ class LMStudioClient:
stream=stream,
)
rest_request = await self._convert_request_to_rest_call(request)
if stream:
try:
stream = await self.openai_client.chat.completions.create(**rest_request)
return convert_openai_chat_completion_stream(
stream, enable_incremental_tool_calls=True
)
return convert_openai_chat_completion_stream(stream, enable_incremental_tool_calls=True)
except Exception as e:
self._log_error(e, "streaming tool calling")
return self._create_fallback_chat_stream()
try:
response = await self.openai_client.chat.completions.create(**rest_request)
if response:
@ -280,9 +279,7 @@ class LMStudioClient:
stream: Optional[bool] = False,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if tools is None or len(tools) == 0:
return await self._llm_respond_non_tools(
llm=llm,
@ -313,7 +310,7 @@ class LMStudioClient:
config = self._get_completion_config_from_params(sampling_params)
if stream:
async def stream_generator():
async def stream_generator() -> AsyncIterator[CompletionResponseStreamChunk]:
try:
prediction_stream = await asyncio.to_thread(
llm.complete_stream,
@ -341,7 +338,7 @@ class LMStudioClient:
config=config,
response_format=json_schema,
)
# If we have a JSON schema, ensure the response is valid JSON
if json_schema is not None:
valid_json = self._handle_json_extraction(response.content, "completion response")
@ -351,7 +348,7 @@ class LMStudioClient:
stop_reason=self._get_stop_reason(response.stats.stop_reason),
)
# If we couldn't extract valid JSON, continue with the original content
return CompletionResponse(
content=response.content,
stop_reason=self._get_stop_reason(response.stats.stop_reason),
@ -361,15 +358,11 @@ class LMStudioClient:
# Return a fallback response with an error message
return self._create_fallback_completion_response()
def _convert_message_list_to_lmstudio_chat(
self, messages: List[Message]
) -> lms.Chat:
def _convert_message_list_to_lmstudio_chat(self, messages: List[Message]) -> lms.Chat:
chat = lms.Chat()
for message in messages:
if content_has_media(message.content):
raise NotImplementedError(
"Media content is not supported in LMStudio messages"
)
raise NotImplementedError("Media content is not supported in LMStudio messages")
if message.role == "user":
chat.add_user_message(interleaved_content_as_str(message.content))
elif message.role == "system":
@ -380,9 +373,7 @@ class LMStudioClient:
raise ValueError(f"Unsupported message role: {message.role}")
return chat
def _convert_prediction_to_chat_response(
self, result: lms.PredictionResult
) -> ChatCompletionResponse:
def _convert_prediction_to_chat_response(self, result: lms.PredictionResult) -> ChatCompletionResponse:
response = ChatCompletionResponse(
completion_message=CompletionMessage(
content=result.content,
@ -415,11 +406,7 @@ class LMStudioClient:
options.update(
{
"maxTokens": params.max_tokens if params.max_tokens != 0 else None,
"repetitionPenalty": (
params.repetition_penalty
if params.repetition_penalty != 0
else None
),
"repetitionPenalty": (params.repetition_penalty if params.repetition_penalty != 0 else None),
}
)
return options
@ -449,32 +436,30 @@ class LMStudioClient:
break
yield item
async def _convert_request_to_rest_call(
self, request: ChatCompletionRequest
) -> dict:
async def _convert_request_to_rest_call(self, request: ChatCompletionRequest) -> dict:
compatible_request = self._convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model
compatible_request["messages"] = [
await convert_message_to_openai_dict_new(m) for m in request.messages
]
compatible_request["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
if request.response_format:
compatible_request["response_format"] = {
"type": "json_schema",
"json_schema": request.response_format.json_schema,
}
if isinstance(request.response_format, JsonSchemaResponseFormat):
compatible_request["response_format"] = {
"type": "json_schema",
"json_schema": request.response_format.json_schema,
}
elif isinstance(request.response_format, GrammarResponseFormat):
compatible_request["response_format"] = {
"type": "grammar",
"bnf": request.response_format.bnf,
}
if request.tools is not None:
compatible_request["tools"] = [
convert_tooldef_to_openai_tool(tool) for tool in request.tools
]
compatible_request["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
compatible_request["logprobs"] = False
compatible_request["stream"] = request.stream
compatible_request["extra_headers"] = {
b"User-Agent": b"llama-stack: lmstudio-inference-adapter"
}
compatible_request["extra_headers"] = {b"User-Agent": b"llama-stack: lmstudio-inference-adapter"}
return compatible_request
def _convert_sampling_params(self, sampling_params: Optional[SamplingParams]) -> dict:
params = {}
params: dict[str, Any] = {}
if sampling_params is None:
return params

View file

@ -14,7 +14,9 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
@ -29,7 +31,6 @@ from llama_stack.apis.inference.inference import (
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
ResponseFormatType,
)
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient
@ -50,6 +51,18 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
def client(self) -> LMStudioClient:
return LMStudioClient(url=self.url)
async def batch_chat_completion(self, *args, **kwargs):
raise NotImplementedError("Batch chat completion not supported by LM Studio Provider")
async def batch_completion(self, *args, **kwargs):
raise NotImplementedError("Batch completion not supported by LM Studio Provider")
async def openai_chat_completion(self, *args, **kwargs):
raise NotImplementedError("OpenAI chat completion not supported by LM Studio Provider")
async def openai_completion(self, *args, **kwargs):
raise NotImplementedError("OpenAI completion not supported by LM Studio Provider")
async def initialize(self) -> None:
pass
@ -71,9 +84,12 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
assert all(not content_has_media(content) for content in contents), (
"Media content not supported in embedding model"
)
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id)
embedding_model = await self.client.get_embedding_model(model.provider_model_id)
embeddings = await self.client.embed(embedding_model, contents)
string_contents = [item.text if hasattr(item, "text") else str(item) for item in contents]
embeddings = await self.client.embed(embedding_model, string_contents)
return EmbeddingsResponse(embeddings=embeddings)
async def chat_completion(
@ -81,26 +97,31 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_choice: Optional[ToolChoice] = None, # Default value changed from ToolChoice.auto to None
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[
Union[JsonSchemaResponseFormat, GrammarResponseFormat]
] = None, # Moved and type changed
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id)
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
json_schema = response_format.json_schema if response_format else None
json_schema_format = response_format if isinstance(response_format, JsonSchemaResponseFormat) else None
if response_format is not None and not isinstance(response_format, JsonSchemaResponseFormat):
raise ValueError(
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
)
return await self.client.llm_respond(
llm=llm,
messages=messages,
sampling_params=sampling_params,
json_schema=json_schema,
json_schema=json_schema_format,
stream=stream,
tool_config=tool_config,
tools=tools,
@ -115,13 +136,16 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, # Skip this for now
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id)
if content_has_media(content):
raise NotImplementedError("Media content not supported in LM Studio Provider")
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
json_schema = response_format.json_schema if response_format else None
if not isinstance(response_format, JsonSchemaResponseFormat):
raise ValueError(
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
)
return await self.client.llm_completion(llm, content, sampling_params, json_schema, stream)
return await self.client.llm_completion(llm, content, sampling_params, response_format, stream)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
@ -63,9 +63,7 @@ MODEL_ENTRIES = [
},
),
ProviderModelEntry(
model_id="all-MiniLM-L6-v2",
provider_model_id="all-minilm-l6-v2",
provider_id="lmstudio",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,

View file

@ -351,10 +351,12 @@
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"lmstudio",
"matplotlib",
"nltk",
@ -367,6 +369,7 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -374,6 +377,7 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"meta-reference-gpu": [

View file

@ -24,7 +24,6 @@ distribution_spec:
telemetry:
- inline::meta-reference
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime

View file

@ -75,7 +75,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/lmstudio/trace_store.db}
tool_runtime:
@ -125,9 +125,9 @@ models:
provider_model_id: llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: llama-3.2-70b-instruct
model_id: llama-3.3-70b-instruct
provider_id: lmstudio
provider_model_id: llama-3.2-70b-instruct
provider_model_id: llama-3.3-70b-instruct
model_type: llm
- metadata:
embedding_dimension: 768