mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 09:05:37 +00:00 
			
		
		
		
	This flips #2823 and #2805 by making the Stack periodically query the providers for models rather than the providers going behind the back and calling "register" on to the registry themselves. This also adds support for model listing for all other providers via `ModelRegistryHelper`. Once this is done, we do not need to manually list or register models via `run.yaml` and it will remove both noise and annoyance (setting `INFERENCE_MODEL` environment variables, for example) from the new user experience. In addition, it adds a configuration variable `allowed_models` which can be used to optionally restrict the set of models exposed from a provider.
		
			
				
	
	
		
			387 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			387 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| from collections.abc import AsyncGenerator, AsyncIterator
 | |
| from typing import Any
 | |
| 
 | |
| from openai import AsyncOpenAI
 | |
| from together import AsyncTogether
 | |
| 
 | |
| from llama_stack.apis.common.content_types import (
 | |
|     InterleavedContent,
 | |
|     InterleavedContentItem,
 | |
| )
 | |
| from llama_stack.apis.inference import (
 | |
|     ChatCompletionRequest,
 | |
|     ChatCompletionResponse,
 | |
|     CompletionRequest,
 | |
|     EmbeddingsResponse,
 | |
|     EmbeddingTaskType,
 | |
|     Inference,
 | |
|     LogProbConfig,
 | |
|     Message,
 | |
|     OpenAIChatCompletion,
 | |
|     OpenAIChatCompletionChunk,
 | |
|     OpenAICompletion,
 | |
|     OpenAIEmbeddingsResponse,
 | |
|     OpenAIMessageParam,
 | |
|     OpenAIResponseFormatParam,
 | |
|     ResponseFormat,
 | |
|     ResponseFormatType,
 | |
|     SamplingParams,
 | |
|     TextTruncation,
 | |
|     ToolChoice,
 | |
|     ToolConfig,
 | |
|     ToolDefinition,
 | |
|     ToolPromptFormat,
 | |
| )
 | |
| from llama_stack.distribution.request_headers import NeedsRequestProviderData
 | |
| from llama_stack.log import get_logger
 | |
| from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
 | |
| from llama_stack.providers.utils.inference.openai_compat import (
 | |
|     convert_message_to_openai_dict,
 | |
|     get_sampling_options,
 | |
|     prepare_openai_completion_params,
 | |
|     process_chat_completion_response,
 | |
|     process_chat_completion_stream_response,
 | |
|     process_completion_response,
 | |
|     process_completion_stream_response,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.prompt_adapter import (
 | |
|     chat_completion_request_to_prompt,
 | |
|     completion_request_to_prompt,
 | |
|     content_has_media,
 | |
|     interleaved_content_as_str,
 | |
|     request_has_media,
 | |
| )
 | |
| 
 | |
| from .config import TogetherImplConfig
 | |
| from .models import MODEL_ENTRIES
 | |
| 
 | |
| logger = get_logger(name=__name__, category="inference")
 | |
| 
 | |
| 
 | |
| class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
 | |
|     def __init__(self, config: TogetherImplConfig) -> None:
 | |
|         ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
 | |
|         self.config = config
 | |
| 
 | |
|     async def initialize(self) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def shutdown(self) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         content: InterleavedContent,
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         stream: bool | None = False,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|     ) -> AsyncGenerator:
 | |
|         if sampling_params is None:
 | |
|             sampling_params = SamplingParams()
 | |
|         model = await self.model_store.get_model(model_id)
 | |
|         request = CompletionRequest(
 | |
|             model=model.provider_resource_id,
 | |
|             content=content,
 | |
|             sampling_params=sampling_params,
 | |
|             response_format=response_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|         )
 | |
|         if stream:
 | |
|             return self._stream_completion(request)
 | |
|         else:
 | |
|             return await self._nonstream_completion(request)
 | |
| 
 | |
|     def _get_client(self) -> AsyncTogether:
 | |
|         together_api_key = None
 | |
|         config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
 | |
