/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.function.Function;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.rexp.Converter;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RStringVector;

public class CaretEnsembleConverter
extends Converter<RGenericVector> {
    public CaretEnsembleConverter(RGenericVector caretEnsemble) {
        super(caretEnsemble);
    }

    @Override
    public PMML encodePMML(RExpEncoder encoder) {
        RGenericVector model;
        RGenericVector caretEnsemble = (RGenericVector)this.getObject();
        RGenericVector models = caretEnsemble.getGenericElement("models");
        RGenericVector ensModel = caretEnsemble.getGenericElement("ens_model");
        RStringVector modelNames = models.names();
        ArrayList<Object> segmentationModels = new ArrayList<Object>();
        Function<Schema, Schema> segmentSchemaFunction = new Function<Schema, Schema>(){

            @Override
            public Schema apply(Schema schema) {
                Label label = schema.getLabel();
                if (label instanceof ContinuousLabel) {
                    return schema.toAnonymousSchema();
                }
                if (label instanceof CategoricalLabel) {
                    return schema;
                }
                throw new IllegalArgumentException();
            }
        };
        for (int i = 0; i < models.size(); ++i) {
            OutputField outputField;
            model = models.getGenericValue(i);
            Conversion conversion = this.encodeTrainModel(model, segmentSchemaFunction);
            RExpEncoder segmentEncoder = conversion.getEncoder();
            encoder.addFields(segmentEncoder);
            Schema segmentSchema = conversion.getSchema();
            Model segmentModel = conversion.getModel();
            FieldName name = FieldName.create((String)modelNames.getValue(i));
            MiningFunction miningFunction = segmentModel.getMiningFunction();
            switch (miningFunction) {
                case REGRESSION: {
                    outputField = ModelUtil.createPredictedField((FieldName)name, (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE).setFinalResult(Boolean.FALSE);
                    break;
                }
                case CLASSIFICATION: {
                    CategoricalLabel categoricalLabel = (CategoricalLabel)segmentSchema.getLabel();
                    SchemaUtil.checkSize((int)2, (CategoricalLabel)categoricalLabel);
                    outputField = ModelUtil.createProbabilityField((FieldName)name, (DataType)DataType.DOUBLE, (Object)categoricalLabel.getValue(1)).setFinalResult(Boolean.FALSE);
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
            Output output = new Output().addOutputFields(new OutputField[]{outputField});
            segmentModel.setOutput(output);
            segmentationModels.add(segmentModel);
        }
        Conversion conversion = this.encodeTrainModel(ensModel, null);
        model = conversion.getModel();
        segmentationModels.add(model);
        MiningModel miningModel = MiningModelUtil.createModelChain(segmentationModels);
        PMML pmml = encoder.encodePMML((Model)miningModel);
        return pmml;
    }

    private Conversion encodeTrainModel(RGenericVector train, Function<Schema, Schema> schemaFunction) {
        RExp finalModel = (RExp)train.getElement("finalModel");
        ModelConverter converter = (ModelConverter)CaretEnsembleConverter.newConverter(finalModel);
        RExpEncoder encoder = new RExpEncoder();
        converter.encodeSchema(encoder);
        Schema schema = encoder.createSchema();
        if (schemaFunction != null) {
            schema = schemaFunction.apply(schema);
        }
        Model model = converter.encodeModel(schema);
        return new Conversion(encoder, schema, model);
    }

    private static class Conversion {
        private RExpEncoder encoder = null;
        private Schema schema = null;
        private Model model = null;

        private Conversion(RExpEncoder encoder, Schema schema, Model model) {
            this.setEncoder(encoder);
            this.setSchema(schema);
            this.setModel(model);
        }

        public RExpEncoder getEncoder() {
            return this.encoder;
        }

        private void setEncoder(RExpEncoder encoder) {
            this.encoder = encoder;
        }

        public Schema getSchema() {
            return this.schema;
        }

        private void setSchema(Schema schema) {
            this.schema = schema;
        }

        public Model getModel() {
            return this.model;
        }

        private void setModel(Model model) {
            this.model = model;
        }
    }
}

