// Copyright © 2025 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: MIT

// clang-format off
#pragma once

#include <aotriton/config.h>
#include <aotriton/dtypes.h>
#include <aotriton/util.h>
#include <aotriton/runtime.h>
#include <functional>
#include <string>
#include <vector>

namespace AOTRITON_NS::v3::flash {

// Unlike KernelDescription, Operator must have its own parameter class
struct OpAttnBwdParams {
    const TensorView<4>* Q;
    const TensorView<4>* K;
    const TensorView<4>* V;
    const TensorView<4>* B;
    float                sm_scale;
    const TensorView<4>* Out;
    const TensorView<4>* DO;
    const TensorView<4>* DK;
    const TensorView<4>* DV;
    const TensorView<4>* DQ;
    const TensorView<4>* DB;
    const TensorView<2>* L;
    const TensorView<2>* D;
    int32_t              num_head_q;
    int32_t              num_head_k;
    const TensorView<1>* cu_seqlens_q;
    const TensorView<1>* cu_seqlens_k;
    int32_t              num_seqlens;
    int32_t              max_seqlen_q;
    int32_t              max_seqlen_k;
    int32_t              head_dim;
    float                dropout_p;
    const TensorView<0>* philox_seed_ptr;
    const TensorView<0>* philox_offset1;
    uint64_t             philox_offset2;
    int32_t              Window_left;
    int32_t              Window_right;
    int16_t              BLOCK_DMODEL;
    int8_t               CAUSAL_TYPE;
    bool                 ENABLE_DROPOUT;
    bool                 PADDED_HEAD;
    int8_t               BIAS_TYPE;
};

struct OpAttnBwdContext {
    const OpAttnBwdParams *params = nullptr;
    enum BackendEnum : int32_t {
        None = -1,
        kMetro_TritonSplit = 0,
        kShim_BwdKernelFuse = 1,
        Max = 2
    };
    static constexpr BackendEnum fallback_backend = kMetro_TritonSplit;
    BackendEnum backend_index = BackendEnum::None;

#if AOTRITON_BUILD_FOR_TUNING
    int _has_preferred_backend = -1;
    static constexpr int _total_number_of_backends = BackendEnum::Max;
    const char* _backend_name = nullptr;
#endif

    // One more layer of dispatcher of functionals is added due to
    // 1. Individual kernel may use fewer arguments
    // 2. Metro kernel needs overall performance numbers over individual kernels.
    // 3. Even metro kernel only has one kernel, another set LUT is need to
    //    determine which metro kernel (or backend) need to be used
    int64_t godel_number() const;
    static std::tuple<int, int> get_archmod_number(Gpu gpu);
    static constexpr int kMaxGodelNumber = 576;

    hipError_t lookup_optimal(Gpu gpu);
    // Unlike Triton kernel, Operator's launch need gpu argument to eventually
    // call backend's lookup_optimal
    hipError_t launch(Gpu gpu, hipStream_t stream) const;
private:
    typedef void (*OpTuneTableEntry)(OpAttnBwdContext& context, int mod_number);
    static OpTuneTableEntry optune_table[][ kMaxGodelNumber ];

    typedef hipError_t (*BackendLauncher)(const OpAttnBwdContext& context,
                                          Gpu gpu,
                                          hipStream_t stream);
    static BackendLauncher launcher_table[ BackendEnum::Max ];
};

namespace optune {

// TODO: declare_list_of_deduplicated_lut_functions



}

}

// vim: set fileencoding=utf-8