|         if config_api_key:
 | |
|             together_api_key = config_api_key
 | |
|         else:
 | |
|             provider_data = self.get_request_provider_data()
 | |
|             if provider_data is None or not provider_data.together_api_key:
 | |
|                 raise ValueError(
 | |
|                     'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
 | |
|                 )
 | |
|             together_api_key = provider_data.together_api_key
 | |
|         return AsyncTogether(api_key=together_api_key)
 | |
| 
 | |
|     def _get_openai_client(self) -> AsyncOpenAI:
 | |
|         together_client = self._get_client().client
 | |
|         return AsyncOpenAI(
 | |
|             base_url=together_client.base_url,
 | |
|             api_key=together_client.api_key,
 | |
|         )
 | |
| 
 | |
|     async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
 | |
|         params = await self._get_params(request)
 | |
|         client = self._get_client()
 | |
|         r = await client.completions.create(**params)
 | |
|         return process_completion_response(r)
 | |
| 
 | |
|     async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
 | |
|         params = await self._get_params(request)
 | |
|         client = self._get_client()
 | |
|         stream = await client.completions.create(**params)
 | |
|         async for chunk in process_completion_stream_response(stream):
 | |
|             yield chunk
 | |
| 
 | |
|     def _build_options(
 | |
|         self,
 | |
|         sampling_params: SamplingParams | None,
 | |
|         logprobs: LogProbConfig | None,
 | |
|         fmt: ResponseFormat,
 | |
|     ) -> dict:
 | |
|         options = get_sampling_options(sampling_params)
 | |
|         if fmt:
 | |
|             if fmt.type == ResponseFormatType.json_schema.value:
 | |
|                 options["response_format"] = {
 | |
|                     "type": "json_object",
 | |
|                     "schema": fmt.json_schema,
 | |
|                 }
 | |
|             elif fmt.type == ResponseFormatType.grammar.value:
 | |
|                 raise NotImplementedError("Grammar response format not supported yet")
 | |
|             else:
 | |
|                 raise ValueError(f"Unknown response format {fmt.type}")
 | |
| 
 | |
|         if logprobs and logprobs.top_k:
 | |
|             if logprobs.top_k != 1:
 | |
|                 raise ValueError(
 | |
|                     f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided",
 | |
|                 )
 | |
|             options["logprobs"] = 1
 | |
| 
 | |
|         return options
 | |
| 
 | |
|     async def chat_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         messages: list[Message],
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         tools: list[ToolDefinition] | None = None,
 | |
|         tool_choice: ToolChoice | None = ToolChoice.auto,
 | |
|         tool_prompt_format: ToolPromptFormat | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         stream: bool | None = False,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|         tool_config: ToolConfig | None = None,
 | |
|     ) -> AsyncGenerator:
 | |
|         if sampling_params is None:
 | |
|             sampling_params = SamplingParams()
 | |
|         model = await self.model_store.get_model(model_id)
 | |
|         request = ChatCompletionRequest(
 | |
|             model=model.provider_resource_id,
 | |
|             messages=messages,
 | |
|             sampling_params=sampling_params,
 | |
|             tools=tools or [],
 | |
|             response_format=response_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|             tool_config=tool_config,
 | |
|         )
 | |
| 
 | |
|         if stream:
 | |
|             return self._stream_chat_completion(request)
 | |
|         else:
 | |
|             return await self._nonstream_chat_completion(request)
 | |
| 
 | |
|     async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
 | |
|         params = await self._get_params(request)
 | |
|         client = self._get_client()
 | |
|         if "messages" in params:
 | |
|             r = await client.chat.completions.create(**params)
 | |
|         else:
 | |
|             r = await client.completions.create(**params)
 | |
|         return process_chat_completion_response(r, request)
 | |
| 
 | |
|     async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
 | |
|         params = await self._get_params(request)
 | |
|         client = self._get_client()
 | |
|         if "messages" in params:
 | |
|             stream = await client.chat.completions.create(**params)
 | |
|         else:
 | |
|             stream = await client.completions.create(**params)
 | |
| 
 | |
|         async for chunk in process_chat_completion_stream_response(stream, request):
 | |
|             yield chunk
 | |
| 
 | |
|     async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
 | |
|         input_dict = {}
 | |
|         media_present = request_has_media(request)
 | |
|         llama_model = self.get_llama_model(request.model)
 | |
|         if isinstance(request, ChatCompletionRequest):
 | |
|             if media_present or not llama_model:
 | |
|                 input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
 | |
|             else:
 | |
|                 input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
 | |
|         else:
 | |
|             assert not media_present, "Together does not support media for Completion requests"
 | |
|             input_dict["prompt"] = await completion_request_to_prompt(request)
 | |
| 
 | |
|         params = {
 | |
|             "model": request.model,
 | |
|             **input_dict,
 | |
|             "stream": request.stream,
 | |
|             **self._build_options(request.sampling_params, request.logprobs, request.response_format),
 | |
|         }
 | |
|         logger.debug(f"params to together: {params}")
 | |
|         return params
 | |
| 
 | |
|     async def embeddings(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         contents: list[str] | list[InterleavedContentItem],
 | |
|         text_truncation: TextTruncation | None = TextTruncation.none,
 | |
|         output_dimension: int | None = None,
 | |
|         task_type: EmbeddingTaskType | None = None,
 | |
|     ) -> EmbeddingsResponse:
 | |
|         model = await self.model_store.get_model(model_id)
 | |
|         assert all(not content_has_media(content) for content in contents), (
 | |
|             "Together does not support media for embeddings"
 | |
|         )
 | |
|         client = self._get_client()
 | |
|         r = await client.embeddings.create(
 | |
|             model=model.provider_resource_id,
 | |
|             input=[interleaved_content_as_str(content) for content in contents],
 | |
|         )
 | |
|         embeddings = [item.embedding for item in r.data]
 | |
|         return EmbeddingsResponse(embeddings=embeddings)
 | |
| 
 | |
|     async def openai_embeddings(
 | |
|         self,
 | |
|         model: str,
 | |
|         input: str | list[str],
 | |
|         encoding_format: str | None = "float",
 | |
|         dimensions: int | None = None,
 | |
|         user: str | None = None,
 | |
|     ) -> OpenAIEmbeddingsResponse:
 | |
|         raise NotImplementedError()
 | |
| 
 | |
|     async def openai_completion(
 | |
|         self,
 | |
|         model: str,
 | |
|         prompt: str | list[str] | list[int] | list[list[int]],
 | |
|         best_of: int | None = None,
 | |
|         echo: bool | None = None,
 | |
|         frequency_penalty: float | None = None,
 | |
|         logit_bias: dict[str, float] | None = None,
 | |
|         logprobs: bool | None = None,
 | |
|         max_tokens: int | None = None,
 | |
|         n: int | None = None,
 | |
|         presence_penalty: float | None = None,
 | |
|         seed: int | None = None,
 | |
|         stop: str | list[str] | None = None,
 | |
|         stream: bool | None = None,
 | |
|         stream_options: dict[str, Any] | None = None,
 | |
|         temperature: float | None = None,
 | |
|         top_p: float | None = None,
 | |
|         user: str | None = None,
 | |
|         guided_choice: list[str] | None = None,
 | |
|         prompt_logprobs: int | None = None,
 | |
|         suffix: str | None = None,
 | |
|     ) -> OpenAICompletion:
 | |
|         model_obj = await self.model_store.get_model(model)
 | |
|         params = await prepare_openai_completion_params(
 | |
|             model=model_obj.provider_resource_id,
 | |
|             prompt=prompt,
 | |
|             best_of=best_of,
 | |
|             echo=echo,
 | |
|             frequency_penalty=frequency_penalty,
 | |
|             logit_bias=logit_bias,
 | |
|             logprobs=logprobs,
 | |
|             max_tokens=max_tokens,
 | |
|             n=n,
 | |
|             presence_penalty=presence_penalty,
 | |
|             seed=seed,
 | |
|             stop=stop,
 | |
|             stream=stream,
 | |
|             stream_options=stream_options,
 | |
|             temperature=temperature,
 | |
|             top_p=top_p,
 | |
|             user=user,
 | |
|         )
 | |
