FFT

画像処理・信号処理の分野で幅広く使われているFFT をHalide で実装しました。

主な仕様

  • 入力: 32bit 浮動小数点数配列
  • 出力: 32bit 浮動小数点数配列
  • FPGA サイクル数: 0.25 element/cycle
  • FPGA ゲート数:
    • BRAM: 61
    • DSP: 204
    • FF: 27705
    • LUT: 20361

ソースコード

#include <Halide.h>
#include <Element.h>

using namespace Halide;
using namespace Halide::Element;

class FFT : public Halide::Generator<FFT> {
    Var c{"c"}, i{"i"}, k{"k"};

    GeneratorParam<int32_t> n_{"n", 256};
    GeneratorParam<int32_t> batch_size_{"batch_size", 4};
    ImageParam in{Float(32), 3, "in"};

public:
    Func build()
    {
        const int32_t n = static_cast<int32_t>(n_);
        const int32_t batch_size = static_cast<int32_t>(batch_size_);

        Func weight("weight");
        Expr theta = static_cast<float>(-2.0 * M_PI) * cast<float>(i) / static_cast<float>(n);
        weight(c, i) = select(c ==0, cos(theta), sin(theta));
        
        Func stage("in");
        stage(c, i, k) = in(c, i, k);

        for (int j=0; j<log2(n); ++j) {

            stage = BoundaryConditions::repeat_edge(stage, {{0, 2}, {0, n}, {0, batch_size}});

            Func next_stage("stage" + std::to_string(j));

            const int m = (n >> (j + 1));

            Expr cond = (i % (n >> j)) < m;

            Expr o = select(cond, i + m, i - m);

            ComplexExpr vi = {stage(0, i, k), stage(1, i, k)};
            ComplexExpr vo = {stage(0, o, k), stage(1, o, k)};

            // Case 1
            ComplexExpr v1 = vi + vo;

            // Case 2
            Expr wi = (i % m) * (1<<j);
            ComplexExpr w = {weight(0, wi), weight(1, wi)};
            ComplexExpr v2 = (vo - vi) * w;
            next_stage(c, i, k) = select(cond, select(c == 0, v1.x, v1.y),
                                               select(c == 0, v2.x, v2.y));

            schedule(next_stage, {2, n, batch_size}).unroll(c);

            stage = next_stage;
        }

        // Make bit-reversal 32-bit integer index
        Expr ri = cast<uint32_t>(i);
        ri = (ri & E(0x55555555)) <<  1 | (ri & E(0xAAAAAAAA)) >>  1;
        ri = (ri & E(0x33333333)) <<  2 | (ri & E(0xCCCCCCCC)) >>  2;
        ri = (ri & E(0x0F0F0F0F)) <<  4 | (ri & E(0xF0F0F0F0)) >>  4;
        ri = (ri & E(0x00FF00FF)) <<  8 | (ri & E(0xFF00FF00)) >>  8;
        ri = (ri & E(0x0000FFFF)) << 16 | (ri & E(0xFFFF0000)) >> 16;
        ri = cast<int32_t>(ri >> (32 - log2(n)));

        stage = BoundaryConditions::repeat_edge(stage, {{0, 2}, {0, n}, {0, batch_size}});

        Func out("out");
        out(c, i, k) = stage(c, ri, k);

        schedule(in, {2, n, batch_size});
        schedule(weight, {2, n/2});
        schedule(out, {2, n, batch_size}).unroll(c);
        
        return out;
    }

private:
    Expr E(int32_t v)
    {
        return make_const(UInt(32), static_cast<uint32_t>(v));
    }
};

HALIDE_REGISTER_GENERATOR(FFT, "fft")

 

解説

FFT (Fast Fourier Transform) は離散フーリエ変換を高速に解くためのアルゴリズムであり、古くから画像処理・信号処理において多用されてきました。このコードでは、ベーシックなCooley-Tukey型周波数間引きFFTをHalide DSLで実装しています。

入力は以下の通り、3次元の32ビット浮動小数点数配列です。

ImageParam in{Float(32), 3, "in"};

1次元目は複素数を表すためにのみ用意されており、インデックス0は実数部を、インデックス1は虚数部を表しています。2次元目がデータ系列の数であり、この単位でフーリエ変換を行います。3次元目はバッチサイズで、2次元目で指定した長さのフーリエ変換を何セット行うかを指定します。2次元目、3次元目の大きさは、コンパイル時に決定します。

GeneratorParam<int32_t> n_{"n", 256};
GeneratorParam<int32_t> batch_size_{"batch_size", 4};

FFTの重みは各ステージの計算に先立って計算し、テーブル化しておきます。この計算はFPGA内部で行われ、外部から与える必要はありません。

Func weight("weight");
Expr theta = static_cast<float>(-2.0 * M_PI) * cast<float>(i) / static_cast<float>(n);
weight(c, i) = select(c ==0, cos(theta), sin(theta));

forブロックがFFTの各演算ステージの定義になっています。このfor文はDSLコンパイル時に実行され、DSL定義としては連結されたlog2(n)個のFuncになることに注意してください。このようにC++を使用して簡潔な記述ができるのも、C++ EDSLであるHalide DSLの強みです。

Func stage("in");
stage(c, i, k) = in(c, i, k);
                                                                                       
for (int j=0; j<log2(n); ++j) {
                                                                                       
    stage = BoundaryConditions::repeat_edge(stage, {{0, 2}, {0, n}, {0, batch_size}});
                                                                                       
    Func next_stage("stage" + std::to_string(j));

    <...snip...>

    stage = next_stage;
}

最後に、ビット反転インデックスによるデータの並べ替えを行い、結果を出力します。

Expr ri = cast<uint32_t>(i);
ri = (ri & E(0x55555555)) <<  1 | (ri & E(0xAAAAAAAA)) >>  1;
ri = (ri & E(0x33333333)) <<  2 | (ri & E(0xCCCCCCCC)) >>  2;
ri = (ri & E(0x0F0F0F0F)) <<  4 | (ri & E(0xF0F0F0F0)) >>  4;
ri = (ri & E(0x00FF00FF)) <<  8 | (ri & E(0xFF00FF00)) >>  8;
ri = (ri & E(0x0000FFFF)) << 16 | (ri & E(0xFFFF0000)) >> 16;
ri = cast<int32_t>(ri >> (32 - log2(n)));
                                                                    
stage = BoundaryConditions::repeat_edge(stage, {{0, 2}, {0, n}});
                                                                    
Func out("out");
out(c, i, k) = stage(c, ri, k);

全てのソースコードにアクセスするためには、Githubを、テスト用Linuxイメージの使用法はこちらを参照して下さい。 IPコア形式での提供は近日中の対応を予定しています。

Share on FacebookShare on Google+Tweet about this on TwitterShare on LinkedIn

レビュー

レビューはまだありません。

“FFT” の口コミを投稿します

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

Available Downloads:

ログインするとダウンロードとカスタマイズなどのお問合せができます。