Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions gemma/compress_weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,7 @@ void CompressWeights(const Path& weights_path,
WeightsF<TConfig>* weights =
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
Compressor compressor(pool);
ForEachTensor</*kHaveRaw=*/true, TConfig, LayerF<TConfig>>(
weights, *c_weights, compressor);
ForEachTensor<TConfig, LayerF<TConfig>>(weights, *c_weights, compressor);
compressor.AddScales(weights->scales.data(), weights->scales.size());
compressor.WriteAll(pool, compressed_weights_path);

Expand Down
3 changes: 1 addition & 2 deletions gemma/weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ struct LoadCompressedWeightsT {

std::array<float, TConfig::kNumTensorScales> scales;
CacheLoader loader(weights);
const void* raw_weights = nullptr; // ForEachTensor requires const.
ForEachTensor</*kHaveRaw=*/false, TConfig>(raw_weights, *c_weights, loader);
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
loader.LoadScales(scales.data(), scales.size());
if (!loader.ReadAll(pool)) {
HWY_ABORT("Failed to load model weights.");
Expand Down
16 changes: 9 additions & 7 deletions gemma/weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_

#include <stddef.h>

#include "compression/compress.h"
#include "gemma/common.h"
#include "gemma/configs.h"
Expand Down Expand Up @@ -221,17 +223,17 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
}

// Calls func(name, float*, CompressedArray&) for each tensor. float* is
// null if !kHaveRaw, in which case raw_weights can be nullptr. This happens
// when loading weights from BlobStore. If kHaveRaw, then RawLayer must be
// specified and we pass a float* pointing to the raw float weights for that
// tensor for use by compress_weights.cc.
// null if raw_weights is nullptr, e.g., when loading weights from BlobStore.
// Otherwise, RawLayer must be specified and we pass a float* pointing to the
// raw float weights for that tensor for use by compress_weights.cc.
//
// This avoids repeating the list of tensors between loading and compressing,
// while also avoiding dependency on raw_weights.h.
template <bool kHaveRaw, class TConfig, class RawLayer = void,
class RawWeights = void, class Func>
void ForEachTensor(const RawWeights* raw_weights,
template <class TConfig, class RawLayer = void, class RawWeightsPtr, class Func>
void ForEachTensor(RawWeightsPtr raw_weights,
CompressedWeights<TConfig>& c_weights, Func& func) {
constexpr bool kHaveRaw = !hwy::IsSame<RawWeightsPtr, nullptr_t>();

GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale);

Expand Down