|         return await self._get_openai_client().completions.create(**params)  # type: ignore
 | |
| 
 | |
|     async def openai_chat_completion(
 | |
|         self,
 | |
|         model: str,
 | |
|         messages: list[OpenAIMessageParam],
 | |
|         frequency_penalty: float | None = None,
 | |
|         function_call: str | dict[str, Any] | None = None,
 | |
|         functions: list[dict[str, Any]] | None = None,
 | |
|         logit_bias: dict[str, float] | None = None,
 | |
|         logprobs: bool | None = None,
 | |
|         max_completion_tokens: int | None = None,
 | |
|         max_tokens: int | None = None,
 | |
|         n: int | None = None,
 | |
|         parallel_tool_calls: bool | None = None,
 | |
|         presence_penalty: float | None = None,
 | |
|         response_format: OpenAIResponseFormatParam | None = None,
 | |
|         seed: int | None = None,
 | |
|         stop: str | list[str] | None = None,
 | |
|         stream: bool | None = None,
 | |
|         stream_options: dict[str, Any] | None = None,
 | |
|         temperature: float | None = None,
 | |
|         tool_choice: str | dict[str, Any] | None = None,
 | |
|         tools: list[dict[str, Any]] | None = None,
 | |
|         top_logprobs: int | None = None,
 | |
|         top_p: float | None = None,
 | |
|         user: str | None = None,
 | |
|     ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
 | |
|         model_obj = await self.model_store.get_model(model)
 | |
|         params = await prepare_openai_completion_params(
 | |
|             model=model_obj.provider_resource_id,
 | |
|             messages=messages,
 | |
|             frequency_penalty=frequency_penalty,
 | |
|             function_call=function_call,
 | |
|             functions=functions,
 | |
|             logit_bias=logit_bias,
 | |
|             logprobs=logprobs,
 | |
|             max_completion_tokens=max_completion_tokens,
 | |
|             max_tokens=max_tokens,
 | |
|             n=n,
 | |
|             parallel_tool_calls=parallel_tool_calls,
 | |
|             presence_penalty=presence_penalty,
 | |
|             response_format=response_format,
 | |
|             seed=seed,
 | |
|             stop=stop,
 | |
|             stream=stream,
 | |
|             stream_options=stream_options,
 | |
|             temperature=temperature,
 | |
|             tool_choice=tool_choice,
 | |
|             tools=tools,
 | |
|             top_logprobs=top_logprobs,
 | |
|             top_p=top_p,
 | |
|             user=user,
 | |
|         )
 | |
|         if params.get("stream", False):
 | |
|             return self._stream_openai_chat_completion(params)
 | |
|         return await self._get_openai_client().chat.completions.create(**params)  # type: ignore
 | |
| 
 | |
|     async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
 | |
|         # together.ai sometimes adds usage data to the stream, even if include_usage is False
 | |
|         # This causes an unexpected final chunk with empty choices array to be sent
 | |
|         # to clients that may not handle it gracefully.
 | |
|         include_usage = False
 | |
|         if params.get("stream_options", None):
 | |
|             include_usage = params["stream_options"].get("include_usage", False)
 | |
|         stream = await self._get_openai_client().chat.completions.create(**params)
 | |
| 
 | |
|         seen_finish_reason = False
 | |
|         async for chunk in stream:
 | |
|             # Final usage chunk with no choices that the user didn't request, so discard
 | |
|             if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
 | |
|                 break
 | |
|             yield chunk
 | |
|             for choice in chunk.choices:
 | |
|                 if choice.finish_reason:
 | |
|                     seen_finish_reason = True
 | |
|                     break
 |