forked from phoenix-oss/llama-stack-mirror
		
	Enable remote::vllm (#384)
* Enable remote::vllm * Kill the giant list of hard coded models
This commit is contained in:
		
							parent
							
								
									093c9f1987
								
							
						
					
					
						commit
						b10e9f46bb
					
				
					 5 changed files with 80 additions and 53 deletions
				
			
		|  | @ -4,12 +4,15 @@ | |||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from .config import VLLMImplConfig | ||||
| from .vllm import VLLMInferenceAdapter | ||||
| from .config import VLLMInferenceAdapterConfig | ||||
| 
 | ||||
| 
 | ||||
| async def get_adapter_impl(config: VLLMImplConfig, _deps): | ||||
|     assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}" | ||||
| async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps): | ||||
|     from .vllm import VLLMInferenceAdapter | ||||
| 
 | ||||
|     assert isinstance( | ||||
|         config, VLLMInferenceAdapterConfig | ||||
|     ), f"Unexpected config type: {type(config)}" | ||||
|     impl = VLLMInferenceAdapter(config) | ||||
|     await impl.initialize() | ||||
|     return impl | ||||
|  |  | |||
|  | @ -11,12 +11,16 @@ from pydantic import BaseModel, Field | |||
| 
 | ||||
| 
 | ||||
| @json_schema_type | ||||
| class VLLMImplConfig(BaseModel): | ||||
| class VLLMInferenceAdapterConfig(BaseModel): | ||||
|     url: Optional[str] = Field( | ||||
|         default=None, | ||||
|         description="The URL for the vLLM model serving endpoint", | ||||
|     ) | ||||
|     max_tokens: int = Field( | ||||
|         default=4096, | ||||
|         description="Maximum number of tokens to generate.", | ||||
|     ) | ||||
|     api_token: Optional[str] = Field( | ||||
|         default=None, | ||||
|         default="fake", | ||||
|         description="The API token", | ||||
|     ) | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ from typing import AsyncGenerator | |||
| from llama_models.llama3.api.chat_format import ChatFormat | ||||
| from llama_models.llama3.api.datatypes import Message | ||||
| from llama_models.llama3.api.tokenizer import Tokenizer | ||||
| from llama_models.sku_list import all_registered_models, resolve_model | ||||
| 
 | ||||
| from openai import OpenAI | ||||
| 
 | ||||
|  | @ -23,42 +24,19 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( | |||
|     chat_completion_request_to_prompt, | ||||
| ) | ||||
| 
 | ||||
| from .config import VLLMImplConfig | ||||
| 
 | ||||
| VLLM_SUPPORTED_MODELS = { | ||||
|     "Llama3.1-8B": "meta-llama/Llama-3.1-8B", | ||||
|     "Llama3.1-70B": "meta-llama/Llama-3.1-70B", | ||||
|     "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", | ||||
|     "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", | ||||
|     "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", | ||||
|     "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", | ||||
|     "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", | ||||
|     "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", | ||||
|     "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", | ||||
|     "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", | ||||
|     "Llama3.2-1B": "meta-llama/Llama-3.2-1B", | ||||
|     "Llama3.2-3B": "meta-llama/Llama-3.2-3B", | ||||
|     "Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision", | ||||
|     "Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision", | ||||
|     "Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct", | ||||
|     "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", | ||||
|     "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", | ||||
|     "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", | ||||
|     "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", | ||||
|     "Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4", | ||||
|     "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", | ||||
|     "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", | ||||
|     "Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8", | ||||
|     "Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M", | ||||
|     "Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B", | ||||
| } | ||||
| from .config import VLLMInferenceAdapterConfig | ||||
| 
 | ||||
| 
 | ||||
| class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): | ||||
|     def __init__(self, config: VLLMImplConfig) -> None: | ||||
|     def __init__(self, config: VLLMInferenceAdapterConfig) -> None: | ||||
|         self.config = config | ||||
|         self.formatter = ChatFormat(Tokenizer.get_instance()) | ||||
|         self.client = None | ||||
|         self.huggingface_repo_to_llama_model_id = { | ||||
|             model.huggingface_repo: model.descriptor() | ||||
|             for model in all_registered_models() | ||||
|             if model.huggingface_repo | ||||
|         } | ||||
| 
 | ||||
|     async def initialize(self) -> None: | ||||
|         self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) | ||||
|  | @ -70,10 +48,21 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): | |||
|         pass | ||||
| 
 | ||||
|     async def list_models(self) -> List[ModelDef]: | ||||
|         return [ | ||||
|             ModelDef(identifier=model.id, llama_model=model.id) | ||||
|             for model in self.client.models.list() | ||||
|         ] | ||||
|         models = [] | ||||
|         for model in self.client.models.list(): | ||||
|             repo = model.id | ||||
|             if repo not in self.huggingface_repo_to_llama_model_id: | ||||
|                 print(f"Unknown model served by vllm: {repo}") | ||||
|                 continue | ||||
| 
 | ||||
|             identifier = self.huggingface_repo_to_llama_model_id[repo] | ||||
|             models.append( | ||||
|                 ModelDef( | ||||
|                     identifier=identifier, | ||||
|                     llama_model=identifier, | ||||
|                 ) | ||||
|             ) | ||||
|         return models | ||||
| 
 | ||||
|     async def completion( | ||||
|         self, | ||||
|  | @ -118,7 +107,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): | |||
|     ) -> ChatCompletionResponse: | ||||
|         params = self._get_params(request) | ||||
|         r = client.completions.create(**params) | ||||
|         return process_chat_completion_response(request, r, self.formatter) | ||||
|         return process_chat_completion_response(r, self.formatter) | ||||
| 
 | ||||
|     async def _stream_chat_completion( | ||||
|         self, request: ChatCompletionRequest, client: OpenAI | ||||
|  | @ -139,11 +128,19 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): | |||
|             yield chunk | ||||
| 
 | ||||
|     def _get_params(self, request: ChatCompletionRequest) -> dict: | ||||
|         options = get_sampling_options(request.sampling_params) | ||||
|         if "max_tokens" not in options: | ||||
|             options["max_tokens"] = self.config.max_tokens | ||||
| 
 | ||||
|         model = resolve_model(request.model) | ||||
|         if model is None: | ||||
|             raise ValueError(f"Unknown model: {request.model}") | ||||
| 
 | ||||
|         return { | ||||
|             "model": VLLM_SUPPORTED_MODELS[request.model], | ||||
|             "model": model.huggingface_repo, | ||||
|             "prompt": chat_completion_request_to_prompt(request, self.formatter), | ||||
|             "stream": request.stream, | ||||
|             **get_sampling_options(request.sampling_params), | ||||
|             **options, | ||||
|         } | ||||
| 
 | ||||
|     async def embeddings( | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue