Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ cc_library(
],
)

cc_library(
name = "weights_raw",
hdrs = ["gemma/weights_raw.h"],
deps = [
":common",
":weights",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:thread_pool",
],
)

cc_library(
name = "gemma_lib",
srcs = [
Expand Down Expand Up @@ -214,6 +226,7 @@ cc_binary(
":common",
":gemma_lib",
":weights",
":weights_raw",
# Placeholder for internal dep, do not remove.,
"//compression:compress",
"@hwy//:hwy",
Expand Down Expand Up @@ -331,7 +344,7 @@ cc_library(
":common",
":gemma_lib",
":prompt",
":weights",
":weights_raw",
],
)

Expand All @@ -346,7 +359,7 @@ cc_test(
":backprop_scalar",
":prompt",
":sampler",
":weights",
":weights_raw",
"@googletest//:gtest_main",
],
)
Expand All @@ -363,11 +376,9 @@ cc_test(
":backprop_scalar",
":gemma_lib",
":ops",
":prompt",
":sampler",
":weights",
":weights_raw",
"@googletest//:gtest_main",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:thread_pool",
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ set(SOURCES
gemma/ops.h
gemma/weights.cc
gemma/weights.h
gemma/weights_raw.h
util/app.h
util/args.h
)
Expand Down
1 change: 0 additions & 1 deletion backprop/backward-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <stddef.h>

#include <algorithm>
#include <array>
#include <cmath>

#include "backprop/prompt.h"
Expand Down
2 changes: 0 additions & 2 deletions backprop/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_

#include <vector>

#include "backprop/prompt.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
Expand Down
3 changes: 1 addition & 2 deletions backprop/backward_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
#include <string.h>

#include <cmath>
#include <complex>
#include <vector>

#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"
#include "gemma/weights_raw.h"

namespace gcpp {
template<typename T>
Expand Down
12 changes: 8 additions & 4 deletions backprop/backward_scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

#include "backprop/backward_scalar.h"

#include <stddef.h>
#include <string.h> // memset

#include <array>
#include <complex>
#include <random>
Expand All @@ -23,6 +26,7 @@
#include "backprop/forward_scalar.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "gemma/weights_raw.h"

namespace gcpp {

Expand Down Expand Up @@ -55,8 +59,8 @@ TEST(BackPropTest, MatMulVJP) {
memset(&grad, 0, sizeof(grad));
MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
kRows, kCols, kTokens);
TestGradient(dx, c_x, func, 1e-11, 1e-12,__LINE__);
TestGradient(grad, c_weights, func, 1e-14, 1e-12,__LINE__);
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__);
TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__);
}
}

Expand Down Expand Up @@ -91,8 +95,8 @@ TEST(BackPropTest, MultiHeadMatMulVJP) {
memset(&grad, 0, sizeof(grad));
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(),
dx.data(), kHeads, kRows, kCols, kTokens);
TestGradient(dx, c_x, func, 1e-15, 1e-13,__LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-13,__LINE__);
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__);
}
}

Expand Down
9 changes: 2 additions & 7 deletions backprop/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

#include <stddef.h>

#include <algorithm>
#include <array>
#include <complex>
#include <random>
Expand All @@ -29,10 +28,8 @@
#include "backprop/forward_scalar.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "compression/compress.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "hwy/aligned_allocator.h"
#include "gemma/activations.h"
#include "gemma/weights_raw.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

Expand All @@ -52,8 +49,6 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {

namespace hn = hwy::HWY_NAMESPACE;

void TestMatMulVJP() {
static const size_t kRows = 8;
static const size_t kCols = 64;
Expand Down
9 changes: 4 additions & 5 deletions backprop/forward-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
#include <stddef.h>
#include <stdint.h>

#include <array>
#include <cmath>
#include <vector>

#include "gemma/activations.h"
#include "gemma/common.h"
Expand All @@ -40,11 +40,11 @@
#endif

#include "gemma/ops.h"
#include "hwy/highway.h"

HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;

template <typename ArrayT>
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
Expand Down Expand Up @@ -202,11 +202,10 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
DF df;
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
const auto y = Load(df, out + i);
const auto x = Load(df, out_mul + i);
const auto y = hn::Load(df, out + i);
const auto x = hn::Load(df, out_mul + i);
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
}
}
Expand Down
2 changes: 1 addition & 1 deletion backprop/forward_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"
#include "gemma/weights_raw.h"

namespace gcpp {

Expand Down
33 changes: 3 additions & 30 deletions backprop/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,16 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_

#include <stddef.h>

#include <array>
#include <complex>
#include <random>

#include "gemma/weights.h"
#include "gtest/gtest.h"
#include "gemma/weights_raw.h"

namespace gcpp {

template<typename T, size_t kLen>
void RandInit(std::array<T, kLen>& x, T stddev, std::mt19937& gen) {
std::normal_distribution<T> dist(0.0, stddev);
for (size_t i = 0; i < kLen; ++i) {
x[i] = dist(gen);
}
}

template<typename T, typename TConfig>
void RandInit(Layer<T, TConfig>& w, T stddev, std::mt19937& gen) {
RandInit(w.pre_attention_norm_scale, stddev, gen);
RandInit(w.attn_vec_einsum_w, stddev, gen);
RandInit(w.qkv_einsum_w, stddev, gen);
RandInit(w.pre_ffw_norm_scale, stddev, gen);
RandInit(w.gating_einsum_w, stddev, gen);
RandInit(w.linear_w, stddev, gen);
}

template<typename T, typename TConfig>
void RandInit(Weights<T, TConfig>& w, T stddev, std::mt19937& gen) {
static constexpr size_t kLayers = TConfig::kLayers;
RandInit(w.embedder_input_embedding, stddev, gen);
RandInit(w.final_norm_scale, stddev, gen);
for (size_t i = 0; i < kLayers; ++i) {
RandInit(*w.GetLayer(i), stddev, gen);
}
}

template<typename T, typename U, size_t kLen>
void Complexify(const std::array<T, kLen>& x,
std::array<std::complex<U>, kLen>& c_x) {
Expand Down
4 changes: 3 additions & 1 deletion gemma/compress_weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "compression/io.h" // Path
#include "gemma/common.h" // Model
#include "gemma/weights.h"
#include "gemma/weights_raw.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
Expand Down Expand Up @@ -317,7 +318,8 @@ void CompressWeights(const Path& weights_path,
WeightsF<TConfig>* weights =
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
Compressor compressor(pool);
ForEachTensor<TConfig>(weights, *c_weights, compressor);
ForEachTensor</*kHaveRaw=*/true, TConfig, LayerF<TConfig>>(
weights, *c_weights, compressor);
compressor.AddScales(weights->scales.data(), weights->scales.size());
compressor.WriteAll(pool, compressed_weights_path);

Expand Down
5 changes: 2 additions & 3 deletions gemma/gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ constexpr bool kShowTokenization = false;
// Must be aligned.
template <class TConfig, size_t kBatchSize>
struct Activations {
using LayerConfig = LayerF<TConfig>;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
Expand Down Expand Up @@ -979,8 +978,8 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
}

Gemma::~Gemma() {
CallForModelAndWeight<DeleteLayersPtrs>(model_type_, weight_type_,
weights_u8_);
CallForModelAndWeight<DeleteCompressedWeights>(model_type_, weight_type_,
weights_u8_);
}

void Gemma::Generate(const RuntimeConfig& runtime_config,
Expand Down
4 changes: 2 additions & 2 deletions gemma/weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

#include "gemma/weights.h"

#include <algorithm>
#include <cstdlib>

#include "compression/compress.h"
Expand Down Expand Up @@ -47,7 +46,8 @@ struct LoadCompressedWeightsT {

std::array<float, TConfig::kNumTensorScales> scales;
CacheLoader loader(weights);
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
const void* raw_weights = nullptr; // ForEachTensor requires const.
ForEachTensor</*kHaveRaw=*/false, TConfig>(raw_weights, *c_weights, loader);
loader.LoadScales(scales.data(), scales.size());
if (!loader.ReadAll(pool)) {
HWY_ABORT("Failed to load model weights.");
Expand Down
Loading