mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-29 04:21:59 +00:00
278 lines
11 KiB
Python
278 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
EmbeddingsResponse,
|
|
EmbeddingTaskType,
|
|
GrammarResponseFormat,
|
|
Inference,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
Message,
|
|
OpenAIChatCompletion,
|
|
OpenAIChatCompletionChunk,
|
|
OpenAICompletion,
|
|
OpenAIMessageParam,
|
|
OpenAIResponseFormatParam,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
TextTruncation,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.apis.inference.inference import (
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
)
|
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
|
from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient
|
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
content_has_media,
|
|
)
|
|
|
|
from .models import MODEL_ENTRIES
|
|
|
|
|
|
class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|
def __init__(self, url: str) -> None:
|
|
self.url = url
|
|
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
|
|
|
@property
|
|
def client(self) -> LMStudioClient:
|
|
return LMStudioClient(url=self.url)
|
|
|
|
async def batch_completion(
|
|
self,
|
|
model_id: str,
|
|
content_batch: List[InterleavedContent],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
):
|
|
raise NotImplementedError("Batch completion is not supported by LM Studio Provider")
|
|
|
|
async def batch_chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages_batch: List[List[Message]],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
):
|
|
raise NotImplementedError("Batch completion is not supported by LM Studio Provider")
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
model: str,
|
|
messages: List[OpenAIMessageParam],
|
|
frequency_penalty: Optional[float] = None,
|
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
logit_bias: Optional[Dict[str, float]] = None,
|
|
logprobs: Optional[bool] = None,
|
|
max_completion_tokens: Optional[int] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
parallel_tool_calls: Optional[bool] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
response_format: Optional[OpenAIResponseFormatParam] = None,
|
|
seed: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[Dict[str, Any]] = None,
|
|
temperature: Optional[float] = None,
|
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
top_logprobs: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
user: Optional[str] = None,
|
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
|
if self.model_store is None:
|
|
raise ValueError("ModelStore is not initialized")
|
|
model_obj = await self.model_store.get_model(model)
|
|
params = {
|
|
k: v
|
|
for k, v in {
|
|
"model": model_obj.provider_resource_id,
|
|
"messages": messages,
|
|
"frequency_penalty": frequency_penalty,
|
|
"function_call": function_call,
|
|
"functions": functions,
|
|
"logit_bias": logit_bias,
|
|
"logprobs": logprobs,
|
|
"max_completion_tokens": max_completion_tokens,
|
|
"max_tokens": max_tokens,
|
|
"n": n,
|
|
"parallel_tool_calls": parallel_tool_calls,
|
|
"presence_penalty": presence_penalty,
|
|
"response_format": response_format,
|
|
"seed": seed,
|
|
"stop": stop,
|
|
"stream": stream,
|
|
"stream_options": stream_options,
|
|
"temperature": temperature,
|
|
"tool_choice": tool_choice,
|
|
"tools": tools,
|
|
"top_logprobs": top_logprobs,
|
|
"top_p": top_p,
|
|
"user": user,
|
|
}.items()
|
|
if v is not None
|
|
}
|
|
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
|
|
|
async def openai_completion(
|
|
self,
|
|
model: str,
|
|
prompt: Union[str, List[str], List[int], List[List[int]]],
|
|
best_of: Optional[int] = None,
|
|
echo: Optional[bool] = None,
|
|
frequency_penalty: Optional[float] = None,
|
|
logit_bias: Optional[Dict[str, float]] = None,
|
|
logprobs: Optional[bool] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
seed: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[Dict[str, Any]] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
user: Optional[str] = None,
|
|
guided_choice: Optional[List[str]] = None,
|
|
prompt_logprobs: Optional[int] = None,
|
|
) -> OpenAICompletion:
|
|
if not isinstance(prompt, str):
|
|
raise ValueError("LM Studio does not support non-string prompts for completion")
|
|
if self.model_store is None:
|
|
raise ValueError("ModelStore is not initialized")
|
|
model_obj = await self.model_store.get_model(model)
|
|
params = {
|
|
k: v
|
|
for k, v in {
|
|
"model": model_obj.provider_resource_id,
|
|
"prompt": prompt,
|
|
"best_of": best_of,
|
|
"echo": echo,
|
|
"frequency_penalty": frequency_penalty,
|
|
"logit_bias": logit_bias,
|
|
"logprobs": logprobs,
|
|
"max_tokens": max_tokens,
|
|
"n": n,
|
|
"presence_penalty": presence_penalty,
|
|
"seed": seed,
|
|
"stop": stop,
|
|
"stream": stream,
|
|
"stream_options": stream_options,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"user": user,
|
|
}.items()
|
|
if v is not None
|
|
}
|
|
return await self.openai_client.completions.create(**params) # type: ignore
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def register_model(self, model):
|
|
await self.register_helper.register_model(model)
|
|
return model
|
|
|
|
async def unregister_model(self, model_id):
|
|
pass
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[str] | List[InterleavedContentItem],
|
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
|
output_dimension: Optional[int] = None,
|
|
task_type: Optional[EmbeddingTaskType] = None,
|
|
) -> EmbeddingsResponse:
|
|
assert all(not content_has_media(content) for content in contents), (
|
|
"Media content not supported in embedding model"
|
|
)
|
|
if self.model_store is None:
|
|
raise ValueError("ModelStore is not initialized")
|
|
model = await self.model_store.get_model(model_id)
|
|
embedding_model = await self.client.get_embedding_model(model.provider_model_id)
|
|
string_contents = [item.text if hasattr(item, "text") else str(item) for item in contents]
|
|
embeddings = await self.client.embed(embedding_model, string_contents)
|
|
return EmbeddingsResponse(embeddings=embeddings)
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = None, # Default value changed from ToolChoice.auto to None
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
response_format: Optional[
|
|
Union[JsonSchemaResponseFormat, GrammarResponseFormat]
|
|
] = None, # Moved and type changed
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
|
if self.model_store is None:
|
|
raise ValueError("ModelStore is not initialized")
|
|
model = await self.model_store.get_model(model_id)
|
|
llm = await self.client.get_llm(model.provider_model_id)
|
|
|
|
json_schema_format = response_format if isinstance(response_format, JsonSchemaResponseFormat) else None
|
|
if response_format is not None and not isinstance(response_format, JsonSchemaResponseFormat):
|
|
raise ValueError(
|
|
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
|
|
)
|
|
return await self.client.llm_respond(
|
|
llm=llm,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
json_schema=json_schema_format,
|
|
stream=stream,
|
|
tool_config=tool_config,
|
|
tools=tools,
|
|
)
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None, # Skip this for now
|
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
|
if self.model_store is None:
|
|
raise ValueError("ModelStore is not initialized")
|
|
model = await self.model_store.get_model(model_id)
|
|
llm = await self.client.get_llm(model.provider_model_id)
|
|
if content_has_media(content):
|
|
raise NotImplementedError("Media content not supported in LM Studio Provider")
|
|
|
|
if not isinstance(response_format, JsonSchemaResponseFormat):
|
|
raise ValueError(
|
|
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
|
|
)
|
|
|
|
return await self.client.llm_completion(llm, content, sampling_params, response_format, stream)
|