package com.mindspore.flclient.model;

import com.mindspore.Model;
import com.mindspore.flclient.common.FLLoggerGenerater;
import java.util.Map;
import java.util.logging.Logger;

/* loaded from: input_file:com/mindspore/flclient/model/LossCallback.class */
public class LossCallback extends Callback {
    private static final Logger logger = FLLoggerGenerater.getModelLogger(LossCallback.class.toString());
    private float lossSum;
    private float uploadLoss;

    public LossCallback(Model model) {
        super(model);
        this.lossSum = 0.0f;
        this.uploadLoss = 0.0f;
    }

    @Override // com.mindspore.flclient.model.Callback
    public Status stepBegin() {
        return Status.SUCCESS;
    }

    @Override // com.mindspore.flclient.model.Callback
    public Status stepEnd() {
        Map<String, float[]> outputsBySize = getOutputsBySize(1);
        if (outputsBySize.isEmpty()) {
            logger.severe("cannot find loss tensor");
            return Status.NULLPTR;
        }
        Map.Entry<String, float[]> next = outputsBySize.entrySet().iterator().next();
        if (next.getValue().length < 1 || Float.isNaN(next.getValue()[0])) {
            logger.severe("loss is nan");
            return Status.FAILED;
        }
        float f = next.getValue()[0];
        logger.info("batch:" + this.steps + ",loss:" + f);
        this.lossSum += f;
        this.steps++;
        return Status.SUCCESS;
    }

    @Override // com.mindspore.flclient.model.Callback
    public Status epochBegin() {
        return Status.SUCCESS;
    }

    @Override // com.mindspore.flclient.model.Callback
    public Status epochEnd() {
        logger.info("----------epoch:" + this.epochs + ",average loss:" + (this.lossSum / this.steps) + "----------");
        setUploadLoss(this.lossSum / this.steps);
        this.steps = 0;
        this.epochs++;
        this.lossSum = 0.0f;
        return Status.SUCCESS;
    }

    public float getUploadLoss() {
        return this.uploadLoss;
    }

    public void setUploadLoss(float f) {
        this.uploadLoss = f;
    }
}
