/**
 * UGENE - Integrated Bioinformatics Tools.
 * Copyright (C) 2008-2024 UniPro <ugene@unipro.ru>
 * http://ugene.net
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 * MA 02110-1301, USA.
 */

#include "ExtractMSAConsensusWorker.h"

#include <U2Algorithm/BuiltInConsensusAlgorithms.h>
#include <U2Algorithm/MsaConsensusAlgorithmRegistry.h>
#include <U2Algorithm/MsaConsensusUtils.h>

#include <U2Core/AppContext.h>
#include <U2Core/FailTask.h>
#include <U2Core/U2AssemblyDbi.h>
#include <U2Core/U2OpStatusUtils.h>
#include <U2Core/U2SafePoints.h>

#include <U2Designer/DelegateEditors.h>

#include <U2Lang/ActorPrototypeRegistry.h>
#include <U2Lang/BaseActorCategories.h>
#include <U2Lang/BasePorts.h>
#include <U2Lang/BaseSlots.h>
#include <U2Lang/BaseTypes.h>
#include <U2Lang/WorkflowEnv.h>

#include <U2View/ExportConsensusTask.h>

namespace U2 {
namespace LocalWorkflow {

const QString ExtractMSAConsensusSequenceWorkerFactory::ACTOR_ID("extract-msa-consensus-sequence");
const QString ExtractMSAConsensusStringWorkerFactory::ACTOR_ID("extract-msa-consensus-string");

namespace {
const QString ALGO_ATTR_ID("algorithm");
const QString THRESHOLD_ATTR_ID("threshold");
const QString GAPS_ATTR_ID("keep-gaps");
}  // namespace

ExtractMSAConsensusWorker::ExtractMSAConsensusWorker(Actor* actor)
    : BaseWorker(actor),
      extractMsaConsensus(nullptr) {
}

void ExtractMSAConsensusWorker::init() {
}

Task* ExtractMSAConsensusWorker::tick() {
    if (hasMsa()) {
        U2OpStatusImpl os;
        Msa msa = takeMsa(os);
        CHECK_OP(os, new FailTask(os.getError()));
        extractMsaConsensus = createTask(msa);
        return extractMsaConsensus;
    } else {
        finish();
        return nullptr;
    }
}

void ExtractMSAConsensusWorker::sl_taskFinished() {
    auto t = dynamic_cast<ExtractMSAConsensusTaskHelper*>(sender());
    CHECK(t != nullptr, );
    CHECK(t->isFinished() && !t->hasError(), );
    CHECK(!t->isCanceled(), );

    sendResult(context->getDataStorage()->getDataHandler(t->getResult()));
}

void ExtractMSAConsensusWorker::cleanup() {
}

bool ExtractMSAConsensusWorker::hasMsa() const {
    const IntegralBus* port = ports[BasePorts::IN_MSA_PORT_ID()];
    SAFE_POINT(port != nullptr, "NULL msa port", false);
    return port->hasMessage();
}

Msa ExtractMSAConsensusWorker::takeMsa(U2OpStatus& os) {
    const Message m = getMessageAndSetupScriptValues(ports[BasePorts::IN_MSA_PORT_ID()]);
    const QVariantMap data = m.getData().toMap();
    if (!data.contains(BaseSlots::MULTIPLE_ALIGNMENT_SLOT().getId())) {
        os.setError(tr("Empty msa slot"));
        return {};
    }
    const SharedDbiDataHandler dbiId = data[BaseSlots::MULTIPLE_ALIGNMENT_SLOT().getId()].value<SharedDbiDataHandler>();
    const MsaObject* obj = StorageUtils::getMsaObject(context->getDataStorage(), dbiId);
    if (obj == nullptr) {
        os.setError(tr("Error with msa object"));
        return {};
    }
    return obj->getAlignment();
}

///////////////////////////////////////////////////////////////////////
// ExtractMSAConsensusStringWorker
ExtractMSAConsensusStringWorker::ExtractMSAConsensusStringWorker(Actor* actor)
    : ExtractMSAConsensusWorker(actor) {
}

void ExtractMSAConsensusStringWorker::finish() {
    IntegralBus* inPort = ports[BasePorts::IN_MSA_PORT_ID()];
    SAFE_POINT(inPort != nullptr, "NULL msa port", );
    SAFE_POINT(inPort->isEnded(), "The msa is not ended", );
    IntegralBus* outPort = ports[BasePorts::OUT_TEXT_PORT_ID()];
    SAFE_POINT(outPort != nullptr, "NULL text port", );

    outPort->setEnded();
    setDone();
}

void ExtractMSAConsensusStringWorker::sendResult(const SharedDbiDataHandler& /*seqId*/) {
    QVariantMap data;
    data[BaseSlots::TEXT_SLOT().getId()] = extractMsaConsensus->getResultAsText();
    IntegralBus* outPort = ports[BasePorts::OUT_TEXT_PORT_ID()];

    SAFE_POINT(outPort != nullptr, "NULL text port", );

    outPort->put(Message(outPort->getBusType(), data));
}

ExtractMSAConsensusTaskHelper* ExtractMSAConsensusStringWorker::createTask(const Msa& msa) {
    const QString algoId = getValue<QString>(ALGO_ATTR_ID);
    const int threshold = getValue<int>(THRESHOLD_ATTR_ID);
    extractMsaConsensus = new ExtractMSAConsensusTaskHelper(algoId, threshold, true, msa, context->getDataStorage()->getDbiRef());
    connect(extractMsaConsensus, SIGNAL(si_stateChanged()), SLOT(sl_taskFinished()));
    return extractMsaConsensus;
}

///////////////////////////////////////////////////////////////////////
// ExtractMSAConsensusSequenceWorker
ExtractMSAConsensusSequenceWorker::ExtractMSAConsensusSequenceWorker(Actor* actor)
    : ExtractMSAConsensusWorker(actor) {
}

void ExtractMSAConsensusSequenceWorker::finish() {
    IntegralBus* inPort = ports[BasePorts::IN_MSA_PORT_ID()];
    SAFE_POINT(inPort != nullptr, "NULL msa port", );
    SAFE_POINT(inPort->isEnded(), "The msa is not ended", );
    IntegralBus* outPort = ports[BasePorts::OUT_SEQ_PORT_ID()];
    SAFE_POINT(outPort != nullptr, "NULL sequence port", );

    outPort->setEnded();
    setDone();
}

void ExtractMSAConsensusSequenceWorker::sendResult(const SharedDbiDataHandler& seqId) {
    QVariantMap data;
    data[BaseSlots::DNA_SEQUENCE_SLOT().getId()] = qVariantFromValue<SharedDbiDataHandler>(seqId);
    IntegralBus* outPort = ports[BasePorts::OUT_SEQ_PORT_ID()];
    SAFE_POINT(outPort != nullptr, "NULL sequence port", );

    outPort->put(Message(outPort->getBusType(), data));
}

ExtractMSAConsensusTaskHelper* ExtractMSAConsensusSequenceWorker::createTask(const Msa& msa) {
    const QString algoId = getValue<QString>(ALGO_ATTR_ID);
    const int threshold = getValue<int>(THRESHOLD_ATTR_ID);
    const bool keepGaps = getValue<bool>(GAPS_ATTR_ID);
    extractMsaConsensus = new ExtractMSAConsensusTaskHelper(algoId, threshold, keepGaps, msa, context->getDataStorage()->getDbiRef());
    connect(extractMsaConsensus, SIGNAL(si_stateChanged()), SLOT(sl_taskFinished()));
    return extractMsaConsensus;
}

///////////////////////////////////////////////////////////////////////
// ExtractMSAConsensusTaskHelper
ExtractMSAConsensusTaskHelper::ExtractMSAConsensusTaskHelper(const QString& algoId, int threshold, bool keepGaps, const Msa& msa, const U2DbiRef& targetDbi)
    : Task(ExtractMSAConsensusTaskHelper::tr("Extract consensus"), TaskFlags_NR_FOSCOE),
      algoId(algoId),
      threshold(threshold),
      keepGaps(keepGaps),
      msa(msa->getCopy()),
      targetDbi(targetDbi)
//,resultText("")
{
}

QString ExtractMSAConsensusTaskHelper::getResultName() const {
    QString res;
    res = msa->getName();
    res += "_consensus";
    return res;
}

void ExtractMSAConsensusTaskHelper::prepare() {
    QSharedPointer<MsaConsensusAlgorithm> algo(createAlgorithm());
    SAFE_POINT_EXT(algo != nullptr, setError("Wrong consensus algorithm"), );

    MsaConsensusUtils::updateConsensus(msa, resultText, algo.data());
    if (!keepGaps && algo->getFactory()->isSequenceLikeResult()) {
        resultText.replace("-", "");
    }

    if (algo->getFactory()->isSequenceLikeResult()) {
        U2SequenceImporter seqImporter;
        seqImporter.startSequence(stateInfo, targetDbi, U2ObjectDbi::ROOT_FOLDER, getResultName(), false);
        seqImporter.addBlock(resultText.data(), resultText.length(), stateInfo);
        resultSequence = seqImporter.finalizeSequence(stateInfo);
    }
}

U2EntityRef ExtractMSAConsensusTaskHelper::getResult() const {
    const U2EntityRef ref(targetDbi, resultSequence.id);
    return ref;
}

MsaConsensusAlgorithm* ExtractMSAConsensusTaskHelper::createAlgorithm() {
    MsaConsensusAlgorithmRegistry* reg = AppContext::getMSAConsensusAlgorithmRegistry();
    SAFE_POINT_EXT(reg != nullptr, setError("NULL registry"), nullptr);

    MsaConsensusAlgorithmFactory* f = reg->getAlgorithmFactory(algoId);
    if (f == nullptr) {
        setError(ExtractMSAConsensusTaskHelper::tr("Unknown consensus algorithm: ") + algoId);
        return nullptr;
    }
    MsaConsensusAlgorithm* alg = f->createAlgorithm(msa, false);
    SAFE_POINT_EXT(alg != nullptr, setError("NULL algorithm"), nullptr);
    alg->setThreshold(threshold);

    return alg;
}

QByteArray ExtractMSAConsensusTaskHelper::getResultAsText() const {
    return resultText;
}

///////////////////////////////////////////////////////////////////////
// ExtractMSAConsensusWorkerFactory
ExtractMSAConsensusSequenceWorkerFactory::ExtractMSAConsensusSequenceWorkerFactory()
    : DomainFactory(ACTOR_ID) {
}

Worker* ExtractMSAConsensusSequenceWorkerFactory::createWorker(Actor* actor) {
    return new ExtractMSAConsensusSequenceWorker(actor);
}

void ExtractMSAConsensusSequenceWorkerFactory::init() {
    MsaConsensusAlgorithmRegistry* reg = AppContext::getMSAConsensusAlgorithmRegistry();
    SAFE_POINT(reg != nullptr, "NULL registry", );

    const Descriptor desc(ACTOR_ID,
                          ExtractMSAConsensusSequenceWorker::tr("Extract Consensus from Alignment as Sequence"),
                          ExtractMSAConsensusSequenceWorker::tr("Extract the consensus sequence from the incoming multiple sequence alignment."));

    QList<PortDescriptor*> ports;
    {
        Descriptor inD(BasePorts::IN_MSA_PORT_ID(),
                       ExtractMSAConsensusStringWorker::tr("Input alignment"),
                       ExtractMSAConsensusStringWorker::tr("A alignment which consensus should be extracted"));
        QMap<Descriptor, DataTypePtr> inData;
        inData[BaseSlots::MULTIPLE_ALIGNMENT_SLOT()] = BaseTypes::MULTIPLE_ALIGNMENT_TYPE();
        ports << new PortDescriptor(inD, DataTypePtr(new MapDataType(BasePorts::IN_MSA_PORT_ID(), inData)), true);

        Descriptor outD(BasePorts::OUT_SEQ_PORT_ID(),
                        ExtractMSAConsensusSequenceWorker::tr("Consensus sequence"),
                        ExtractMSAConsensusSequenceWorker::tr("Provides resulting consensus as a sequence"));

        QMap<Descriptor, DataTypePtr> outData;
        outData[BaseSlots::DNA_SEQUENCE_SLOT()] = BaseTypes::DNA_SEQUENCE_TYPE();
        ports << new PortDescriptor(outD, DataTypePtr(new MapDataType(BasePorts::OUT_SEQ_PORT_ID(), outData)), false, true);
    }

    QList<Attribute*> attrs;
    QMap<QString, PropertyDelegate*> delegates;
    {
        const Descriptor algoDesc(ALGO_ATTR_ID,
                                  ExtractMSAConsensusSequenceWorker::tr("Algorithm"),
                                  ExtractMSAConsensusSequenceWorker::tr("The algorithm of consensus extracting."));
        const Descriptor thresholdDesc(THRESHOLD_ATTR_ID,
                                       ExtractMSAConsensusSequenceWorker::tr("Threshold"),
                                       ExtractMSAConsensusSequenceWorker::tr("The threshold of the algorithm."));
        const Descriptor gapsDesc(GAPS_ATTR_ID,
                                  ExtractMSAConsensusSequenceWorker::tr("Keep gaps"),
                                  ExtractMSAConsensusSequenceWorker::tr("Set this parameter if the result consensus must keep the gaps."));

        auto thr = new Attribute(thresholdDesc, BaseTypes::NUM_TYPE(), true, 100);
        auto algo = new Attribute(algoDesc, BaseTypes::STRING_TYPE(), true, BuiltInConsensusAlgorithms::STRICT_ALGO);
        attrs << algo << thr << new Attribute(gapsDesc, BaseTypes::BOOL_TYPE(), true, true);

        QVariantMap algos;
        QVariantMap m;
        QVariantList visibleRelationList;
        m["minimum"] = 0;
        m["maximum"] = 100;
        auto thrDelegate = new SpinBoxDelegate(m);
        foreach (const QString& algoId, reg->getAlgorithmIds()) {
            MsaConsensusAlgorithmFactory* f = reg->getAlgorithmFactory(algoId);
            if (f->isSequenceLikeResult()) {
                algos[f->getName()] = algoId;
                if (f->supportsThreshold()) {
                    visibleRelationList.append(algoId);
                }
            }
        }
        thr->addRelation(new VisibilityRelation(ALGO_ATTR_ID, visibleRelationList));
        algo->addRelation(new SpinBoxDelegatePropertyRelation(THRESHOLD_ATTR_ID));
        delegates[ALGO_ATTR_ID] = new ComboBoxDelegate(algos);
        delegates[THRESHOLD_ATTR_ID] = thrDelegate;
    }

    ActorPrototype* proto = new IntegralBusActorPrototype(desc, ports, attrs);
    proto->setPrompter(new ExtractMSAConsensusWorkerPrompter());
    proto->setEditor(new DelegateEditor(delegates));

    WorkflowEnv::getProtoRegistry()->registerProto(BaseActorCategories::CATEGORY_ALIGNMENT(), proto);
    DomainFactory* localDomain = WorkflowEnv::getDomainRegistry()->getById(LocalDomainFactory::ID);
    localDomain->registerEntry(new ExtractMSAConsensusSequenceWorkerFactory());
}

///////////////////////////////////////////////////////////////////////
// ExtractMSAConsensusStringWorkerFactory
ExtractMSAConsensusStringWorkerFactory::ExtractMSAConsensusStringWorkerFactory()
    : DomainFactory(ACTOR_ID) {
}

Worker* ExtractMSAConsensusStringWorkerFactory::createWorker(Actor* actor) {
    return new ExtractMSAConsensusStringWorker(actor);
}

void ExtractMSAConsensusStringWorkerFactory::init() {
    MsaConsensusAlgorithmRegistry* reg = AppContext::getMSAConsensusAlgorithmRegistry();
    SAFE_POINT(reg != nullptr, "NULL registry", );

    const Descriptor desc(ACTOR_ID,
                          ExtractMSAConsensusSequenceWorker::tr("Extract Consensus from Alignment as Text"),
                          ExtractMSAConsensusSequenceWorker::tr("Extract the consensus string from the incoming multiple sequence alignment."));

    QList<PortDescriptor*> ports;
    {
        Descriptor inD(BasePorts::IN_MSA_PORT_ID(),
                       ExtractMSAConsensusStringWorker::tr("Input alignment"),
                       ExtractMSAConsensusStringWorker::tr("A alignment which consensus should be extracted"));

        QMap<Descriptor, DataTypePtr> inData;
        inData[BaseSlots::MULTIPLE_ALIGNMENT_SLOT()] = BaseTypes::MULTIPLE_ALIGNMENT_TYPE();
        ports << new PortDescriptor(inD, DataTypePtr(new MapDataType(BasePorts::IN_MSA_PORT_ID(), inData)), true);

        Descriptor outD(BasePorts::OUT_TEXT_PORT_ID(),
                        ExtractMSAConsensusStringWorker::tr("Consensus"),
                        ExtractMSAConsensusStringWorker::tr("Provides resulting consensus as a text"));

        QMap<Descriptor, DataTypePtr> outData;
        outData[BaseSlots::TEXT_SLOT()] = BaseTypes::STRING_TYPE();
        ports << new PortDescriptor(outD, DataTypePtr(new MapDataType(BasePorts::OUT_TEXT_PORT_ID(), outData)), false, true);
    }

    QList<Attribute*> attrs;
    QMap<QString, PropertyDelegate*> delegates;
    {
        const Descriptor algoDesc(ALGO_ATTR_ID,
                                  ExtractMSAConsensusSequenceWorker::tr("Algorithm"),
                                  ExtractMSAConsensusSequenceWorker::tr("The algorithm of consensus extracting."));
        const Descriptor thresholdDesc(THRESHOLD_ATTR_ID,
                                       ExtractMSAConsensusSequenceWorker::tr("Threshold"),
                                       ExtractMSAConsensusSequenceWorker::tr("The threshold of the algorithm."));
        auto thr = new Attribute(thresholdDesc, BaseTypes::NUM_TYPE(), true, 100);
        auto algo = new Attribute(algoDesc, BaseTypes::STRING_TYPE(), true, BuiltInConsensusAlgorithms::DEFAULT_ALGO);
        attrs << algo << thr;

        QVariantList visibleRelationList;
        QVariantMap algos;
        QVariantMap m;
        m["minimum"] = 0;
        m["maximum"] = 100;
        auto thrDelegate = new SpinBoxDelegate(m);
        foreach (const QString& algoId, reg->getAlgorithmIds()) {
            MsaConsensusAlgorithmFactory* f = reg->getAlgorithmFactory(algoId);
            if (!f->isSequenceLikeResult()) {
                algos[f->getName()] = algoId;
                if (f->supportsThreshold()) {
                    visibleRelationList.append(algoId);
                }
            }
        }
        thr->addRelation(new VisibilityRelation(ALGO_ATTR_ID, visibleRelationList));
        algo->addRelation(new SpinBoxDelegatePropertyRelation(THRESHOLD_ATTR_ID));
        delegates[ALGO_ATTR_ID] = new ComboBoxDelegate(algos);
        delegates[THRESHOLD_ATTR_ID] = thrDelegate;
    }

    ActorPrototype* proto = new IntegralBusActorPrototype(desc, ports, attrs);
    proto->setPrompter(new ExtractMSAConsensusWorkerPrompter());
    proto->setEditor(new DelegateEditor(delegates));

    WorkflowEnv::getProtoRegistry()->registerProto(BaseActorCategories::CATEGORY_ALIGNMENT(), proto);
    DomainFactory* localDomain = WorkflowEnv::getDomainRegistry()->getById(LocalDomainFactory::ID);
    localDomain->registerEntry(new ExtractMSAConsensusStringWorkerFactory());
}

///////////////////////////////////////////////////////////////////////
// ExtractMSAConsensusWorkerPrompter
ExtractMSAConsensusWorkerPrompter::ExtractMSAConsensusWorkerPrompter(Actor* actor)
    : PrompterBase<ExtractMSAConsensusWorkerPrompter>(actor) {
}

QString ExtractMSAConsensusWorkerPrompter::composeRichDoc() {
    QString algorithm = getParameter(ALGO_ATTR_ID).toString();
    QString link = getHyperlink(ALGO_ATTR_ID, algorithm);
    return ExtractMSAConsensusSequenceWorker::tr("Extracts the consensus sequence from the incoming alignment(s) using the %1 algorithm.").arg(link);
}

SpinBoxDelegatePropertyRelation* SpinBoxDelegatePropertyRelation::clone() const {
    return new SpinBoxDelegatePropertyRelation(*this);
}

QVariant SpinBoxDelegatePropertyRelation::getAffectResult(const QVariant& influencingValue, const QVariant& dependentValue, DelegateTags* /*infTags*/, DelegateTags* depTags) const {
    CHECK(depTags != nullptr, dependentValue);
    updateDelegateTags(influencingValue, depTags);
    int res = qBound(depTags->get("minimum").toInt(), dependentValue.toInt(), depTags->get("maximum").toInt());
    return res;
}

void SpinBoxDelegatePropertyRelation::updateDelegateTags(const QVariant& influencingValue, DelegateTags* dependentTags) const {
    MsaConsensusAlgorithmRegistry* reg = AppContext::getMSAConsensusAlgorithmRegistry();
    SAFE_POINT(reg != nullptr, "NULL registry", );
    MsaConsensusAlgorithmFactory* consFactory = reg->getAlgorithmFactory(influencingValue.toString());
    if (!consFactory) {
        return;
    }
    if (dependentTags != nullptr) {
        dependentTags->set("minimum", consFactory->getMinThreshold());
        dependentTags->set("maximum", consFactory->getMaxThreshold());
    }
}

}  // namespace LocalWorkflow
}  // namespace U2
