/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import java.util.logging.Logger;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.optimisers.AdaDelta;
import org.tribuo.math.optimisers.AdaGrad;
import org.tribuo.math.optimisers.AdaGradRDA;
import org.tribuo.math.optimisers.Adam;
import org.tribuo.math.optimisers.ParameterAveraging;
import org.tribuo.math.optimisers.Pegasos;
import org.tribuo.math.optimisers.RMSProp;
import org.tribuo.math.optimisers.SGD;

public class GradientOptimiserOptions
implements Options {
    private static final Logger logger = Logger.getLogger(GradientOptimiserOptions.class.getName());
    @Option(longName="sgo-type", usage="Selects the gradient optimiser. Defaults to ADAGRAD.")
    private StochasticGradientOptimiserType optimiserType = StochasticGradientOptimiserType.ADAGRAD;
    @Option(longName="sgo-learning-rate", usage="Learning rate for AdaGrad, AdaGradRDA, Adam, Pegasos.")
    public double learningRate = 0.18;
    @Option(longName="sgo-epsilon", usage="Epsilon for AdaDelta, AdaGrad, AdaGradRDA, Adam.")
    public double epsilon = 0.066;
    @Option(longName="sgo-rho", usage="Rho for RMSProp, AdaDelta, SGD with Momentum.")
    public double rho = 0.95;
    @Option(longName="sgo-lambda", usage="Lambda for Pegasos.")
    public double lambda = 0.01;
    @Option(longName="sgo-parameter-averaging", usage="Use parameter averaging.")
    public boolean paramAve = false;
    @Option(longName="sgo-momentum", usage="Use momentum in SGD.")
    public SGD.Momentum momentum = SGD.Momentum.NONE;

    public StochasticGradientOptimiser getOptimiser() {
        StochasticGradientOptimiser sgo;
        switch (this.optimiserType) {
            case ADADELTA: {
                sgo = new AdaDelta(this.rho, this.epsilon);
                break;
            }
            case ADAGRAD: {
                sgo = new AdaGrad(this.learningRate, this.epsilon);
                break;
            }
            case ADAGRADRDA: {
                sgo = new AdaGradRDA(this.learningRate, this.epsilon);
                break;
            }
            case ADAM: {
                sgo = new Adam(this.learningRate, this.epsilon);
                break;
            }
            case PEGASOS: {
                sgo = new Pegasos(this.learningRate, this.lambda);
                break;
            }
            case RMSPROP: {
                sgo = new RMSProp(this.learningRate, this.rho);
                break;
            }
            case CONSTANTSGD: {
                sgo = SGD.getSimpleSGD(this.learningRate, this.rho, this.momentum);
                break;
            }
            case LINEARSGD: {
                sgo = SGD.getLinearDecaySGD(this.learningRate, this.rho, this.momentum);
                break;
            }
            case SQRTSGD: {
                sgo = SGD.getSqrtDecaySGD(this.learningRate, this.rho, this.momentum);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unhandled StochasticGradientOptimiser type: " + (Object)((Object)this.optimiserType));
            }
        }
        if (this.paramAve) {
            logger.info("Using parameter averaging");
            return new ParameterAveraging(sgo);
        }
        return sgo;
    }

    public static enum StochasticGradientOptimiserType {
        ADADELTA,
        ADAGRAD,
        ADAGRADRDA,
        ADAM,
        PEGASOS,
        RMSPROP,
        CONSTANTSGD,
        LINEARSGD,
        SQRTSGD;

    }
}

