大语言模型推理框架llama.cpp开发实战
<div id="container" data-v-1d7a5742="" data-element="root" contentScore="8369">译者 | 朱先忠
审校 | 重楼
本文首先探索当前热门的大语言模型推理框架llama.cpp的内部架构,然后使用此框架实现了一个基本形式的聊天程序。
简介
当前,llama.cpp框架以其简单性被业界广泛采用,彻底改变了LLM推理领域。它支持企业和个人开发人员能够在从SBC到多GPU集群的各类型设备上部署机器学习大型语言模型。尽管llama.cpp的语言绑定方式使其使用方式变得容易,但是对于性能敏感或资源受限的情况,使用C/C++编程方案可能是一个更为可行的选择。
本文旨在让读者详细了解如何使用直接来自llama.cpp的低级函数执行LLM推理。具体地讲,我们将详细探讨llama.cpp框架开发程序的详细流程、llama.cpp框架的架构,最后实现一个简单的聊天应用程序。
请注意,我们将在本文中编写的C++代码也用于SmolChat应用程序中,这是一个原生Android应用程序,它允许用户在聊天界面中与LLM/SLM实现完全在设备上的交互。具体来说,我们将使用文章前面将定义的LLMInference类与JNI绑定一起使用,从而实现共同执行GGUF模型。
另外,本文将分析的代码实现可以在链接处找到。
还有,上述代码也派生自llama.cpp的官方简单聊天示例程序。
关于llama.cpp
llama.cpp是一个C/C++框架,用于在多个执行后端推断以GGUF格式定义的机器学习模型。这个框架最初是Meta著名的Llama系列LLM的纯C/C++实现,可以在苹果公司自研的Silicon处理器、AVX/AVX-512、CUDA和基于Arm Neon的环境中推断。此外,这个框架还包括一个基于CLI的工具llama-cli来运行GGUF LLM模型,还提供一个llama-server(OpenAI兼容服务器)通过HTTP请求方式执行模型。
llama.cpp使用机器学习的张量库ggml,这是一个低级框架,提供深度学习模型所需的原始函数,并从用户那里抽象后端实现细节。Georgi Gerganov是ggml库和llama.cpp框架的创建者。
此外,llama.cpp框架存储库的README文件还列出了其他编程语言中基于llama.cpp构建的包装器。Ollama和LM Studio等流行工具也使用llama.cpp上的绑定来增强用户友好性。该项目不依赖其他第三方库。
llama.cpp与PyTorch/TensorFlow有何不同?
llama.cpp从一开始就强调ML模型的推理,而PyTorch 和TensorFlow 是端到端解决方案,通过一个安装包的形式来提供数据处理、模型训练/验证和高效推理。
注意:PyTorch和TensorFlow也有各自的轻量级推理扩展,即ExecuTorch和TensorFlowLite。
仅考虑模型的推理阶段,llama.cpp的实现是轻量的,因为它没有第三方依赖项,并且自动支持大量可用的运算符或模型格式。此外,顾名思义,该项目最初是一个用于推断LLM(来自Meta的Llama模型)的高效库,并继续支持广泛的开源LLM架构。
如果把PyTorch/TensorFlow比作是豪华、耗电的游轮的话,那么llama.cpp就是小型、快速的摩托艇。PyTorch/TF和llama.cpp都有各自的使用场景。
设置
我们在基于Linux的环境(本机或WSL环境)中进行开发;为此,需要安装cmake和GNU/clang工具链。我们将从源代码编译llama.cpp,并将其作为共享库添加到我们的可执行聊天程序中。
首先,我们创建一个项目目录smol_chat,并使用一个externals目录来存储克隆自原项目的llama.cpp存储库。
mkdir smol_chatcd smol_chatmkdir srcmkdir externalstouch CMakeLists.txtcd externalsgit clone --depth=1 https://github.com/ggerganov/llama.cpp
[*]1.
[*]2.
[*]3.
[*]4.
[*]5.
[*]6.
[*]7.
[*]8.
[*]9.
CMakeLists.txt是我们定义构建项目方案的文件,通过引用来自externals/llama.cpp的标准头文件和共享库,允许CMake使用默认工具链(GNU/clang)编译我们的C/C++代码。
cmake_minimum_required(VERSION 3.10)project(llama_inference)set(CMAKE_CXX_STANDARD 17)set(LLAMA_BUILD_COMMON On)add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/externals/llama.cpp")add_executable(chatsrc/LLMInference.cpp src/main.cpp)target_link_libraries(chat PRIVATEcommon llama ggml)
[*]1.
[*]2.
[*]3.
[*]4.
[*]5.
[*]6.
[*]7.
[*]8.
[*]9.
[*]10.
[*]11.
[*]12.
[*]13.
[*]14.
[*]15.
[*]16.
加载模型
现在,我们已经定义了如何通过CMake构建我们的项目。接下来,我们创建一个头文件LLMInference.h,它声明了一个包含高级函数的类,用于与LLM交互。llama.cpp提供了一个C样式的API,因此将其嵌入到类中将有助于我们抽象/隐藏内部工作细节。
#ifndef LLMINFERENCE_H#define LLMINFERENCE_H#include "common.h"#include "llama.h"#include #include class LLMInference {// llama.cpp特定的数据类型llama_context* _ctx;llama_model* _model;llama_sampler* _sampler;llama_batch _batch;llama_token _currToken;// 用于在聊天程序中存储用户/助手信息的容器std::vector _messages;//将聊天模板应用于所有消息后生成的字符串存储在“_messages”中std::vector _formattedMessages;// 将最后查询的标记存储到“_messages”中std::vector _promptTokens;int _prevLen = 0;// 存储给定查询的完整响应std::string _response = "";public:void loadModel(const std::string& modelPath, float minP, float temperature);void addChatMessage(const std::string& message, const std::string& role);void startCompletion(const std::string& query);std::string completionLoop();void stopCompletion();~LLMInference();};#endif
[*]1.
[*]2.
[*]3.
[*]4.
[*]5.
[*]6.
[*]7.
[*]8.
[*]9.
[*]10.
[*]11.
[*]12.
[*]13.
[*]14.
[*]15.
[*]16.
[*]17.
[*]18.
[*]19.
[*]20.
[*]21.
[*]22.
[*]23.
[*]24.
[*]25.
[*]26.
[*]27.
[*]28.
[*]29.
[*]30.
[*]31.
[*]32.
[*]33.
[*]34.
[*]35.
[*]36.
[*]37.
[*]38.
[*]39.
[*]40.
[*]41.
[*]42.
[*]43.
[*]44.
上面头文件中声明的私有成员将用于实现本文后续部分中描述的公共成员函数。首先,让我们在LLMInference.cpp中定义每个成员函数。
#include "LLMInference.h"#include #include void LLMInference::loadModel(const std::string& model_path, float min_p, float temperature) {//创建一个llama_model的实例llama_model_params model_params = llama_model_default_params();_model = llama_load_model_from_file(model_path.data(), model_params);if (!_model) {throw std::runtime_error("load_model() failed");}//创建 llama_context 实例llama_context_params ctx_params = llama_context_default_params();ctx_params.n_ctx = 0; // 从模型 GGUF 文件中获取上下文大小ctx_params.no_perf = true; // 禁用性能指标_ctx = llama_new_context_with_model(_model, ctx_params);if (!_ctx) {throw std::runtime_error("llama_new_context_with_model() returned null");}//初始化采样器llama_sampler_chain_params sampler_params = llama_sampler_chain_default_params();sampler_params.no_perf = true; // 禁用性能指标_sampler = llama_sampler_chain_init(sampler_params);llama_sampler_chain_add(_sampler, llama_sampler_init_min_p(min_p, 1));llama_sampler_chain_add(_sampler, llama_sampler_init_temp(temperature));llama_sampler_chain_add(_sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));_formattedMessages = std::vector(llama_n_ctx(_ctx));_messages.clear();}
[*]1.
[*]2.
[*]3.
[*]4.
[*]5.
[*]6.
[*]7.
[*]8.
[*]9.
[*]10.
[*]11.
[*]12.
[*]13.
[*]14.
[*]15.
[*]16.
[*]17.
[*]18.
[*]19.
[*]20.
[*]21.
[*]22.
[*]23.
[*]24.
[*]25.
[*]26.
[*]27.
[*]28.
[*]29.
[*]30.
[*]31.
[*]32.
[*]33.
[*]34.
上述代码中,llama_load_model_from_file使用llama_load_model从文件内部读取模型,并使用给定的llama_model_params填充llama_model实例。用户可以提供参数,但我们可以使用llama_model_default_params获取预初始化的默认结构。
llama_context表示加载的GGUF模型的执行环境。llama_new_context_with_model实例化新的llama_context,并通过读取llama_model_params或自动检测可用的后端来准备执行的后端。它还初始化K-V缓存,这在解码或推理步骤中是很重要的。管理跨多个后端的计算的后端调度程序也被初始化。
llama_sampler决定了我们如何从模型(特别是LLM的解码器)的输出(logits)得出概率分布中的采样/选择标记。LLM为词汇表中存在的每个标记分配一个概率,表示该标记出现在序列中的下一个概率。我们使用llama_sampler_init_temp和llama_sampler_init_min_p设置的温度和min-p是控制标记采样过程的两个参数。
执行推理
推理过程涉及多个步骤,该过程将用户的文本查询作为输入并返回LLM的响应。
1. 将聊天模板应用于查询
对于LLM,传入消息被归类为属于三个角色,即用户、助手和系统。其中,用户和助手消息分别由用户和LLM给出,而系统表示整个对话中遵循的系统范围提示。每条消息都由角色和内容组成,其中内容是实际文本,角色是三个角色中的任何一个。
[*]1.
系统提示是对话的第一条消息。在我们的代码中,消息存储为名为_messages的std::vector。其中,llama_chat_message是具有角色和内容属性的llama.cpp结构。我们使用llama.cpp中的llama_chat_apply_template函数将存储在GGUF文件中的聊天模板应用为元数据。我们将应用聊天模板后获得的字符串或std::vector存储在_formattedMessages中。
2. 标记化
标记化是将给定文本划分为较小部分(标记)的过程。我们为每个部分/标记分配一个唯一的整数ID,从而将输入文本转换为整数序列,形成LLM的输入。llama.cpp提供common_tokenize或llama_tokenize函数来执行标记化,其中common_tokenize将标记序列作为std::vector返回。
void LLMInference::startCompletion(const std::string& query) {addChatMessage(query, "user");// 应用聊天模板 int new_len = llama_chat_apply_template(_model,nullptr,_messages.data(),_messages.size(),true,_formattedMessages.data(),_formattedMessages.size());if (new_len > (int)_formattedMessages.size()) {//调整输出缓冲区 `_formattedMessages`的大小并重新应用聊天模板_formattedMessages.resize(new_len);new_len = llama_chat_apply_template(_model, nullptr, _messages.data(), _messages.size(), true, _formattedMessages.data(), _formattedMessages.size());}if (new_len < 0) {throw std::runtime_error("llama_chat_apply_template() in LLMInference::start_completion() failed");}std::string prompt(_formattedMessages.begin() + _prevLen, _formattedMessages.begin() + new_len);// 标记化_promptTokens = common_tokenize(_model, prompt, true, true);// 创建一个包含单个序列的llama_batch// see llama_batch_init for more details_batch.token = _promptTokens.data();_batch.n_tokens = _promptTokens.size();}
[*]1.
[*]2.
[*]3.
[*]4.
[*]5.
[*]6.
[*]7.
[*]8.
[*]9.
[*]10.
[*]11.
[*]12.
[*]13.
[*]14.
[*]15.
[*]16.
[*]17.
[*]18.
[*]19.
[*]20.
[*]21.
[*]22.
[*]23.
[*]24.
[*]25.
[*]26.
[*]27.
[*]28.
[*]29.
[*]30.
[*]31.
在上面代码中,我们应用聊天模板并在LLMInference::startCompletion方法中执行标记化,然后创建一个llama_batch实例来保存模型的最终输入。
3. 解码、采样和KV缓存
如前所述,LLM通过连续预测给定序列中的下一个标记来生成响应。LLM还经过训练以预测特殊的生成结束(EOG)标记,指示预测标记序列的结束。completion_loop函数返回序列中的下一个标记,并不断被调用,直到它返回的标记是EOG标记。
[*]通过llama_n_ctx和llama_get_kv_cached_used_cells,我们可以确定用于存储输入的上下文的长度。目前,如果标记化输入的长度超过上下文大小的话,我们会抛出一个错误。
[*]llama_decode根据变量_batch中的输入信息对模型进行前向传递。
[*]通过在LLMInference::loadModel中初始化的_sampler,我们抽样或选择一个标记作为我们的预测并将其存储在_currToken中。我们检查该标记是否为EOG标记,然后返回“EOG”,表示应终止调用LLMInference::completionLoop的文本生成循环。终止时,我们将一条新消息附加到_messages,这是具有角色assistant的LLM给出的完整响应信息。
[*] _currToken仍然是一个整数,它由common_token_to_piece函数转换为字符串标记片段。此字符串标记从finishLoop方法返回。
[*]我们需要重新初始化_batch以确保它现在仅包含_currToken而不是整个输入序列,即_promptTokens。这是因为所有先前标记的“键”和“值”都已缓存。通过避免计算_promptTokens中所有标记的所有“键”和“值”,可以减少推理时间。
std::string LLMInference::completionLoop() {// 检查模型输入的长度是否超出了模型的上下文大小int contextSize = llama_n_ctx(_ctx);int nCtxUsed = llama_get_kv_cache_used_cells(_ctx);if (nCtxUsed + _batch.n_tokens > contextSize) {std::cerr addChatMessage("You are a helpful assistant", "system");while (true) {std::cout startCompletion(query);std::string predictedToken;while ((predictedToken = llmInference->completionLoop()) != "") {std::cout
页:
[1]