From 7ece21c28abd2ec3f299db72934a3e14c5a8fb14 Mon Sep 17 00:00:00 2001 From: namtranase Date: Sun, 17 Mar 2024 12:10:52 +0700 Subject: [PATCH 1/3] update: change codebase update to the org repo --- .gitignore | 5 +- CMakeLists.txt | 2 +- DEVELOPERS.md | 2 + requirements.txt | 2 - src/gemma_binding.cpp | 594 +++++++++++++++++++++--------------------- src/run.cpp | 298 +++++++++++++++++++++ 6 files changed, 606 insertions(+), 297 deletions(-) create mode 100644 DEVELOPERS.md create mode 100644 src/run.cpp diff --git a/.gitignore b/.gitignore index b5d1648..551b2a3 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,7 @@ cython_debug/ #.idea/ # Vscode -.vscode/ \ No newline at end of file +.vscode/ + +#p Precommit +.pre-commit-config.yaml diff --git a/CMakeLists.txt b/CMakeLists.txt index 8414cbc..1494591 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG e29cd566cf3367671e8f59419a04e308796a7c57) +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_MakeAvailable(gemma) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) diff --git a/DEVELOPERS.md b/DEVELOPERS.md new file mode 100644 index 0000000..7c761a6 --- /dev/null +++ b/DEVELOPERS.md @@ -0,0 +1,2 @@ +## 🤝 Contributing +Contributions are welcome. Please clone the repository, push your changes to a new branch, and submit a pull request. diff --git a/requirements.txt b/requirements.txt index 59a8043..e69de29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +0,0 @@ -pybind11 -pre-commit diff --git a/src/gemma_binding.cpp b/src/gemma_binding.cpp index 53607db..768940b 100644 --- a/src/gemma_binding.cpp +++ b/src/gemma_binding.cpp @@ -1,17 +1,27 @@ #include #include -// #include "gemma.h" // Adjust include path as necessary + +// Command line text interface to gemma. + #include #include #include #include -#include // NOLINT +#include // NOLINT #include +// copybara:import_next_line:gemma_cpp #include "compression/compress.h" -#include "gemma.h" // Gemma +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "gemma.h" // Gemma +// copybara:end +// copybara:import_next_line:gemma_cpp #include "util/app.h" -#include "util/args.h" // HasHelp +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // HasHelp +// copybara:end #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -21,291 +31,289 @@ namespace py = pybind11; -namespace gcpp -{ - - void ShowHelp(gcpp::LoaderArgs &loader, gcpp::InferenceArgs &inference, - gcpp::AppArgs &app) - { - fprintf(stderr, - "\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to " - "specify 3 required model loading arguments: --tokenizer, " - "--compressed_weights, " - "and --model.\n\nModel Loading Arguments\n\n"); - loader.Help(); - fprintf(stderr, "\nInference Arguments\n\n"); - inference.Help(); - fprintf(stderr, "\nApplication Arguments\n\n"); - app.Help(); - fprintf(stderr, "\n\n"); - } +namespace gcpp { + +static constexpr std::string_view kAsciiArtBanner = + " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" + " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" + "| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n" + " \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n" + " __/ | | | | |\n" + " |___/ |_| |_|"; + +void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { + loader.Print(app.verbosity); + inference.Print(app.verbosity); + app.Print(app.verbosity); + + if (app.verbosity >= 2) { + time_t now = time(nullptr); + char* dt = ctime(&now); // NOLINT + std::cout << "Date & Time : " << dt + << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize + << "\n" + << "Hardware concurrency : " + << std::thread::hardware_concurrency() << std::endl + << "Instruction set : " + << hwy::TargetName(hwy::DispatchedTarget()) << " (" + << hwy::VectorBytes() * 8 << " bits)" << "\n" + << "Compiled config : " << CompiledConfig() << "\n" + << "Weight Type : " + << gcpp::TypeName(gcpp::WeightT()) << "\n" + << "EmbedderInput Type : " + << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; + } +} - void ShowConfig(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app) - { - loader.Print(app.verbosity); - inference.Print(app.verbosity); - app.Print(app.verbosity); +void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, + gcpp::AppArgs& app) { + std::cerr + << kAsciiArtBanner + << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" + "==========================================================\n\n" + "To run gemma.cpp, you need to " + "specify 3 required model loading arguments:\n --tokenizer\n " + "--compressed_weights\n" + " --model.\n"; + std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " + "--compressed_weights 2b-it-sfp.sbs --model 2b-it\n"; + std::cerr << "\n*Model Loading Arguments*\n\n"; + loader.Help(); + std::cerr << "\n*Inference Arguments*\n\n"; + inference.Help(); + std::cerr << "\n*Application Arguments*\n\n"; + app.Help(); + std::cerr << "\n"; +} - if (app.verbosity >= 2) - { - time_t now = time(nullptr); - char *dt = ctime(&now); // NOLINT - std::cout << "Date & Time : " << dt - << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize - << "\n" - << "Hardware concurrency : " - << std::thread::hardware_concurrency() << std::endl - << "Instruction set : " - << hwy::TargetName(hwy::DispatchedTarget()) << " (" - << hwy::VectorBytes() * 8 << " bits)" - << "\n" - << "Weight Type : " - << gcpp::TypeName(gcpp::WeightT()) << "\n" - << "EmbedderInput Type : " - << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; +void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { + PROFILER_ZONE("Gen.misc"); + int abs_pos = 0; // absolute token index over all turns + int current_pos = 0; // token index within the current turn + int prompt_size{}; + + std::mt19937 gen; + if (args.deterministic) { + gen.seed(42); + } else { + std::random_device rd; + gen.seed(rd()); + } + + // callback function invoked for each generated token. + auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, + tokenizer = model.Tokenizer(), + verbosity](int token, float) { + ++abs_pos; + ++current_pos; + if (current_pos < prompt_size) { + std::cerr << "." << std::flush; + } else if (token == gcpp::EOS_ID) { + if (!args.multiturn) { + abs_pos = 0; + if (args.deterministic) { + gen.seed(42); + } + } + if (verbosity >= 2) { + std::cout << "\n[ End ]\n"; + } + } else { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + // +1 since position is incremented above + if (current_pos == prompt_size + 1) { + // first token of response + token_text.erase(0, token_text.find_first_not_of(" \t\n")); + if (verbosity >= 1) { + std::cout << std::endl << std::endl; } + } + std::cout << token_text << std::flush; } + return true; + }; - void ReplGemma(gcpp::Gemma &model, hwy::ThreadPool &pool, - hwy::ThreadPool &inner_pool, const InferenceArgs &args, - int verbosity, const gcpp::AcceptFunc &accept_token) + while (abs_pos < args.max_tokens) { + std::string prompt_string; + std::vector prompt; + current_pos = 0; { - PROFILER_ZONE("Gen.misc"); - int abs_pos = 0; // absolute token index over all turns - int current_pos = 0; // token index within the current turn - int prompt_size{}; - - std::mt19937 gen; - if (args.deterministic) - { - gen.seed(42); - } - else - { - std::random_device rd; - gen.seed(rd()); + PROFILER_ZONE("Gen.input"); + if (verbosity >= 1) { + std::cout << "> " << std::flush; + } + + if (eot_line.size() == 0) { + std::getline(std::cin, prompt_string); + } else { + std::string line; + while (std::getline(std::cin, line)) { + if (line == eot_line) { + break; + } + prompt_string += line + "\n"; } + } + } - // callback function invoked for each generated token. - auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, - tokenizer = &model.Tokenizer(), - verbosity](int token, float) - { - ++abs_pos; - ++current_pos; - if (current_pos < prompt_size) - { - std::cerr << "." << std::flush; - } - else if (token == gcpp::EOS_ID) - { - if (!args.multiturn) - { - abs_pos = 0; - if (args.deterministic) - { - gen.seed(42); - } - } - if (verbosity >= 2) - { - std::cout << "\n[ End ]" << std::endl; - } - } - else - { - std::string token_text; - HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); - // +1 since position is incremented above - if (current_pos == prompt_size + 1) - { - // first token of response - token_text.erase(0, token_text.find_first_not_of(" \t\n")); - if (verbosity >= 1) - { - std::cout << std::endl - << std::endl; - } - } - // TODO(austinvhuang): is explicit space necessary? - std::cout << token_text << std::flush; - } - return true; - }; - - while (abs_pos < args.max_tokens) - { - std::string prompt_string; - std::vector prompt; - current_pos = 0; - { - PROFILER_ZONE("Gen.input"); - if (verbosity >= 1) - { - std::cout << "> " << std::flush; - } - std::getline(std::cin, prompt_string); - } - - if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") - { - return; - } - - if (model.model_training == ModelTraining::GEMMA_IT) - { - // For instruction-tuned models: add control tokens. - prompt_string = "user\n" + prompt_string + - "\nmodel\n"; - if (abs_pos > 0) - { - // Prepend "" token if this is a multi-turn dialogue - // continuation. - prompt_string = "\n" + prompt_string; - } - } - - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); - - // For both pre-trained and instruction-tuned models: prepend "" token - // if needed. - if (abs_pos == 0) - { - prompt.insert(prompt.begin(), 2); - } - - prompt_size = prompt.size(); - - std::cerr << std::endl - << "[ Reading prompt ] " << std::flush; - - const double time_start = hwy::platform::Now(); - GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); - const double time_end = hwy::platform::Now(); - const double tok_sec = current_pos / (time_end - time_start); - if (verbosity >= 2) - { - std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" - << std::endl - << tok_sec << " tokens / sec" << std::endl; - } - std::cout << std::endl - << std::endl; - } - std::cout - << "max_tokens (" << args.max_tokens - << ") exceeded. Use a larger value if desired using the --max_tokens " - << "command line flag.\n"; + if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { + return; } - void Run(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app) - { - PROFILER_ZONE("Run.misc"); + if (prompt_string == "%c" || prompt_string == "%C") { + abs_pos = 0; + continue; + } - hwy::ThreadPool inner_pool(0); - hwy::ThreadPool pool(app.num_threads); - // For many-core, pinning threads to cores helps. - if (app.num_threads > 10) - { - PinThreadToCore(app.num_threads - 1); // Main thread + if (model.model_training == ModelTraining::GEMMA_IT) { + // For instruction-tuned models: add control tokens. + prompt_string = "user\n" + prompt_string + + "\nmodel\n"; + if (abs_pos > 0) { + // Prepend "" token if this is a multi-turn dialogue + // continuation. + prompt_string = "\n" + prompt_string; + } + } - pool.Run(0, pool.NumThreads(), - [](uint64_t /*task*/, size_t thread) - { PinThreadToCore(thread); }); - } + HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); - gcpp::Gemma model(loader, pool); + // For both pre-trained and instruction-tuned models: prepend "" token + // if needed. + if (abs_pos == 0) { + prompt.insert(prompt.begin(), 2); + } - if (const char *error = inference.Validate()) - { - ShowHelp(loader, inference, app); - HWY_ABORT("\nInvalid args: %s", error); - } + prompt_size = prompt.size(); - if (app.verbosity >= 1) - { - static const std::string banner_ascii_art = - " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" - " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" - "| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n" - " \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n" - " __/ | | | | |\n" - " |___/ |_| |_|"; - - const std::string instructions = - "*Usage*\n" - " Enter an instruction and press enter (%Q quits).\n\n" - "*Examples*\n" - " - Write an email to grandma thanking her for the cookies.\n" - " - What are some historical attractions to visit around " - "Massachusetts?\n" - " - Compute the nth fibonacci number in javascript.\n" - " - Write a standup comedy bit about GPU programming.\n"; - - std::cout << "\033[2J\033[1;1H" // clear screen - << banner_ascii_art << "\n\n"; - ShowConfig(loader, inference, app); - std::cout << "\n" - << instructions << "\n"; - } + std::cerr << std::endl << "[ Reading prompt ] " << std::flush; - ReplGemma(model, pool, inner_pool, inference, app.verbosity, - /*accept_token=*/[](int) - { return true; }); + const double time_start = hwy::platform::Now(); + GenerateGemma(model, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, + stream_token, accept_token, gen, verbosity); + const double time_end = hwy::platform::Now(); + const double tok_sec = current_pos / (time_end - time_start); + if (verbosity >= 2) { + std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" + << std::endl + << tok_sec << " tokens / sec" << std::endl; } + std::cout << std::endl << std::endl; + } + std::cout + << "max_tokens (" << args.max_tokens + << ") exceeded. Use a larger value if desired using the --max_tokens " + << "command line flag.\n"; +} - std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool, - hwy::ThreadPool &inner_pool, const InferenceArgs &args, - int verbosity, const gcpp::AcceptFunc &accept_token, std::string &prompt_string) - { - std::string generated_text; - // Seed the random number generator - std::random_device rd; - std::mt19937 gen(rd()); - int prompt_size{}; - if (model.model_training == ModelTraining::GEMMA_IT) - { - // For instruction-tuned models: add control tokens. - prompt_string = "user\n" + prompt_string + - "\nmodel\n"; - } - // Encode the prompt string into tokens - std::vector prompt; - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); - // Placeholder for generated token IDs - std::vector generated_tokens; - // Define lambda for token decoding - StreamFunc stream_token = [&generated_tokens](int token, float /* probability */) -> bool { - generated_tokens.push_back(token); - return true; // Continue generating - }; - // Decode tokens - prompt_size = prompt.size(); - GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token, accept_token, gen, verbosity); - HWY_ASSERT(model.Tokenizer().Decode(generated_tokens, &generated_text).ok()); - generated_text = generated_text.substr(prompt_string.size()); - - return generated_text; - } +void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { + PROFILER_ZONE("Run.misc"); - std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, std::string &prompt_string) - { - hwy::ThreadPool inner_pool(0); - hwy::ThreadPool pool(app.num_threads); - if (app.num_threads > 10) - { - PinThreadToCore(app.num_threads - 1); // Main thread + hwy::ThreadPool inner_pool(0); + hwy::ThreadPool pool(app.num_threads); + // For many-core, pinning threads to cores helps. + if (app.num_threads > 10) { + PinThreadToCore(app.num_threads - 1); // Main thread - pool.Run(0, pool.NumThreads(), - [](uint64_t /*task*/, size_t thread) - { PinThreadToCore(thread); }); - } - gcpp::Gemma model(loader, pool); - return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) - { return true; }, prompt_string); + pool.Run(0, pool.NumThreads(), + [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); + } - } + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); + + auto kv_cache = CreateKVCache(loader.ModelType()); + + if (const char* error = inference.Validate()) { + ShowHelp(loader, inference, app); + HWY_ABORT("\nInvalid args: %s", error); + } + + if (app.verbosity >= 1) { + const std::string instructions = + "*Usage*\n" + " Enter an instruction and press enter (%C resets conversation, " + "%Q quits).\n" + + (inference.multiturn == 0 + ? std::string(" Since multiturn is set to 0, conversation will " + "automatically reset every turn.\n\n") + : "\n") + + "*Examples*\n" + " - Write an email to grandma thanking her for the cookies.\n" + " - What are some historical attractions to visit around " + "Massachusetts?\n" + " - Compute the nth fibonacci number in javascript.\n" + " - Write a standup comedy bit about GPU programming.\n"; + + std::cout << "\033[2J\033[1;1H" // clear screen + << kAsciiArtBanner << "\n\n"; + ShowConfig(loader, inference, app); + std::cout << "\n" << instructions << "\n"; + } + + ReplGemma( + model, kv_cache, pool, inner_pool, inference, app.verbosity, + /*accept_token=*/[](int) { return true; }, app.eot_line); +} + +// std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool, +// hwy::ThreadPool &inner_pool, const InferenceArgs &args, +// int verbosity, const gcpp::AcceptFunc &accept_token, +// std::string &prompt_string) +// { +// std::string generated_text; +// // Seed the random number generator +// std::random_device rd; +// std::mt19937 gen(rd()); +// int prompt_size{}; +// if (model.model_training == ModelTraining::GEMMA_IT) +// { +// // For instruction-tuned models: add control tokens. +// prompt_string = "user\n" + prompt_string + +// "\nmodel\n"; +// } +// // Encode the prompt string into tokens +// std::vector prompt; +// HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); +// // Placeholder for generated token IDs +// std::vector generated_tokens; +// // Define lambda for token decoding +// StreamFunc stream_token = [&generated_tokens](int token, float /* probability */) -> bool { +// generated_tokens.push_back(token); +// return true; // Continue generating +// }; +// // Decode tokens +// prompt_size = prompt.size(); +// GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token, +// accept_token, gen, verbosity); +// HWY_ASSERT(model.Tokenizer()->Decode(generated_tokens, &generated_text).ok()); +// generated_text = generated_text.substr(prompt_string.size()); + +// return generated_text; +// } + +// std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, +// std::string &prompt_string){ +// hwy::ThreadPool inner_pool(0); +// hwy::ThreadPool pool(app.num_threads); +// if (app.num_threads > 10) +// { +// PinThreadToCore(app.num_threads - 1); // Main thread + +// pool.Run(0, pool.NumThreads(), +// [](uint64_t /*task*/, size_t thread) +// { PinThreadToCore(thread); }); +// } +// gcpp::Gemma model(loader, pool); +// return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) +// { return true; }, prompt_string); + +// } } // namespace gcpp @@ -335,30 +343,30 @@ void chat_base(int argc, char **argv) PROFILER_PRINT_RESULTS(); // Must call outside the zone above. // return 1; } -std::string completion_base(int argc, char **argv) -{ - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); - std::string prompt_string = argv[argc-1]; - return gcpp::completion(loader, inference, app, prompt_string); -} -std::string completion_base_wrapper(const std::vector &args,std::string &prompt_string) -{ - int argc = args.size() + 2; // +1 for the program name - std::vector argv_vec; - argv_vec.reserve(argc); - - argv_vec.push_back(const_cast("pygemma")); - - for (const auto &arg : args) - { - argv_vec.push_back(const_cast(arg.c_str())); - } - argv_vec.push_back(const_cast(prompt_string.c_str())); - char **argv = argv_vec.data(); - return completion_base(argc, argv); -} +// std::string completion_base(int argc, char **argv) +// { +// gcpp::LoaderArgs loader(argc, argv); +// gcpp::InferenceArgs inference(argc, argv); +// gcpp::AppArgs app(argc, argv); +// std::string prompt_string = argv[argc-1]; +// return gcpp::completion(loader, inference, app, prompt_string); +// } +// std::string completion_base_wrapper(const std::vector &args,std::string &prompt_string) +// { +// int argc = args.size() + 2; // +1 for the program name +// std::vector argv_vec; +// argv_vec.reserve(argc); + +// argv_vec.push_back(const_cast("pygemma")); + +// for (const auto &arg : args) +// { +// argv_vec.push_back(const_cast(arg.c_str())); +// } +// argv_vec.push_back(const_cast(prompt_string.c_str())); +// char **argv = argv_vec.data(); +// return completion_base(argc, argv); +// } void show_help_wrapper() { // Assuming ShowHelp does not critically depend on argv content @@ -392,5 +400,5 @@ PYBIND11_MODULE(pygemma, m) m.doc() = "Pybind11 integration for chat_base function"; m.def("chat_base", &chat_base_wrapper, "A wrapper for the chat_base function accepting Python list of strings as arguments"); m.def("show_help", &show_help_wrapper, "A wrapper for show_help function"); - m.def("completion", &completion_base_wrapper, "A wrapper for inference function"); + // m.def("completion", &completion_base_wrapper, "A wrapper for inference function"); } diff --git a/src/run.cpp b/src/run.cpp new file mode 100644 index 0000000..b08e4ca --- /dev/null +++ b/src/run.cpp @@ -0,0 +1,298 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Command line text interface to gemma. + +#include +#include +#include +#include +#include // NOLINT +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "gemma.h" // Gemma +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // HasHelp +// copybara:end +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/per_target.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +namespace gcpp { + +static constexpr std::string_view kAsciiArtBanner = + " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" + " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" + "| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n" + " \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n" + " __/ | | | | |\n" + " |___/ |_| |_|"; + +void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { + loader.Print(app.verbosity); + inference.Print(app.verbosity); + app.Print(app.verbosity); + + if (app.verbosity >= 2) { + time_t now = time(nullptr); + char* dt = ctime(&now); // NOLINT + std::cout << "Date & Time : " << dt + << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize + << "\n" + << "Hardware concurrency : " + << std::thread::hardware_concurrency() << std::endl + << "Instruction set : " + << hwy::TargetName(hwy::DispatchedTarget()) << " (" + << hwy::VectorBytes() * 8 << " bits)" << "\n" + << "Compiled config : " << CompiledConfig() << "\n" + << "Weight Type : " + << gcpp::TypeName(gcpp::WeightT()) << "\n" + << "EmbedderInput Type : " + << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; + } +} + +void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, + gcpp::AppArgs& app) { + std::cerr + << kAsciiArtBanner + << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" + "==========================================================\n\n" + "To run gemma.cpp, you need to " + "specify 3 required model loading arguments:\n --tokenizer\n " + "--compressed_weights\n" + " --model.\n"; + std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " + "--compressed_weights 2b-it-sfp.sbs --model 2b-it\n"; + std::cerr << "\n*Model Loading Arguments*\n\n"; + loader.Help(); + std::cerr << "\n*Inference Arguments*\n\n"; + inference.Help(); + std::cerr << "\n*Application Arguments*\n\n"; + app.Help(); + std::cerr << "\n"; +} + +void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { + PROFILER_ZONE("Gen.misc"); + int abs_pos = 0; // absolute token index over all turns + int current_pos = 0; // token index within the current turn + int prompt_size{}; + + std::mt19937 gen; + if (args.deterministic) { + gen.seed(42); + } else { + std::random_device rd; + gen.seed(rd()); + } + + // callback function invoked for each generated token. + auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, + tokenizer = model.Tokenizer(), + verbosity](int token, float) { + ++abs_pos; + ++current_pos; + if (current_pos < prompt_size) { + std::cerr << "." << std::flush; + } else if (token == gcpp::EOS_ID) { + if (!args.multiturn) { + abs_pos = 0; + if (args.deterministic) { + gen.seed(42); + } + } + if (verbosity >= 2) { + std::cout << "\n[ End ]\n"; + } + } else { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + // +1 since position is incremented above + if (current_pos == prompt_size + 1) { + // first token of response + token_text.erase(0, token_text.find_first_not_of(" \t\n")); + if (verbosity >= 1) { + std::cout << std::endl << std::endl; + } + } + std::cout << token_text << std::flush; + } + return true; + }; + + while (abs_pos < args.max_tokens) { + std::string prompt_string; + std::vector prompt; + current_pos = 0; + { + PROFILER_ZONE("Gen.input"); + if (verbosity >= 1) { + std::cout << "> " << std::flush; + } + + if (eot_line.size() == 0) { + std::getline(std::cin, prompt_string); + } else { + std::string line; + while (std::getline(std::cin, line)) { + if (line == eot_line) { + break; + } + prompt_string += line + "\n"; + } + } + } + + if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { + return; + } + + if (prompt_string == "%c" || prompt_string == "%C") { + abs_pos = 0; + continue; + } + + if (model.model_training == ModelTraining::GEMMA_IT) { + // For instruction-tuned models: add control tokens. + prompt_string = "user\n" + prompt_string + + "\nmodel\n"; + if (abs_pos > 0) { + // Prepend "" token if this is a multi-turn dialogue + // continuation. + prompt_string = "\n" + prompt_string; + } + } + + HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); + + // For both pre-trained and instruction-tuned models: prepend "" token + // if needed. + if (abs_pos == 0) { + prompt.insert(prompt.begin(), 2); + } + + prompt_size = prompt.size(); + + std::cerr << std::endl << "[ Reading prompt ] " << std::flush; + + const double time_start = hwy::platform::Now(); + GenerateGemma(model, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, + stream_token, accept_token, gen, verbosity); + const double time_end = hwy::platform::Now(); + const double tok_sec = current_pos / (time_end - time_start); + if (verbosity >= 2) { + std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" + << std::endl + << tok_sec << " tokens / sec" << std::endl; + } + std::cout << std::endl << std::endl; + } + std::cout + << "max_tokens (" << args.max_tokens + << ") exceeded. Use a larger value if desired using the --max_tokens " + << "command line flag.\n"; +} + +void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { + PROFILER_ZONE("Run.misc"); + + hwy::ThreadPool inner_pool(0); + hwy::ThreadPool pool(app.num_threads); + // For many-core, pinning threads to cores helps. + if (app.num_threads > 10) { + PinThreadToCore(app.num_threads - 1); // Main thread + + pool.Run(0, pool.NumThreads(), + [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); + } + + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); + + auto kv_cache = CreateKVCache(loader.ModelType()); + + if (const char* error = inference.Validate()) { + ShowHelp(loader, inference, app); + HWY_ABORT("\nInvalid args: %s", error); + } + + if (app.verbosity >= 1) { + const std::string instructions = + "*Usage*\n" + " Enter an instruction and press enter (%C resets conversation, " + "%Q quits).\n" + + (inference.multiturn == 0 + ? std::string(" Since multiturn is set to 0, conversation will " + "automatically reset every turn.\n\n") + : "\n") + + "*Examples*\n" + " - Write an email to grandma thanking her for the cookies.\n" + " - What are some historical attractions to visit around " + "Massachusetts?\n" + " - Compute the nth fibonacci number in javascript.\n" + " - Write a standup comedy bit about GPU programming.\n"; + + std::cout << "\033[2J\033[1;1H" // clear screen + << kAsciiArtBanner << "\n\n"; + ShowConfig(loader, inference, app); + std::cout << "\n" << instructions << "\n"; + } + + ReplGemma( + model, kv_cache, pool, inner_pool, inference, app.verbosity, + /*accept_token=*/[](int) { return true; }, app.eot_line); +} + +} // namespace gcpp + +int main(int argc, char** argv) { + { + PROFILER_ZONE("Startup.misc"); + + gcpp::LoaderArgs loader(argc, argv); + gcpp::InferenceArgs inference(argc, argv); + gcpp::AppArgs app(argc, argv); + + if (gcpp::HasHelp(argc, argv)) { + ShowHelp(loader, inference, app); + return 0; + } + + if (const char* error = loader.Validate()) { + ShowHelp(loader, inference, app); + HWY_ABORT("\nInvalid args: %s", error); + } + + gcpp::Run(loader, inference, app); + } + PROFILER_PRINT_RESULTS(); // Must call outside the zone above. + return 0; +} From 016f59d77a903c27dc5bc8e12fe6727d24604246 Mon Sep 17 00:00:00 2001 From: namtranase Date: Sun, 17 Mar 2024 21:44:40 +0700 Subject: [PATCH 2/3] update: add new interface based on new gemma.cpp --- .clang-format | 2 + CMakeLists.txt | 4 +- setup.py | 5 +- src/gemma_binding.cpp | 233 +++++++++--------------------------------- src/gemma_binding.h | 46 +++++++++ tests/test_chat.py | 39 ++----- 6 files changed, 116 insertions(+), 213 deletions(-) create mode 100644 .clang-format create mode 100644 src/gemma_binding.h diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..2bd8588 --- /dev/null +++ b/.clang-format @@ -0,0 +1,2 @@ +Language: Cpp +BasedOnStyle: Google diff --git a/CMakeLists.txt b/CMakeLists.txt index 1494591..60b0238 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG 8fb44ed6dd123f63dca95c20c561e8ca1de511d7) FetchContent_MakeAvailable(gemma) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) @@ -30,3 +30,5 @@ FetchContent_GetProperties(gemma) FetchContent_GetProperties(sentencepiece) target_include_directories(pygemma PRIVATE ${gemma_SOURCE_DIR}) target_include_directories(pygemma PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/setup.py b/setup.py index 57be3dd..9bd0c0b 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import sys from setuptools import setup, find_packages, Extension from setuptools.command.build_ext import build_ext +import platform class CMakeExtension(Extension): @@ -39,7 +40,7 @@ def build_extension(self, ext): "--", "-j", "12", - ] # Specifies the number of jobs to run simultaneously + ] if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) @@ -58,7 +59,7 @@ def build_extension(self, ext): version="0.1.2", author="Nam Tran", author_email="namtran.ase@gmail.com", - description="A Python package with a C++ backend using gemma.", + description="A Python package with a C++ backend using gemma.cpp", long_description=""" This package provides Python bindings to a C++ library using pybind11. """, diff --git a/src/gemma_binding.cpp b/src/gemma_binding.cpp index 768940b..8baa3e1 100644 --- a/src/gemma_binding.cpp +++ b/src/gemma_binding.cpp @@ -1,38 +1,9 @@ #include #include -// Command line text interface to gemma. - -#include -#include -#include -#include -#include // NOLINT -#include - -// copybara:import_next_line:gemma_cpp -#include "compression/compress.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "gemma.h" // Gemma -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // HasHelp -// copybara:end -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" -#include "hwy/per_target.h" -#include "hwy/profiler.h" -#include "hwy/timer.h" - +#include "gemma_binding.h" namespace py = pybind11; -namespace gcpp { - static constexpr std::string_view kAsciiArtBanner = " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" @@ -211,35 +182,51 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, << "command line flag.\n"; } -void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { +void GemmaWrapper::loadModel(const std::vector &args) { + int argc = args.size() + 1; // +1 for the program name + std::vector argv_vec; + argv_vec.reserve(argc); + argv_vec.push_back(const_cast("pygemma")); + for (const auto &arg : args) + { + argv_vec.push_back(const_cast(arg.c_str())); + } + + char **argv = argv_vec.data(); + + this->m_loader = gcpp::LoaderArgs(argc, argv); + this->m_inference = gcpp::InferenceArgs(argc, argv); + this->m_app = gcpp::AppArgs(argc, argv); + PROFILER_ZONE("Run.misc"); hwy::ThreadPool inner_pool(0); - hwy::ThreadPool pool(app.num_threads); + hwy::ThreadPool pool(this->m_app.num_threads); // For many-core, pinning threads to cores helps. - if (app.num_threads > 10) { - PinThreadToCore(app.num_threads - 1); // Main thread + if (this->m_app.num_threads > 10) { + PinThreadToCore(this->m_app.num_threads - 1); // Main thread pool.Run(0, pool.NumThreads(), [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, - loader.ModelType(), pool); - - auto kv_cache = CreateKVCache(loader.ModelType()); + if (!this->m_model) { + this->m_model.reset(new gcpp::Gemma(this->m_loader.tokenizer, this->m_loader.compressed_weights, this->m_loader.ModelType(), pool)); + } +// auto kvcache = CreateKVCache(loader.ModelType()); + this->m_kvcache = CreateKVCache(this->m_loader.ModelType()); - if (const char* error = inference.Validate()) { - ShowHelp(loader, inference, app); + if (const char* error = this->m_inference.Validate()) { + ShowHelp(this->m_loader, this->m_inference, this->m_app); HWY_ABORT("\nInvalid args: %s", error); } - if (app.verbosity >= 1) { + if (this->m_app.verbosity >= 1) { const std::string instructions = "*Usage*\n" " Enter an instruction and press enter (%C resets conversation, " "%Q quits).\n" + - (inference.multiturn == 0 + (this->m_inference.multiturn == 0 ? std::string(" Since multiturn is set to 0, conversation will " "automatically reset every turn.\n\n") : "\n") + @@ -252,153 +239,35 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { std::cout << "\033[2J\033[1;1H" // clear screen << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, inference, app); + ShowConfig(this->m_loader, this->m_inference, this->m_app); std::cout << "\n" << instructions << "\n"; } - - ReplGemma( - model, kv_cache, pool, inner_pool, inference, app.verbosity, - /*accept_token=*/[](int) { return true; }, app.eot_line); } -// std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool, -// hwy::ThreadPool &inner_pool, const InferenceArgs &args, -// int verbosity, const gcpp::AcceptFunc &accept_token, -// std::string &prompt_string) -// { -// std::string generated_text; -// // Seed the random number generator -// std::random_device rd; -// std::mt19937 gen(rd()); -// int prompt_size{}; -// if (model.model_training == ModelTraining::GEMMA_IT) -// { -// // For instruction-tuned models: add control tokens. -// prompt_string = "user\n" + prompt_string + -// "\nmodel\n"; -// } -// // Encode the prompt string into tokens -// std::vector prompt; -// HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); -// // Placeholder for generated token IDs -// std::vector generated_tokens; -// // Define lambda for token decoding -// StreamFunc stream_token = [&generated_tokens](int token, float /* probability */) -> bool { -// generated_tokens.push_back(token); -// return true; // Continue generating -// }; -// // Decode tokens -// prompt_size = prompt.size(); -// GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token, -// accept_token, gen, verbosity); -// HWY_ASSERT(model.Tokenizer()->Decode(generated_tokens, &generated_text).ok()); -// generated_text = generated_text.substr(prompt_string.size()); - -// return generated_text; -// } - -// std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, -// std::string &prompt_string){ -// hwy::ThreadPool inner_pool(0); -// hwy::ThreadPool pool(app.num_threads); -// if (app.num_threads > 10) -// { -// PinThreadToCore(app.num_threads - 1); // Main thread - -// pool.Run(0, pool.NumThreads(), -// [](uint64_t /*task*/, size_t thread) -// { PinThreadToCore(thread); }); -// } -// gcpp::Gemma model(loader, pool); -// return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) -// { return true; }, prompt_string); - -// } - -} // namespace gcpp - -void chat_base(int argc, char **argv) -{ - { - PROFILER_ZONE("Startup.misc"); - - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); - - if (gcpp::HasHelp(argc, argv)) - { - ShowHelp(loader, inference, app); - // return 0; - } - - if (const char *error = loader.Validate()) - { - ShowHelp(loader, inference, app); - HWY_ABORT("\nInvalid args: %s", error); - } - - gcpp::Run(loader, inference, app); - } - PROFILER_PRINT_RESULTS(); // Must call outside the zone above. - // return 1; -} -// std::string completion_base(int argc, char **argv) -// { -// gcpp::LoaderArgs loader(argc, argv); -// gcpp::InferenceArgs inference(argc, argv); -// gcpp::AppArgs app(argc, argv); -// std::string prompt_string = argv[argc-1]; -// return gcpp::completion(loader, inference, app, prompt_string); -// } -// std::string completion_base_wrapper(const std::vector &args,std::string &prompt_string) -// { -// int argc = args.size() + 2; // +1 for the program name -// std::vector argv_vec; -// argv_vec.reserve(argc); - -// argv_vec.push_back(const_cast("pygemma")); - -// for (const auto &arg : args) -// { -// argv_vec.push_back(const_cast(arg.c_str())); -// } -// argv_vec.push_back(const_cast(prompt_string.c_str())); -// char **argv = argv_vec.data(); -// return completion_base(argc, argv); -// } -void show_help_wrapper() -{ - // Assuming ShowHelp does not critically depend on argv content - gcpp::LoaderArgs loader(0, nullptr); - gcpp::InferenceArgs inference(0, nullptr); - gcpp::AppArgs app(0, nullptr); - - ShowHelp(loader, inference, app); +void GemmaWrapper::showConfig() { + ShowConfig(this->m_loader,this->m_inference, this->m_app); } -std::string chat_base_wrapper(const std::vector &args) -{ - int argc = args.size() + 1; // +1 for the program name - std::vector argv_vec; - argv_vec.reserve(argc); - argv_vec.push_back(const_cast("pygemma")); - - for (const auto &arg : args) - { - argv_vec.push_back(const_cast(arg.c_str())); - } - - char **argv = argv_vec.data(); - - chat_base(argc, argv); +void GemmaWrapper::showHelp() { + ShowHelp(this->m_loader,this->m_inference, this->m_app); } -PYBIND11_MODULE(pygemma, m) -{ - m.doc() = "Pybind11 integration for chat_base function"; - m.def("chat_base", &chat_base_wrapper, "A wrapper for the chat_base function accepting Python list of strings as arguments"); - m.def("show_help", &show_help_wrapper, "A wrapper for show_help function"); - // m.def("completion", &completion_base_wrapper, "A wrapper for inference function"); +PYBIND11_MODULE(pygemma, m) { + py::class_(m, "Gemma") + .def(py::init<>()) + .def("show_config", &GemmaWrapper::showConfig) + .def("show_help", &GemmaWrapper::showHelp) + .def("load_model", [](GemmaWrapper &self, + const std::string &tokenizer, + const std::string &compressed_weights, + const std::string &model) { + std::vector args = { + "--tokenizer", tokenizer, + "--compressed_weights", compressed_weights, + "--model", model + }; + self.loadModel(args); // Assuming GemmaWrapper::loadModel accepts std::vector + }, py::arg("tokenizer"), py::arg("compressed_weights"), py::arg("model")) + .def("completion", &GemmaWrapper::completionPrompt); } diff --git a/src/gemma_binding.h b/src/gemma_binding.h new file mode 100644 index 0000000..7f5da9c --- /dev/null +++ b/src/gemma_binding.h @@ -0,0 +1,46 @@ +#pragma once +// Command line text interface to gemma. + +#include +#include +#include +#include +#include // NOLINT +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "gemma.h" // Gemma +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // HasHelp +// copybara:end +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/per_target.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +using namespace gcpp; + +class GemmaWrapper { + public: + // GemmaWrapper(); + void loadModel(const std::vector &args); // Consider exception safety + void showConfig(); + void showHelp(); + std::string completionPrompt(); + + private: + gcpp::LoaderArgs m_loader = gcpp::LoaderArgs(0, nullptr); + gcpp::InferenceArgs m_inference = gcpp::InferenceArgs(0, nullptr); + gcpp::AppArgs m_app = gcpp::AppArgs(0, nullptr); + std::unique_ptr m_model; + KVCache m_kvcache; +}; diff --git a/tests/test_chat.py b/tests/test_chat.py index 43795c7..d2b84f9 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,5 +1,5 @@ import argparse -import pygemma +from pygemma import Gemma def main(): @@ -19,36 +19,19 @@ def main(): "--model", type=str, required=True, help="Model type identifier." ) parser.add_argument( - "--input", type=str, required=False, help="Input text to chat with the model. If None, Switch to Chat mode.", - default="Hello." + "--input", + type=str, + required=False, + help="Input text to chat with the model. If None, Switch to Chat mode.", + default="Hello.", ) # Now using the parsed arguments args = parser.parse_args() - if args.input is not None: - string = pygemma.completion( - [ - "--tokenizer", - args.tokenizer, - "--compressed_weights", - args.compressed_weights, - "--model", - args.model, - ], args.input - ) - print(string) - else: - return pygemma.chat_base( - [ - "--tokenizer", - args.tokenizer, - "--compressed_weights", - args.compressed_weights, - "--model", - args.model, - ] - ) - # Optionally, show help if needed - # pygemma.show_help() + + gemma = Gemma() + gemma.show_config() + gemma.show_help() + gemma.load_model(args.tokenizer, args.compressed_weights, args.model) if __name__ == "__main__": From edb2504f24b6875f43a6ce8d767668ba63e903d3 Mon Sep 17 00:00:00 2001 From: namtranase Date: Sun, 17 Mar 2024 21:46:44 +0700 Subject: [PATCH 3/3] update: remove redundant code --- src/run.cpp | 298 ---------------------------------------------------- 1 file changed, 298 deletions(-) delete mode 100644 src/run.cpp diff --git a/src/run.cpp b/src/run.cpp deleted file mode 100644 index b08e4ca..0000000 --- a/src/run.cpp +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Command line text interface to gemma. - -#include -#include -#include -#include -#include // NOLINT -#include - -// copybara:import_next_line:gemma_cpp -#include "compression/compress.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "gemma.h" // Gemma -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // HasHelp -// copybara:end -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" -#include "hwy/per_target.h" -#include "hwy/profiler.h" -#include "hwy/timer.h" - -namespace gcpp { - -static constexpr std::string_view kAsciiArtBanner = - " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" - " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" - "| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n" - " \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n" - " __/ | | | | |\n" - " |___/ |_| |_|"; - -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { - loader.Print(app.verbosity); - inference.Print(app.verbosity); - app.Print(app.verbosity); - - if (app.verbosity >= 2) { - time_t now = time(nullptr); - char* dt = ctime(&now); // NOLINT - std::cout << "Date & Time : " << dt - << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize - << "\n" - << "Hardware concurrency : " - << std::thread::hardware_concurrency() << std::endl - << "Instruction set : " - << hwy::TargetName(hwy::DispatchedTarget()) << " (" - << hwy::VectorBytes() * 8 << " bits)" << "\n" - << "Compiled config : " << CompiledConfig() << "\n" - << "Weight Type : " - << gcpp::TypeName(gcpp::WeightT()) << "\n" - << "EmbedderInput Type : " - << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; - } -} - -void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, - gcpp::AppArgs& app) { - std::cerr - << kAsciiArtBanner - << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" - "==========================================================\n\n" - "To run gemma.cpp, you need to " - "specify 3 required model loading arguments:\n --tokenizer\n " - "--compressed_weights\n" - " --model.\n"; - std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " - "--compressed_weights 2b-it-sfp.sbs --model 2b-it\n"; - std::cerr << "\n*Model Loading Arguments*\n\n"; - loader.Help(); - std::cerr << "\n*Inference Arguments*\n\n"; - inference.Help(); - std::cerr << "\n*Application Arguments*\n\n"; - app.Help(); - std::cerr << "\n"; -} - -void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const InferenceArgs& args, int verbosity, - const gcpp::AcceptFunc& accept_token, std::string& eot_line) { - PROFILER_ZONE("Gen.misc"); - int abs_pos = 0; // absolute token index over all turns - int current_pos = 0; // token index within the current turn - int prompt_size{}; - - std::mt19937 gen; - if (args.deterministic) { - gen.seed(42); - } else { - std::random_device rd; - gen.seed(rd()); - } - - // callback function invoked for each generated token. - auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, - tokenizer = model.Tokenizer(), - verbosity](int token, float) { - ++abs_pos; - ++current_pos; - if (current_pos < prompt_size) { - std::cerr << "." << std::flush; - } else if (token == gcpp::EOS_ID) { - if (!args.multiturn) { - abs_pos = 0; - if (args.deterministic) { - gen.seed(42); - } - } - if (verbosity >= 2) { - std::cout << "\n[ End ]\n"; - } - } else { - std::string token_text; - HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); - // +1 since position is incremented above - if (current_pos == prompt_size + 1) { - // first token of response - token_text.erase(0, token_text.find_first_not_of(" \t\n")); - if (verbosity >= 1) { - std::cout << std::endl << std::endl; - } - } - std::cout << token_text << std::flush; - } - return true; - }; - - while (abs_pos < args.max_tokens) { - std::string prompt_string; - std::vector prompt; - current_pos = 0; - { - PROFILER_ZONE("Gen.input"); - if (verbosity >= 1) { - std::cout << "> " << std::flush; - } - - if (eot_line.size() == 0) { - std::getline(std::cin, prompt_string); - } else { - std::string line; - while (std::getline(std::cin, line)) { - if (line == eot_line) { - break; - } - prompt_string += line + "\n"; - } - } - } - - if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { - return; - } - - if (prompt_string == "%c" || prompt_string == "%C") { - abs_pos = 0; - continue; - } - - if (model.model_training == ModelTraining::GEMMA_IT) { - // For instruction-tuned models: add control tokens. - prompt_string = "user\n" + prompt_string + - "\nmodel\n"; - if (abs_pos > 0) { - // Prepend "" token if this is a multi-turn dialogue - // continuation. - prompt_string = "\n" + prompt_string; - } - } - - HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); - - // For both pre-trained and instruction-tuned models: prepend "" token - // if needed. - if (abs_pos == 0) { - prompt.insert(prompt.begin(), 2); - } - - prompt_size = prompt.size(); - - std::cerr << std::endl << "[ Reading prompt ] " << std::flush; - - const double time_start = hwy::platform::Now(); - GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, - stream_token, accept_token, gen, verbosity); - const double time_end = hwy::platform::Now(); - const double tok_sec = current_pos / (time_end - time_start); - if (verbosity >= 2) { - std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" - << std::endl - << tok_sec << " tokens / sec" << std::endl; - } - std::cout << std::endl << std::endl; - } - std::cout - << "max_tokens (" << args.max_tokens - << ") exceeded. Use a larger value if desired using the --max_tokens " - << "command line flag.\n"; -} - -void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { - PROFILER_ZONE("Run.misc"); - - hwy::ThreadPool inner_pool(0); - hwy::ThreadPool pool(app.num_threads); - // For many-core, pinning threads to cores helps. - if (app.num_threads > 10) { - PinThreadToCore(app.num_threads - 1); // Main thread - - pool.Run(0, pool.NumThreads(), - [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); - } - - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, - loader.ModelType(), pool); - - auto kv_cache = CreateKVCache(loader.ModelType()); - - if (const char* error = inference.Validate()) { - ShowHelp(loader, inference, app); - HWY_ABORT("\nInvalid args: %s", error); - } - - if (app.verbosity >= 1) { - const std::string instructions = - "*Usage*\n" - " Enter an instruction and press enter (%C resets conversation, " - "%Q quits).\n" + - (inference.multiturn == 0 - ? std::string(" Since multiturn is set to 0, conversation will " - "automatically reset every turn.\n\n") - : "\n") + - "*Examples*\n" - " - Write an email to grandma thanking her for the cookies.\n" - " - What are some historical attractions to visit around " - "Massachusetts?\n" - " - Compute the nth fibonacci number in javascript.\n" - " - Write a standup comedy bit about GPU programming.\n"; - - std::cout << "\033[2J\033[1;1H" // clear screen - << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, inference, app); - std::cout << "\n" << instructions << "\n"; - } - - ReplGemma( - model, kv_cache, pool, inner_pool, inference, app.verbosity, - /*accept_token=*/[](int) { return true; }, app.eot_line); -} - -} // namespace gcpp - -int main(int argc, char** argv) { - { - PROFILER_ZONE("Startup.misc"); - - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); - - if (gcpp::HasHelp(argc, argv)) { - ShowHelp(loader, inference, app); - return 0; - } - - if (const char* error = loader.Validate()) { - ShowHelp(loader, inference, app); - HWY_ABORT("\nInvalid args: %s", error); - } - - gcpp::Run(loader, inference, app); - } - PROFILER_PRINT_RESULTS(); // Must call outside the zone above. - return 0; -}