Tiling a Multi-Function Program
Finally, we can start rewriting tiling programs and producing fused variants.
Let's go back to the annotation function that we saw previously.
Annotation getAnnotation() override {
Variable x("x");
return annotate(
// Tilable tells us what dimensions of the output
// can be tiled.
Tileable(x = Expr(0), output["size"], step,
// Produces indicates what subset (pointed at
// by the fields we saw in the last chapter)
// of the output we are describing.
Produces::Subset(output, {x, step}),
// What subsets of the input do we need to
// consume to produce the required output subset?
Consumes::Subset(input, {x, step})));
}
Now to generate a fused version, we simply tile the output's dimension! As expected, only fields marked as tileable can be tiled.
#include "helpers.h"
#include "library/array/impl/cpu-array.h"
using namespace gern;
int main() {
// ***** PROGRAM DEFINITION *****
// Our first, simple program.
auto input = mk_array("input");
auto temp = mk_array("temp");
auto output = mk_array("output");
Variable t("t");
gern::annot::add_1 add_1;
Composable program({
// Tile the output!
// t indicates what the tile
// size should be
Tile(output["size"], t)(
add_1(input, temp),
add_1(temp, output)),
});
// ***** PROGRAM EVALUATION *****
library::impl::ArrayCPU a(10);
a.ascending();
library::impl::ArrayCPU b(10);
int64_t t_val = 2;
auto runner = compile_program(program);
runner.evaluate(
{
{"output", &b},
{"input", &a},
{"t", &t_val}, // <-- t must also be provided now!
});
// SANITY CHECK
for (int i = 0; i < 10; i++) {
assert(a.data[i] + 2 == b.data[i]);
}
}
Finally, we can look at the generated tiled program:
void function_7(library::impl::ArrayCPU &input, library::impl::ArrayCPU &output, int64_t t) {
for (int64_t _gern_x_2_6 = 0; (_gern_x_2_6 < output.size); _gern_x_2_6 = (_gern_x_2_6 + t)) {
auto _query_output_8 = output.query(_gern_x_2_6, t);
library::impl::ArrayCPU temp = library::impl::ArrayCPU::allocate(_gern_x_2_6, t);
auto _query_input_9 = input.query(_gern_x_2_6, t);
library::impl::add_1(_query_input_9, temp);
library::impl::add_1(temp, _query_output_8);
temp.destroy();
}
}