mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Pass 1 for pre-commit fixes
This commit is contained in:
parent
cfc6bdae68
commit
59e1c5f4a0
7 changed files with 119 additions and 109 deletions
|
@ -33,7 +33,7 @@ The following models are available by default:
|
||||||
- `meta-llama-3.1-70b-instruct `
|
- `meta-llama-3.1-70b-instruct `
|
||||||
- `llama-3.2-1b-instruct `
|
- `llama-3.2-1b-instruct `
|
||||||
- `llama-3.2-3b-instruct `
|
- `llama-3.2-3b-instruct `
|
||||||
- `llama-3.2-70b-instruct `
|
- `llama-3.3-70b-instruct `
|
||||||
- `nomic-embed-text-v1.5 `
|
- `nomic-embed-text-v1.5 `
|
||||||
- `all-minilm-l6-v2 `
|
- `all-minilm-l6-v2 `
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,20 @@
|
||||||
import asyncio
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
from typing import AsyncIterator, AsyncGenerator, List, Literal, Optional, Union
|
# All rights reserved.
|
||||||
import lmstudio as lms
|
#
|
||||||
import json
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
import re
|
# the root directory of this source tree.
|
||||||
import logging
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, AsyncIterator, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import lmstudio as lms
|
||||||
|
from openai import AsyncOpenAI as OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
|
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
|
@ -16,15 +23,14 @@ from llama_stack.apis.inference.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
|
GrammarResponseFormat,
|
||||||
|
GreedySamplingStrategy,
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
Message,
|
Message,
|
||||||
ToolConfig,
|
|
||||||
ToolDefinition,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
|
||||||
GreedySamplingStrategy,
|
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
ToolConfig,
|
||||||
|
ToolDefinition,
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
|
@ -38,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
content_has_media,
|
content_has_media,
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
from openai import AsyncOpenAI as OpenAI
|
|
||||||
|
|
||||||
LlmPredictionStopReason = Literal[
|
LlmPredictionStopReason = Literal[
|
||||||
"userStopped",
|
"userStopped",
|
||||||
|
@ -57,13 +62,15 @@ class LMStudioClient:
|
||||||
self.url = url
|
self.url = url
|
||||||
self.sdk_client = lms.Client(self.url)
|
self.sdk_client = lms.Client(self.url)
|
||||||
self.openai_client = OpenAI(base_url=f"http://{url}/v1", api_key="lmstudio")
|
self.openai_client = OpenAI(base_url=f"http://{url}/v1", api_key="lmstudio")
|
||||||
|
|
||||||
# Standard error handling helper methods
|
# Standard error handling helper methods
|
||||||
def _log_error(self, error, context=""):
|
def _log_error(self, error, context=""):
|
||||||
"""Centralized error logging method"""
|
"""Centralized error logging method"""
|
||||||
logging.warning(f"Error in LMStudio {context}: {error}")
|
logging.warning(f"Error in LMStudio {context}: {error}")
|
||||||
|
|
||||||
async def _create_fallback_chat_stream(self, error_message="I encountered an error processing your request."):
|
async def _create_fallback_chat_stream(
|
||||||
|
self, error_message="I encountered an error processing your request."
|
||||||
|
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
"""Create a standardized fallback stream for chat completions"""
|
"""Create a standardized fallback stream for chat completions"""
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -83,30 +90,32 @@ class LMStudioClient:
|
||||||
delta=TextDelta(text=""),
|
delta=TextDelta(text=""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_fallback_completion_stream(self, error_message="Error processing response"):
|
async def _create_fallback_completion_stream(self, error_message="Error processing response"):
|
||||||
"""Create a standardized fallback stream for text completions"""
|
"""Create a standardized fallback stream for text completions"""
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=error_message,
|
delta=error_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_fallback_chat_response(self, error_message="I encountered an error processing your request."):
|
def _create_fallback_chat_response(
|
||||||
|
self, error_message="I encountered an error processing your request."
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
"""Create a standardized fallback response for chat completions"""
|
"""Create a standardized fallback response for chat completions"""
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
message=Message(
|
completion_message=CompletionMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=error_message,
|
content=error_message,
|
||||||
),
|
stop_reason=StopReason.end_of_message,
|
||||||
stop_reason=StopReason.end_of_message,
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_fallback_completion_response(self, error_message="Error processing response"):
|
def _create_fallback_completion_response(self, error_message="Error processing response") -> CompletionResponse:
|
||||||
"""Create a standardized fallback response for text completions"""
|
"""Create a standardized fallback response for text completions"""
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
content=error_message,
|
content=error_message,
|
||||||
stop_reason=StopReason.end_of_message,
|
stop_reason=StopReason.end_of_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_json_extraction(self, content, context="JSON extraction"):
|
def _handle_json_extraction(self, content, context="JSON extraction"):
|
||||||
"""Standardized method to extract valid JSON from potentially malformed content"""
|
"""Standardized method to extract valid JSON from potentially malformed content"""
|
||||||
try:
|
try:
|
||||||
|
@ -114,14 +123,14 @@ class LMStudioClient:
|
||||||
return json.dumps(json_content) # Re-serialize to ensure valid JSON
|
return json.dumps(json_content) # Re-serialize to ensure valid JSON
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
self._log_error(e, f"{context} - Attempting to extract valid JSON")
|
self._log_error(e, f"{context} - Attempting to extract valid JSON")
|
||||||
|
|
||||||
json_patterns = [
|
json_patterns = [
|
||||||
r'(\{.*\})', # Match anything between curly braces
|
r"(\{.*\})", # Match anything between curly braces
|
||||||
r'(\[.*\])', # Match anything between square brackets
|
r"(\[.*\])", # Match anything between square brackets
|
||||||
r'```json\s*([\s\S]*?)\s*```', # Match content in JSON code blocks
|
r"```json\s*([\s\S]*?)\s*```", # Match content in JSON code blocks
|
||||||
r'```\s*([\s\S]*?)\s*```', # Match content in any code blocks
|
r"```\s*([\s\S]*?)\s*```", # Match content in any code blocks
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern in json_patterns:
|
for pattern in json_patterns:
|
||||||
json_match = re.search(pattern, content, re.DOTALL)
|
json_match = re.search(pattern, content, re.DOTALL)
|
||||||
if json_match:
|
if json_match:
|
||||||
|
@ -131,7 +140,7 @@ class LMStudioClient:
|
||||||
return json.dumps(json_content) # Re-serialize to ensure valid JSON
|
return json.dumps(json_content) # Re-serialize to ensure valid JSON
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
continue # Try the next pattern
|
continue # Try the next pattern
|
||||||
|
|
||||||
# If we couldn't extract valid JSON, log a warning
|
# If we couldn't extract valid JSON, log a warning
|
||||||
self._log_error("Failed to extract valid JSON", context)
|
self._log_error("Failed to extract valid JSON", context)
|
||||||
return None
|
return None
|
||||||
|
@ -148,14 +157,10 @@ class LMStudioClient:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_embedding_model(self, provider_model_id: str):
|
async def get_embedding_model(self, provider_model_id: str):
|
||||||
model = await asyncio.to_thread(
|
model = await asyncio.to_thread(self.sdk_client.embedding.model, provider_model_id)
|
||||||
self.sdk_client.embedding.model, provider_model_id
|
|
||||||
)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def embed(
|
async def embed(self, embedding_model: lms.EmbeddingModel, contents: Union[str, List[str]]):
|
||||||
self, embedding_model: lms.EmbeddingModel, contents: Union[str, List[str]]
|
|
||||||
):
|
|
||||||
embeddings = await asyncio.to_thread(embedding_model.embed, contents)
|
embeddings = await asyncio.to_thread(embedding_model.embed, contents)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
@ -170,14 +175,12 @@ class LMStudioClient:
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
json_schema: Optional[JsonSchemaResponseFormat] = None,
|
json_schema: Optional[JsonSchemaResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> Union[
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
||||||
]:
|
|
||||||
chat = self._convert_message_list_to_lmstudio_chat(messages)
|
chat = self._convert_message_list_to_lmstudio_chat(messages)
|
||||||
config = self._get_completion_config_from_params(sampling_params)
|
config = self._get_completion_config_from_params(sampling_params)
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator() -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
prediction_stream = await asyncio.to_thread(
|
prediction_stream = await asyncio.to_thread(
|
||||||
llm.respond_stream,
|
llm.respond_stream,
|
||||||
history=chat,
|
history=chat,
|
||||||
|
@ -191,7 +194,7 @@ class LMStudioClient:
|
||||||
delta=TextDelta(text=""),
|
delta=TextDelta(text=""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in self._async_iterate(prediction_stream):
|
async for chunk in self._async_iterate(prediction_stream):
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -225,9 +228,7 @@ class LMStudioClient:
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
||||||
]:
|
|
||||||
try:
|
try:
|
||||||
model_key = llm.get_info().model_key
|
model_key = llm.get_info().model_key
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -240,17 +241,15 @@ class LMStudioClient:
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
rest_request = await self._convert_request_to_rest_call(request)
|
rest_request = await self._convert_request_to_rest_call(request)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
try:
|
try:
|
||||||
stream = await self.openai_client.chat.completions.create(**rest_request)
|
stream = await self.openai_client.chat.completions.create(**rest_request)
|
||||||
return convert_openai_chat_completion_stream(
|
return convert_openai_chat_completion_stream(stream, enable_incremental_tool_calls=True)
|
||||||
stream, enable_incremental_tool_calls=True
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._log_error(e, "streaming tool calling")
|
self._log_error(e, "streaming tool calling")
|
||||||
return self._create_fallback_chat_stream()
|
return self._create_fallback_chat_stream()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.openai_client.chat.completions.create(**rest_request)
|
response = await self.openai_client.chat.completions.create(**rest_request)
|
||||||
if response:
|
if response:
|
||||||
|
@ -280,9 +279,7 @@ class LMStudioClient:
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
||||||
]:
|
|
||||||
if tools is None or len(tools) == 0:
|
if tools is None or len(tools) == 0:
|
||||||
return await self._llm_respond_non_tools(
|
return await self._llm_respond_non_tools(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
@ -313,7 +310,7 @@ class LMStudioClient:
|
||||||
config = self._get_completion_config_from_params(sampling_params)
|
config = self._get_completion_config_from_params(sampling_params)
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator() -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||||
try:
|
try:
|
||||||
prediction_stream = await asyncio.to_thread(
|
prediction_stream = await asyncio.to_thread(
|
||||||
llm.complete_stream,
|
llm.complete_stream,
|
||||||
|
@ -341,7 +338,7 @@ class LMStudioClient:
|
||||||
config=config,
|
config=config,
|
||||||
response_format=json_schema,
|
response_format=json_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we have a JSON schema, ensure the response is valid JSON
|
# If we have a JSON schema, ensure the response is valid JSON
|
||||||
if json_schema is not None:
|
if json_schema is not None:
|
||||||
valid_json = self._handle_json_extraction(response.content, "completion response")
|
valid_json = self._handle_json_extraction(response.content, "completion response")
|
||||||
|
@ -351,7 +348,7 @@ class LMStudioClient:
|
||||||
stop_reason=self._get_stop_reason(response.stats.stop_reason),
|
stop_reason=self._get_stop_reason(response.stats.stop_reason),
|
||||||
)
|
)
|
||||||
# If we couldn't extract valid JSON, continue with the original content
|
# If we couldn't extract valid JSON, continue with the original content
|
||||||
|
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
content=response.content,
|
content=response.content,
|
||||||
stop_reason=self._get_stop_reason(response.stats.stop_reason),
|
stop_reason=self._get_stop_reason(response.stats.stop_reason),
|
||||||
|
@ -361,15 +358,11 @@ class LMStudioClient:
|
||||||
# Return a fallback response with an error message
|
# Return a fallback response with an error message
|
||||||
return self._create_fallback_completion_response()
|
return self._create_fallback_completion_response()
|
||||||
|
|
||||||
def _convert_message_list_to_lmstudio_chat(
|
def _convert_message_list_to_lmstudio_chat(self, messages: List[Message]) -> lms.Chat:
|
||||||
self, messages: List[Message]
|
|
||||||
) -> lms.Chat:
|
|
||||||
chat = lms.Chat()
|
chat = lms.Chat()
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if content_has_media(message.content):
|
if content_has_media(message.content):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Media content is not supported in LMStudio messages")
|
||||||
"Media content is not supported in LMStudio messages"
|
|
||||||
)
|
|
||||||
if message.role == "user":
|
if message.role == "user":
|
||||||
chat.add_user_message(interleaved_content_as_str(message.content))
|
chat.add_user_message(interleaved_content_as_str(message.content))
|
||||||
elif message.role == "system":
|
elif message.role == "system":
|
||||||
|
@ -380,9 +373,7 @@ class LMStudioClient:
|
||||||
raise ValueError(f"Unsupported message role: {message.role}")
|
raise ValueError(f"Unsupported message role: {message.role}")
|
||||||
return chat
|
return chat
|
||||||
|
|
||||||
def _convert_prediction_to_chat_response(
|
def _convert_prediction_to_chat_response(self, result: lms.PredictionResult) -> ChatCompletionResponse:
|
||||||
self, result: lms.PredictionResult
|
|
||||||
) -> ChatCompletionResponse:
|
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
content=result.content,
|
content=result.content,
|
||||||
|
@ -415,11 +406,7 @@ class LMStudioClient:
|
||||||
options.update(
|
options.update(
|
||||||
{
|
{
|
||||||
"maxTokens": params.max_tokens if params.max_tokens != 0 else None,
|
"maxTokens": params.max_tokens if params.max_tokens != 0 else None,
|
||||||
"repetitionPenalty": (
|
"repetitionPenalty": (params.repetition_penalty if params.repetition_penalty != 0 else None),
|
||||||
params.repetition_penalty
|
|
||||||
if params.repetition_penalty != 0
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return options
|
return options
|
||||||
|
@ -449,32 +436,30 @@ class LMStudioClient:
|
||||||
break
|
break
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
async def _convert_request_to_rest_call(
|
async def _convert_request_to_rest_call(self, request: ChatCompletionRequest) -> dict:
|
||||||
self, request: ChatCompletionRequest
|
|
||||||
) -> dict:
|
|
||||||
compatible_request = self._convert_sampling_params(request.sampling_params)
|
compatible_request = self._convert_sampling_params(request.sampling_params)
|
||||||
compatible_request["model"] = request.model
|
compatible_request["model"] = request.model
|
||||||
compatible_request["messages"] = [
|
compatible_request["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
|
||||||
await convert_message_to_openai_dict_new(m) for m in request.messages
|
|
||||||
]
|
|
||||||
if request.response_format:
|
if request.response_format:
|
||||||
compatible_request["response_format"] = {
|
if isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||||
"type": "json_schema",
|
compatible_request["response_format"] = {
|
||||||
"json_schema": request.response_format.json_schema,
|
"type": "json_schema",
|
||||||
}
|
"json_schema": request.response_format.json_schema,
|
||||||
|
}
|
||||||
|
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||||
|
compatible_request["response_format"] = {
|
||||||
|
"type": "grammar",
|
||||||
|
"bnf": request.response_format.bnf,
|
||||||
|
}
|
||||||
if request.tools is not None:
|
if request.tools is not None:
|
||||||
compatible_request["tools"] = [
|
compatible_request["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||||
convert_tooldef_to_openai_tool(tool) for tool in request.tools
|
|
||||||
]
|
|
||||||
compatible_request["logprobs"] = False
|
compatible_request["logprobs"] = False
|
||||||
compatible_request["stream"] = request.stream
|
compatible_request["stream"] = request.stream
|
||||||
compatible_request["extra_headers"] = {
|
compatible_request["extra_headers"] = {b"User-Agent": b"llama-stack: lmstudio-inference-adapter"}
|
||||||
b"User-Agent": b"llama-stack: lmstudio-inference-adapter"
|
|
||||||
}
|
|
||||||
return compatible_request
|
return compatible_request
|
||||||
|
|
||||||
def _convert_sampling_params(self, sampling_params: Optional[SamplingParams]) -> dict:
|
def _convert_sampling_params(self, sampling_params: Optional[SamplingParams]) -> dict:
|
||||||
params = {}
|
params: dict[str, Any] = {}
|
||||||
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
return params
|
return params
|
||||||
|
|
|
@ -14,7 +14,9 @@ from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
|
GrammarResponseFormat,
|
||||||
Inference,
|
Inference,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -29,7 +31,6 @@ from llama_stack.apis.inference.inference import (
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
ResponseFormatType,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient
|
from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient
|
||||||
|
@ -50,6 +51,18 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def client(self) -> LMStudioClient:
|
def client(self) -> LMStudioClient:
|
||||||
return LMStudioClient(url=self.url)
|
return LMStudioClient(url=self.url)
|
||||||
|
|
||||||
|
async def batch_chat_completion(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("Batch chat completion not supported by LM Studio Provider")
|
||||||
|
|
||||||
|
async def batch_completion(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("Batch completion not supported by LM Studio Provider")
|
||||||
|
|
||||||
|
async def openai_chat_completion(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("OpenAI chat completion not supported by LM Studio Provider")
|
||||||
|
|
||||||
|
async def openai_completion(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("OpenAI completion not supported by LM Studio Provider")
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -71,9 +84,12 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
"Media content not supported in embedding model"
|
"Media content not supported in embedding model"
|
||||||
)
|
)
|
||||||
|
if self.model_store is None:
|
||||||
|
raise ValueError("ModelStore is not initialized")
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = await self.client.get_embedding_model(model.provider_model_id)
|
embedding_model = await self.client.get_embedding_model(model.provider_model_id)
|
||||||
embeddings = await self.client.embed(embedding_model, contents)
|
string_contents = [item.text if hasattr(item, "text") else str(item) for item in contents]
|
||||||
|
embeddings = await self.client.embed(embedding_model, string_contents)
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -81,26 +97,31 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = None, # Default value changed from ToolChoice.auto to None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[
|
||||||
|
Union[JsonSchemaResponseFormat, GrammarResponseFormat]
|
||||||
|
] = None, # Moved and type changed
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
|
if self.model_store is None:
|
||||||
|
raise ValueError("ModelStore is not initialized")
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
llm = await self.client.get_llm(model.provider_model_id)
|
llm = await self.client.get_llm(model.provider_model_id)
|
||||||
|
|
||||||
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
|
json_schema_format = response_format if isinstance(response_format, JsonSchemaResponseFormat) else None
|
||||||
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
|
if response_format is not None and not isinstance(response_format, JsonSchemaResponseFormat):
|
||||||
json_schema = response_format.json_schema if response_format else None
|
raise ValueError(
|
||||||
|
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
|
||||||
|
)
|
||||||
return await self.client.llm_respond(
|
return await self.client.llm_respond(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
json_schema=json_schema,
|
json_schema=json_schema_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -115,13 +136,16 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None, # Skip this for now
|
logprobs: Optional[LogProbConfig] = None, # Skip this for now
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||||
|
if self.model_store is None:
|
||||||
|
raise ValueError("ModelStore is not initialized")
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
llm = await self.client.get_llm(model.provider_model_id)
|
llm = await self.client.get_llm(model.provider_model_id)
|
||||||
if content_has_media(content):
|
if content_has_media(content):
|
||||||
raise NotImplementedError("Media content not supported in LM Studio Provider")
|
raise NotImplementedError("Media content not supported in LM Studio Provider")
|
||||||
|
|
||||||
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
|
if not isinstance(response_format, JsonSchemaResponseFormat):
|
||||||
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
|
raise ValueError(
|
||||||
json_schema = response_format.json_schema if response_format else None
|
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
|
||||||
|
)
|
||||||
|
|
||||||
return await self.client.llm_completion(llm, content, sampling_params, json_schema, stream)
|
return await self.client.llm_completion(llm, content, sampling_params, response_format, stream)
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.models.models import ModelType
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId
|
from llama_stack.models.llama.sku_list import CoreModelId
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ProviderModelEntry,
|
ProviderModelEntry,
|
||||||
)
|
)
|
||||||
|
@ -63,9 +63,7 @@ MODEL_ENTRIES = [
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
model_id="all-MiniLM-L6-v2",
|
|
||||||
provider_model_id="all-minilm-l6-v2",
|
provider_model_id="all-minilm-l6-v2",
|
||||||
provider_id="lmstudio",
|
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={
|
metadata={
|
||||||
"embedding_dimension": 384,
|
"embedding_dimension": 384,
|
||||||
|
|
|
@ -351,10 +351,12 @@
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"emoji",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
"lmstudio",
|
"lmstudio",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
@ -367,6 +369,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -374,6 +377,7 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
|
"tree_sitter",
|
||||||
"uvicorn"
|
"uvicorn"
|
||||||
],
|
],
|
||||||
"meta-reference-gpu": [
|
"meta-reference-gpu": [
|
||||||
|
|
|
@ -24,7 +24,6 @@ distribution_spec:
|
||||||
telemetry:
|
telemetry:
|
||||||
- inline::meta-reference
|
- inline::meta-reference
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- remote::brave-search
|
|
||||||
- remote::tavily-search
|
- remote::tavily-search
|
||||||
- inline::code-interpreter
|
- inline::code-interpreter
|
||||||
- inline::rag-runtime
|
- inline::rag-runtime
|
||||||
|
|
|
@ -75,7 +75,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/lmstudio/trace_store.db}
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/lmstudio/trace_store.db}
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
|
@ -125,9 +125,9 @@ models:
|
||||||
provider_model_id: llama-3.2-3b-instruct
|
provider_model_id: llama-3.2-3b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: llama-3.2-70b-instruct
|
model_id: llama-3.3-70b-instruct
|
||||||
provider_id: lmstudio
|
provider_id: lmstudio
|
||||||
provider_model_id: llama-3.2-70b-instruct
|
provider_model_id: llama-3.3-70b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 768
|
embedding_dimension: 768
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue