Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 0 additions & 94 deletions gemma/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -753,71 +753,6 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c20, c21, c22, c23, c30, c31, c32, c33, tile_c,
stride_c);
}

// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
// This function loops over all tiles (static scheduling). TODO(janwas): we can
// possibly remove this if ThreadPool(0) is as efficient as the loop.
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatT>
void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
MatT* HWY_RESTRICT C) {
const hn::ScalableTag<MatT> d;
const size_t N = hn::Lanes(d); // column step size
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; // in vectors

static_assert(kRowsAC % kRegRows == 0);
static_assert(kColsBC % kRegCols == 0);
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
constexpr size_t kTilesY = kRowsAC / kRegRows;
constexpr size_t kTilesX = kColsBC / kRegCols;
constexpr size_t kTiles = kTilesX * kTilesY;

constexpr size_t kStrideA = kColsA_RowsB;
constexpr size_t kStrideB = kColsA_RowsB; // B is column-major
constexpr size_t kStrideC = kColsBC;

HWY_UNROLL(1)
for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) {
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
}
}

// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
// This function processes tiles in parallel with a work-stealing thread pool.
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
typename MatTB, typename OutT>
HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B, OutT* HWY_RESTRICT C,
hwy::ThreadPool& pool) {
// Process reg-sized tiles of C in parallel. We currently write C directly,
// which touches more memory than fits in L3. TODO: add another level of loops
// so that we finish one L3-sized piece of C at a time.
const hn::ScalableTag<MatTA> d;
const size_t N = Lanes(d);
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4; // in vectors

static_assert(kRowsAC % kRegRows == 0);
static_assert(kColsBC % kRegCols == 0);
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
const size_t kTilesY = kRowsAC / kRegRows;
const size_t kTilesX = kColsBC / kRegCols;
const size_t kTiles = kTilesX * kTilesY;

constexpr size_t kStrideA = kColsA_RowsB;
constexpr size_t kStrideB = kColsA_RowsB;
constexpr size_t kStrideC = kColsBC;

pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
// Computes the finished product of one 4x4N tile and writes to C.
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
});
}

// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
// NOTE that batch_size is the number of rows of A and C.
Expand Down Expand Up @@ -869,35 +804,6 @@ HWY_NOINLINE void MatMul_4x4_Batch(
});
}

// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kM, size_t kN, size_t kK, typename MatTA, typename MatTB>
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b,
float* HWY_RESTRICT out) {
for (size_t i = 0; i < kM; ++i) {
for (size_t k = 0; k < kN; ++k) {
for (size_t j = 0; j < kK; ++j) {
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
out[i * kK + j] += a1 * b1;
}
}
}
}

template <size_t kM, size_t kN, size_t kK, typename MatTA>
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
const SfpStream* HWY_RESTRICT b_sfp_stream,
float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<SfpStream>::Decompress(d,
/*in_capacity=*/0, b_sfp_stream, 0,
b.get(), kK * kN);
MatMulSlow<kM, kN, kK>(a, b.get(), out);
}

// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB>
Expand Down
54 changes: 5 additions & 49 deletions gemma/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,54 +528,6 @@ void AssertClose(const MatT* HWY_RESTRICT expected,
}
}

template <size_t kM, size_t kN, size_t kK, typename MatTA,
typename MatTB = MatTA>
void TestTiledMatMul() {
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
GenerateMatHeap<MatTA, kM, kN>(0, pool);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
GenerateMatHeap<MatTB, kN, kK>(0, pool);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
GenerateZeroMatHeap<float, kM, kK>(pool);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow_batch =
GenerateZeroMatHeap<float, kM, kK>(pool);

MatMulSlow<kM, kN, kK>(a->data(), b->data(), c_slow->data());
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), c_slow_batch->data());
AssertClose(c_slow->data(), c_slow_batch->data(), kM * kK);

hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
MatMul_4x4<kM, kN, kK>(a->data(), b_trans->data(), c.get(), pool);

AssertClose(c_slow->data(), c.get(), kM * kK);
}

void TestAllTiledMatMul() {
// medium-sized square test
TestTiledMatMul<512, 512, 512, float>();
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, float>();
TestTiledMatMul<512, 512, 512, float, SfpStream>();
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();

// minimal non-square test
TestTiledMatMul<4, 128, 4, float>();
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t, float>();
TestTiledMatMul<32, 128, 32, float, SfpStream>();
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();

// large-scale test
// TODO(philculliton): investigate rounding issues with large matrices.
// Causes test timeout.
// TestTiledMatMul<512, 24576, 3072, float>();
}

template <size_t kM, size_t kN, size_t kK, typename MatTA,
typename MatTB = MatTA>
void TestTiledBatchMatMul() {
Expand Down Expand Up @@ -638,6 +590,11 @@ void TestAllTiledBatchMatMul() {
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();

// large-scale test
// TODO(philculliton): investigate rounding issues with large matrices.
// Causes test timeout.
// TestTiledBatchMatMul<512, 24576, 3072, float>();
}

void TestMatVecAdd() {
Expand Down Expand Up @@ -746,7 +703,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
Expand Down