diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..2bd8588 --- /dev/null +++ b/.clang-format @@ -0,0 +1,2 @@ +Language: Cpp +BasedOnStyle: Google diff --git a/.gitignore b/.gitignore index b5d1648..551b2a3 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,7 @@ cython_debug/ #.idea/ # Vscode -.vscode/ \ No newline at end of file +.vscode/ + +#p Precommit +.pre-commit-config.yaml diff --git a/CMakeLists.txt b/CMakeLists.txt index 8414cbc..60b0238 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG e29cd566cf3367671e8f59419a04e308796a7c57) +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG 8fb44ed6dd123f63dca95c20c561e8ca1de511d7) FetchContent_MakeAvailable(gemma) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) @@ -30,3 +30,5 @@ FetchContent_GetProperties(gemma) FetchContent_GetProperties(sentencepiece) target_include_directories(pygemma PRIVATE ${gemma_SOURCE_DIR}) target_include_directories(pygemma PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/DEVELOPERS.md b/DEVELOPERS.md new file mode 100644 index 0000000..7c761a6 --- /dev/null +++ b/DEVELOPERS.md @@ -0,0 +1,2 @@ +## 🤝 Contributing +Contributions are welcome. Please clone the repository, push your changes to a new branch, and submit a pull request. diff --git a/requirements.txt b/requirements.txt index 59a8043..e69de29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +0,0 @@ -pybind11 -pre-commit diff --git a/setup.py b/setup.py index 57be3dd..9bd0c0b 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import sys from setuptools import setup, find_packages, Extension from setuptools.command.build_ext import build_ext +import platform class CMakeExtension(Extension): @@ -39,7 +40,7 @@ def build_extension(self, ext): "--", "-j", "12", - ] # Specifies the number of jobs to run simultaneously + ] if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) @@ -58,7 +59,7 @@ def build_extension(self, ext): version="0.1.2", author="Nam Tran", author_email="namtran.ase@gmail.com", - description="A Python package with a C++ backend using gemma.", + description="A Python package with a C++ backend using gemma.cpp", long_description=""" This package provides Python bindings to a C++ library using pybind11. """, diff --git a/src/gemma_binding.cpp b/src/gemma_binding.cpp index 53607db..8baa3e1 100644 --- a/src/gemma_binding.cpp +++ b/src/gemma_binding.cpp @@ -1,396 +1,273 @@ #include #include -// #include "gemma.h" // Adjust include path as necessary -#include -#include -#include -#include -#include // NOLINT -#include - -#include "compression/compress.h" -#include "gemma.h" // Gemma -#include "util/app.h" -#include "util/args.h" // HasHelp -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" -#include "hwy/per_target.h" -#include "hwy/profiler.h" -#include "hwy/timer.h" +#include "gemma_binding.h" namespace py = pybind11; -namespace gcpp -{ - - void ShowHelp(gcpp::LoaderArgs &loader, gcpp::InferenceArgs &inference, - gcpp::AppArgs &app) - { - fprintf(stderr, - "\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to " - "specify 3 required model loading arguments: --tokenizer, " - "--compressed_weights, " - "and --model.\n\nModel Loading Arguments\n\n"); - loader.Help(); - fprintf(stderr, "\nInference Arguments\n\n"); - inference.Help(); - fprintf(stderr, "\nApplication Arguments\n\n"); - app.Help(); - fprintf(stderr, "\n\n"); - } +static constexpr std::string_view kAsciiArtBanner = + " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" + " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" + "| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n" + " \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n" + " __/ | | | | |\n" + " |___/ |_| |_|"; + +void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { + loader.Print(app.verbosity); + inference.Print(app.verbosity); + app.Print(app.verbosity); + + if (app.verbosity >= 2) { + time_t now = time(nullptr); + char* dt = ctime(&now); // NOLINT + std::cout << "Date & Time : " << dt + << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize + << "\n" + << "Hardware concurrency : " + << std::thread::hardware_concurrency() << std::endl + << "Instruction set : " + << hwy::TargetName(hwy::DispatchedTarget()) << " (" + << hwy::VectorBytes() * 8 << " bits)" << "\n" + << "Compiled config : " << CompiledConfig() << "\n" + << "Weight Type : " + << gcpp::TypeName(gcpp::WeightT()) << "\n" + << "EmbedderInput Type : " + << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; + } +} - void ShowConfig(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app) - { - loader.Print(app.verbosity); - inference.Print(app.verbosity); - app.Print(app.verbosity); - - if (app.verbosity >= 2) - { - time_t now = time(nullptr); - char *dt = ctime(&now); // NOLINT - std::cout << "Date & Time : " << dt - << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize - << "\n" - << "Hardware concurrency : " - << std::thread::hardware_concurrency() << std::endl - << "Instruction set : " - << hwy::TargetName(hwy::DispatchedTarget()) << " (" - << hwy::VectorBytes() * 8 << " bits)" - << "\n" - << "Weight Type : " - << gcpp::TypeName(gcpp::WeightT()) << "\n" - << "EmbedderInput Type : " - << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; - } - } +void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, + gcpp::AppArgs& app) { + std::cerr + << kAsciiArtBanner + << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" + "==========================================================\n\n" + "To run gemma.cpp, you need to " + "specify 3 required model loading arguments:\n --tokenizer\n " + "--compressed_weights\n" + " --model.\n"; + std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " + "--compressed_weights 2b-it-sfp.sbs --model 2b-it\n"; + std::cerr << "\n*Model Loading Arguments*\n\n"; + loader.Help(); + std::cerr << "\n*Inference Arguments*\n\n"; + inference.Help(); + std::cerr << "\n*Application Arguments*\n\n"; + app.Help(); + std::cerr << "\n"; +} - void ReplGemma(gcpp::Gemma &model, hwy::ThreadPool &pool, - hwy::ThreadPool &inner_pool, const InferenceArgs &args, - int verbosity, const gcpp::AcceptFunc &accept_token) - { - PROFILER_ZONE("Gen.misc"); - int abs_pos = 0; // absolute token index over all turns - int current_pos = 0; // token index within the current turn - int prompt_size{}; - - std::mt19937 gen; - if (args.deterministic) - { - gen.seed(42); +void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { + PROFILER_ZONE("Gen.misc"); + int abs_pos = 0; // absolute token index over all turns + int current_pos = 0; // token index within the current turn + int prompt_size{}; + + std::mt19937 gen; + if (args.deterministic) { + gen.seed(42); + } else { + std::random_device rd; + gen.seed(rd()); + } + + // callback function invoked for each generated token. + auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, + tokenizer = model.Tokenizer(), + verbosity](int token, float) { + ++abs_pos; + ++current_pos; + if (current_pos < prompt_size) { + std::cerr << "." << std::flush; + } else if (token == gcpp::EOS_ID) { + if (!args.multiturn) { + abs_pos = 0; + if (args.deterministic) { + gen.seed(42); } - else - { - std::random_device rd; - gen.seed(rd()); + } + if (verbosity >= 2) { + std::cout << "\n[ End ]\n"; + } + } else { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + // +1 since position is incremented above + if (current_pos == prompt_size + 1) { + // first token of response + token_text.erase(0, token_text.find_first_not_of(" \t\n")); + if (verbosity >= 1) { + std::cout << std::endl << std::endl; } - - // callback function invoked for each generated token. - auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, - tokenizer = &model.Tokenizer(), - verbosity](int token, float) - { - ++abs_pos; - ++current_pos; - if (current_pos < prompt_size) - { - std::cerr << "." << std::flush; - } - else if (token == gcpp::EOS_ID) - { - if (!args.multiturn) - { - abs_pos = 0; - if (args.deterministic) - { - gen.seed(42); - } - } - if (verbosity >= 2) - { - std::cout << "\n[ End ]" << std::endl; - } - } - else - { - std::string token_text; - HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); - // +1 since position is incremented above - if (current_pos == prompt_size + 1) - { - // first token of response - token_text.erase(0, token_text.find_first_not_of(" \t\n")); - if (verbosity >= 1) - { - std::cout << std::endl - << std::endl; - } - } - // TODO(austinvhuang): is explicit space necessary? - std::cout << token_text << std::flush; - } - return true; - }; - - while (abs_pos < args.max_tokens) - { - std::string prompt_string; - std::vector prompt; - current_pos = 0; - { - PROFILER_ZONE("Gen.input"); - if (verbosity >= 1) - { - std::cout << "> " << std::flush; - } - std::getline(std::cin, prompt_string); - } - - if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") - { - return; - } - - if (model.model_training == ModelTraining::GEMMA_IT) - { - // For instruction-tuned models: add control tokens. - prompt_string = "user\n" + prompt_string + - "\nmodel\n"; - if (abs_pos > 0) - { - // Prepend "" token if this is a multi-turn dialogue - // continuation. - prompt_string = "\n" + prompt_string; - } - } - - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); - - // For both pre-trained and instruction-tuned models: prepend "" token - // if needed. - if (abs_pos == 0) - { - prompt.insert(prompt.begin(), 2); - } - - prompt_size = prompt.size(); - - std::cerr << std::endl - << "[ Reading prompt ] " << std::flush; - - const double time_start = hwy::platform::Now(); - GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); - const double time_end = hwy::platform::Now(); - const double tok_sec = current_pos / (time_end - time_start); - if (verbosity >= 2) - { - std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" - << std::endl - << tok_sec << " tokens / sec" << std::endl; - } - std::cout << std::endl - << std::endl; - } - std::cout - << "max_tokens (" << args.max_tokens - << ") exceeded. Use a larger value if desired using the --max_tokens " - << "command line flag.\n"; + } + std::cout << token_text << std::flush; } + return true; + }; - void Run(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app) + while (abs_pos < args.max_tokens) { + std::string prompt_string; + std::vector prompt; + current_pos = 0; { - PROFILER_ZONE("Run.misc"); - - hwy::ThreadPool inner_pool(0); - hwy::ThreadPool pool(app.num_threads); - // For many-core, pinning threads to cores helps. - if (app.num_threads > 10) - { - PinThreadToCore(app.num_threads - 1); // Main thread - - pool.Run(0, pool.NumThreads(), - [](uint64_t /*task*/, size_t thread) - { PinThreadToCore(thread); }); - } - - gcpp::Gemma model(loader, pool); - - if (const char *error = inference.Validate()) - { - ShowHelp(loader, inference, app); - HWY_ABORT("\nInvalid args: %s", error); + PROFILER_ZONE("Gen.input"); + if (verbosity >= 1) { + std::cout << "> " << std::flush; + } + + if (eot_line.size() == 0) { + std::getline(std::cin, prompt_string); + } else { + std::string line; + while (std::getline(std::cin, line)) { + if (line == eot_line) { + break; + } + prompt_string += line + "\n"; } - - if (app.verbosity >= 1) - { - static const std::string banner_ascii_art = - " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" - " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" - "| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n" - " \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n" - " __/ | | | | |\n" - " |___/ |_| |_|"; - - const std::string instructions = - "*Usage*\n" - " Enter an instruction and press enter (%Q quits).\n\n" - "*Examples*\n" - " - Write an email to grandma thanking her for the cookies.\n" - " - What are some historical attractions to visit around " - "Massachusetts?\n" - " - Compute the nth fibonacci number in javascript.\n" - " - Write a standup comedy bit about GPU programming.\n"; - - std::cout << "\033[2J\033[1;1H" // clear screen - << banner_ascii_art << "\n\n"; - ShowConfig(loader, inference, app); - std::cout << "\n" - << instructions << "\n"; - } - - ReplGemma(model, pool, inner_pool, inference, app.verbosity, - /*accept_token=*/[](int) - { return true; }); + } } - std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool, - hwy::ThreadPool &inner_pool, const InferenceArgs &args, - int verbosity, const gcpp::AcceptFunc &accept_token, std::string &prompt_string) - { - std::string generated_text; - // Seed the random number generator - std::random_device rd; - std::mt19937 gen(rd()); - int prompt_size{}; - if (model.model_training == ModelTraining::GEMMA_IT) - { - // For instruction-tuned models: add control tokens. - prompt_string = "user\n" + prompt_string + - "\nmodel\n"; - } - // Encode the prompt string into tokens - std::vector prompt; - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); - // Placeholder for generated token IDs - std::vector generated_tokens; - // Define lambda for token decoding - StreamFunc stream_token = [&generated_tokens](int token, float /* probability */) -> bool { - generated_tokens.push_back(token); - return true; // Continue generating - }; - // Decode tokens - prompt_size = prompt.size(); - GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token, accept_token, gen, verbosity); - HWY_ASSERT(model.Tokenizer().Decode(generated_tokens, &generated_text).ok()); - generated_text = generated_text.substr(prompt_string.size()); - - return generated_text; + if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { + return; } - std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, std::string &prompt_string) - { - hwy::ThreadPool inner_pool(0); - hwy::ThreadPool pool(app.num_threads); - if (app.num_threads > 10) - { - PinThreadToCore(app.num_threads - 1); // Main thread - - pool.Run(0, pool.NumThreads(), - [](uint64_t /*task*/, size_t thread) - { PinThreadToCore(thread); }); - } - gcpp::Gemma model(loader, pool); - return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) - { return true; }, prompt_string); - + if (prompt_string == "%c" || prompt_string == "%C") { + abs_pos = 0; + continue; } -} // namespace gcpp + if (model.model_training == ModelTraining::GEMMA_IT) { + // For instruction-tuned models: add control tokens. + prompt_string = "user\n" + prompt_string + + "\nmodel\n"; + if (abs_pos > 0) { + // Prepend "" token if this is a multi-turn dialogue + // continuation. + prompt_string = "\n" + prompt_string; + } + } -void chat_base(int argc, char **argv) -{ - { - PROFILER_ZONE("Startup.misc"); + HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); + // For both pre-trained and instruction-tuned models: prepend "" token + // if needed. + if (abs_pos == 0) { + prompt.insert(prompt.begin(), 2); + } - if (gcpp::HasHelp(argc, argv)) - { - ShowHelp(loader, inference, app); - // return 0; - } + prompt_size = prompt.size(); - if (const char *error = loader.Validate()) - { - ShowHelp(loader, inference, app); - HWY_ABORT("\nInvalid args: %s", error); - } + std::cerr << std::endl << "[ Reading prompt ] " << std::flush; - gcpp::Run(loader, inference, app); + const double time_start = hwy::platform::Now(); + GenerateGemma(model, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, + stream_token, accept_token, gen, verbosity); + const double time_end = hwy::platform::Now(); + const double tok_sec = current_pos / (time_end - time_start); + if (verbosity >= 2) { + std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" + << std::endl + << tok_sec << " tokens / sec" << std::endl; } - PROFILER_PRINT_RESULTS(); // Must call outside the zone above. - // return 1; -} -std::string completion_base(int argc, char **argv) -{ - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); - std::string prompt_string = argv[argc-1]; - return gcpp::completion(loader, inference, app, prompt_string); + std::cout << std::endl << std::endl; + } + std::cout + << "max_tokens (" << args.max_tokens + << ") exceeded. Use a larger value if desired using the --max_tokens " + << "command line flag.\n"; } -std::string completion_base_wrapper(const std::vector &args,std::string &prompt_string) -{ - int argc = args.size() + 2; // +1 for the program name - std::vector argv_vec; - argv_vec.reserve(argc); - argv_vec.push_back(const_cast("pygemma")); - - for (const auto &arg : args) - { - argv_vec.push_back(const_cast(arg.c_str())); - } - argv_vec.push_back(const_cast(prompt_string.c_str())); - char **argv = argv_vec.data(); - return completion_base(argc, argv); +void GemmaWrapper::loadModel(const std::vector &args) { + int argc = args.size() + 1; // +1 for the program name + std::vector argv_vec; + argv_vec.reserve(argc); + argv_vec.push_back(const_cast("pygemma")); + for (const auto &arg : args) + { + argv_vec.push_back(const_cast(arg.c_str())); + } + + char **argv = argv_vec.data(); + + this->m_loader = gcpp::LoaderArgs(argc, argv); + this->m_inference = gcpp::InferenceArgs(argc, argv); + this->m_app = gcpp::AppArgs(argc, argv); + + PROFILER_ZONE("Run.misc"); + + hwy::ThreadPool inner_pool(0); + hwy::ThreadPool pool(this->m_app.num_threads); + // For many-core, pinning threads to cores helps. + if (this->m_app.num_threads > 10) { + PinThreadToCore(this->m_app.num_threads - 1); // Main thread + + pool.Run(0, pool.NumThreads(), + [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); + } + + if (!this->m_model) { + this->m_model.reset(new gcpp::Gemma(this->m_loader.tokenizer, this->m_loader.compressed_weights, this->m_loader.ModelType(), pool)); + } +// auto kvcache = CreateKVCache(loader.ModelType()); + this->m_kvcache = CreateKVCache(this->m_loader.ModelType()); + + if (const char* error = this->m_inference.Validate()) { + ShowHelp(this->m_loader, this->m_inference, this->m_app); + HWY_ABORT("\nInvalid args: %s", error); + } + + if (this->m_app.verbosity >= 1) { + const std::string instructions = + "*Usage*\n" + " Enter an instruction and press enter (%C resets conversation, " + "%Q quits).\n" + + (this->m_inference.multiturn == 0 + ? std::string(" Since multiturn is set to 0, conversation will " + "automatically reset every turn.\n\n") + : "\n") + + "*Examples*\n" + " - Write an email to grandma thanking her for the cookies.\n" + " - What are some historical attractions to visit around " + "Massachusetts?\n" + " - Compute the nth fibonacci number in javascript.\n" + " - Write a standup comedy bit about GPU programming.\n"; + + std::cout << "\033[2J\033[1;1H" // clear screen + << kAsciiArtBanner << "\n\n"; + ShowConfig(this->m_loader, this->m_inference, this->m_app); + std::cout << "\n" << instructions << "\n"; + } } -void show_help_wrapper() -{ - // Assuming ShowHelp does not critically depend on argv content - gcpp::LoaderArgs loader(0, nullptr); - gcpp::InferenceArgs inference(0, nullptr); - gcpp::AppArgs app(0, nullptr); - - ShowHelp(loader, inference, app); -} - -std::string chat_base_wrapper(const std::vector &args) -{ - int argc = args.size() + 1; // +1 for the program name - std::vector argv_vec; - argv_vec.reserve(argc); - argv_vec.push_back(const_cast("pygemma")); - for (const auto &arg : args) - { - argv_vec.push_back(const_cast(arg.c_str())); - } - - char **argv = argv_vec.data(); +void GemmaWrapper::showConfig() { + ShowConfig(this->m_loader,this->m_inference, this->m_app); +} - chat_base(argc, argv); +void GemmaWrapper::showHelp() { + ShowHelp(this->m_loader,this->m_inference, this->m_app); } -PYBIND11_MODULE(pygemma, m) -{ - m.doc() = "Pybind11 integration for chat_base function"; - m.def("chat_base", &chat_base_wrapper, "A wrapper for the chat_base function accepting Python list of strings as arguments"); - m.def("show_help", &show_help_wrapper, "A wrapper for show_help function"); - m.def("completion", &completion_base_wrapper, "A wrapper for inference function"); +PYBIND11_MODULE(pygemma, m) { + py::class_(m, "Gemma") + .def(py::init<>()) + .def("show_config", &GemmaWrapper::showConfig) + .def("show_help", &GemmaWrapper::showHelp) + .def("load_model", [](GemmaWrapper &self, + const std::string &tokenizer, + const std::string &compressed_weights, + const std::string &model) { + std::vector args = { + "--tokenizer", tokenizer, + "--compressed_weights", compressed_weights, + "--model", model + }; + self.loadModel(args); // Assuming GemmaWrapper::loadModel accepts std::vector + }, py::arg("tokenizer"), py::arg("compressed_weights"), py::arg("model")) + .def("completion", &GemmaWrapper::completionPrompt); } diff --git a/src/gemma_binding.h b/src/gemma_binding.h new file mode 100644 index 0000000..7f5da9c --- /dev/null +++ b/src/gemma_binding.h @@ -0,0 +1,46 @@ +#pragma once +// Command line text interface to gemma. + +#include +#include +#include +#include +#include // NOLINT +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "gemma.h" // Gemma +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // HasHelp +// copybara:end +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/per_target.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +using namespace gcpp; + +class GemmaWrapper { + public: + // GemmaWrapper(); + void loadModel(const std::vector &args); // Consider exception safety + void showConfig(); + void showHelp(); + std::string completionPrompt(); + + private: + gcpp::LoaderArgs m_loader = gcpp::LoaderArgs(0, nullptr); + gcpp::InferenceArgs m_inference = gcpp::InferenceArgs(0, nullptr); + gcpp::AppArgs m_app = gcpp::AppArgs(0, nullptr); + std::unique_ptr m_model; + KVCache m_kvcache; +}; diff --git a/tests/test_chat.py b/tests/test_chat.py index 43795c7..d2b84f9 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,5 +1,5 @@ import argparse -import pygemma +from pygemma import Gemma def main(): @@ -19,36 +19,19 @@ def main(): "--model", type=str, required=True, help="Model type identifier." ) parser.add_argument( - "--input", type=str, required=False, help="Input text to chat with the model. If None, Switch to Chat mode.", - default="Hello." + "--input", + type=str, + required=False, + help="Input text to chat with the model. If None, Switch to Chat mode.", + default="Hello.", ) # Now using the parsed arguments args = parser.parse_args() - if args.input is not None: - string = pygemma.completion( - [ - "--tokenizer", - args.tokenizer, - "--compressed_weights", - args.compressed_weights, - "--model", - args.model, - ], args.input - ) - print(string) - else: - return pygemma.chat_base( - [ - "--tokenizer", - args.tokenizer, - "--compressed_weights", - args.compressed_weights, - "--model", - args.model, - ] - ) - # Optionally, show help if needed - # pygemma.show_help() + + gemma = Gemma() + gemma.show_config() + gemma.show_help() + gemma.load_model(args.tokenizer, args.compressed_weights, args.model) if __name__ == "__main__":