% FITPULSE Regress pulse model to pulse resistance estimates.
%
% Optimizers: particleswarm and fmincon.
%
% Strategy: Evalulating nonlinear pulse-resistance is costly, so we first
% regress the linear pulse-resistance model to the lab measurements using
% a combination of pso.m and fmincon.m. The linear pulse-resistance 
% surface varies with temperature and SOC but not C-rate. We follow this
% with a full regression of the nonlinear model to get the C-rate
% dependence. We use fmincon.m for this purpose.
%
% -- Usage --
% outData = FITPULSE(cellspec,model,Tstring)
% outData = FITPULSE(...,'Np',NP)
%
% -- Input --
% cellspec    = cell specification generated by the labcell() function
% configREG   = optutil model for regression (from optutil.modelspec)
% Tstring     = string selecting the temperature(s) at which to perform the
%               model regression. Specify two or more temperatures.
%               Arrhenius equations are used to describe varation of the
%               parameter values with temperature. Example: '0degC 25degC'.
% NP          = number of particles to use in particle swarm optimization
%               (DEFAULT 100).
%
% Copyright (©) 2024 The Regents of the University of Colorado, a body
% corporate. Created by Gregory L. Plett and M. Scott Trimboli of the
% University of Colorado Colorado Springs (UCCS). This work is licensed
% under a Creative Commons "Attribution-ShareAlike 4.0 International" Intl.
% License. https://creativecommons.org/licenses/by-sa/4.0/ 
% This code is provided as a supplement to: Gregory L. Plett and M. Scott
% Trimboli, "Battery Management Systems, Volume III, Physics-Based
% Methods," Artech House, 2024. It is provided "as is", without express or
% implied warranty. Attribution should be given by citing: Gregory L. Plett
% and M. Scott Trimboli, Battery Management Systems, Volume III:
% Physics-Based Methods, Artech House, 2024.        

function outData = fitPulse(cellspec,model,varargin)

  parser = inputParser;
  parser.addRequired('cellspec',@(x)isstruct(x)&&strcmp(x.origin__,'labcell'));
  parser.addRequired('model',@(x)isscalar(x)&&isstruct(x));
  parser.addRequired('Tstring',@(x)ischar(x));
  parser.addParameter('Np',100,@(x)isscalar(x)&&x>10);
  parser.parse(cellspec,model,varargin{:});
  arg = parser.Results; % struct of validated arguments
  
  % Collect cell metadata.
  cellname = arg.cellspec.name;
  processfile = arg.cellspec.pls.processfile;
  fitfile = arg.cellspec.pls.fitfile;
  timestamp = arg.cellspec.timestamp__;
  
  outp.print('Started fitPulse %s\n',cellname);
  outp.info(' - The regression takes place over two steps. The first uses\n');
  outp.info('   linear methods and is slow. The second seeks to enhance the\n');
  outp.info('   results from the first optimization and uses nonlinear\n');
  outp.info('   methods and is slower. Patience is advised!\n');
  outp.info(' - During regression, the RMSE value that is output is updated\n');
  outp.info('   whenever a better set of parameters is found and gives\n');
  outp.info('   some indication of optimization progress.\n');
  
  % Load processed pulse data.
  processData = load(processfile);
  cellData = processData.cellData; % struct of param values previously identified
  ecmVect = processData.ecmVect; % vector struct of ECM param values
  
  % Get selected temperature(s).
  Tvect = cellfun(@(x)str2double(x(1:end-4)),strsplit(arg.Tstring));
  indT = find(sum([ecmVect.TdegC]==Tvect'));
  if isempty(indT)
    error(['Could not find pulse datasets at requested temperatures. ' ...
        'Cannot continue.\n']);
  end
  if length(indT) < 2
    error(['Need datasets at at least two temperatures to run pulse ' ...
        'regression. Cannot continue.'])
  end
  
  % Structure to store all arguments passed to cost function.
  p = struct;
  
  % Fetch struct of fitted ECM parameter values at the selected temperatures.
  ecmData = ecmVect(indT);
  clear tempSet;
  for kt = length(ecmData):-1:1
    tmp = struct;
    tmp.TdegC = Tvect(kt);
    tmp.socPct = ecmData(kt).socPct;  % vector of SOC setpoints
    tmp.iapp = ecmData(kt).iapp;      % vector of current pulse magnitude setpoints
    tmp.iappLAB = ecmData(kt).I;      % measured iapp matrix (dim1=SOC, dim2=iapp)
    tmp.R0LAB = ecmData(kt).R;        % pulse resistance matrix (dim1=SOC, dim2=iapp)
    tmp.deltaR0LAB = ecmData(kt).ubR - ecmData(kt).lbR;

    % Fetch OCP curves.
    % NEG - use regressed MSMR model (absolute OCP)
    thetaNEG = cellData.neg.theta0 + ...
        (tmp.socPct/100)*(cellData.neg.theta100-cellData.neg.theta0);
    ocpNEG = MSMR(cellData.neg).ocp('theta',thetaNEG,'TdegC',Tvect(kt));
    tmp.thetaNEG = ocpNEG.theta;
    tmp.UocpNEG = ocpNEG.Uocp;
    % POS - use regressed MSMR model (absolute OCP)
    thetaPOS = cellData.pos.theta0 + ... 
        (tmp.socPct/100)*(cellData.pos.theta100-cellData.pos.theta0);
    ocpPOS = MSMR(cellData.pos).ocp('theta',thetaPOS,'TdegC',Tvect(kt));
    tmp.thetaPOS = ocpPOS.theta;
    tmp.UocpPOS = ocpPOS.Uocp;

    tempSet(kt) = tmp;   
  end
  p.Tvect = Tvect;
  p.tempSet = tempSet;
  
  % Add fixed parameters to model.
  model = arg.model;
  model.neg.U0 = optutil.param('fix',cellData.neg.U0);
  model.neg.X = optutil.param('fix',cellData.neg.X);
  model.neg.omega = optutil.param('fix',cellData.neg.omega);
  model.neg.theta0 = optutil.param('fix',cellData.neg.theta0);
  model.neg.theta100 = optutil.param('fix',cellData.neg.theta100);
  model.pos.U0 = optutil.param('fix',cellData.pos.U0);
  model.pos.X = optutil.param('fix',cellData.pos.X);
  model.pos.omega = optutil.param('fix',cellData.pos.omega);
  model.pos.theta0 = optutil.param('fix',cellData.pos.theta0);
  model.pos.theta100 = optutil.param('fix',cellData.pos.theta100);
  model.const.Q = optutil.param('fix',cellData.const.Q);
  modelspec = optutil.modelspec(model,'TrefdegC',25);
  p.modelspec = modelspec;
  
  % Configure and run optimizer.
  % -- PSO+fmincon on linear pulse resistance (fast and robust!) --
  init = optutil.pack('init',modelspec,[],'coerce');
  lb = optutil.pack('lb',modelspec,[],'coerce');  % vector of lower bounds
  ub = optutil.pack('ub',modelspec,[],'coerce');  % vector of upper bounds
  Np = arg.Np;       % number of particles
  pop0 = init(:).';  % (partial) initial population
  optionsFMINCON = optimoptions(@fmincon,'Display','off',...
      'MaxFunEvals',1e6,'MaxIter',1e3,...
      'TolFun',1e-20,'TolX',1e-20,'TolCon',1e-20);
  optionsPSO = optimoptions(@particleswarm,...
      'Display','off','UseParallel',false,...
      'FunctionTolerance',1e-20,'SwarmSize',Np,...
      'MaxIterations',100,'MaxStallIterations',20,...
      'InitialSwarmMatrix',pop0,'FunValCheck','off',...
      'HybridFcn',{@fmincon,optionsFMINCON});
  
  outp.info('Running initial pulse regression (linear model)...\n');
  linVect = particleswarm( ...
      @(x)cost(x,p,'linear'),length(init),lb,ub,optionsPSO);
  linParams = optutil.unpack(linVect,modelspec);
  % -- fmincon on nonlinear pulse resistance (slow and less robust!) --
  
  % warning('off','MATLAB:bvp4c:RelTolNotMet'); 
  warning('off','MATLAB:bvp5c:RelTolNotMet'); % GLP
  % warning('off'); % GLP
  
  outp.info('Running second pulse regression (exact nonlinear model)...\n');
  nlVect = fmincon( ...
      @(x)cost(x,p,'exact'),linVect,[],[],[],[],lb,ub,[],optionsFMINCON);
  nlParams = optutil.unpack(nlVect,modelspec);
  
  warning('on','MATLAB:bvp5c:RelTolNotMet'); % GLP
  
  % Save cell parameter values (these apply at TrefdegC=25).
  secnames = fieldnames(nlParams);
  for ks = 1:length(secnames)
    secname = secnames{ks};
    sec = nlParams.(secname);
    paramnames = fieldnames(sec);
    for kp = 1:length(paramnames)
      paramname = paramnames{kp};
      val = sec.(paramname);
      cellData.(secname).(paramname) = val;
    end % for
  end % for
  cellData.const.TdegC = p.modelspec.Tref-273.15;
  
  % Collect and save output data.
  outData = struct;
  outData.cellData = cellData;
  outData.linearEstimates = linParams;
  outData.nonlinearEstimates = nlParams;
  outData.modelspec = p.modelspec;
  outData.origin__ = 'fitPulse';
  outData.arg__ = arg;
  outData.timestamp__ = timestamp;
  save(fitfile,'-struct','outData');
  
  outp.print('Finished fitPulse %s\n',cellname);
end


% Utility functions -------------------------------------------------------

%COST Cost function for the model regression.
function J = cost(vect,param,solnMethod)

  if ~exist('solnMethod','var')
    % Method of finding the pulse resistance, either 'linear' or 'exact'.
    solnMethod = 'exact';
  end
  useLinear = strcmpi(solnMethod,'linear');
  
  Tvect = param.Tvect;
  ntemp = length(param.Tvect);
  
  persistent minJ bestModel bscount bscount2 evalnum solmethod fig lines_R0LAB lines_R0MOD;
  if isempty(minJ)
    minJ = Inf;
  end
  if isempty(bestModel)
    bestModel = struct;
  end
  if isempty(bscount)
    bscount = 0;
  end
  if isempty(bscount2)
    bscount2 = 0;
  end
  if isempty(evalnum)
    evalnum = 0;
  end
  if isempty(solmethod)
    solmethod = solnMethod;
  end
  if outp.plot && (isempty(fig) || ~ishandle(fig) || ~strcmp(solnMethod,solmethod))
    minJ = Inf;
    bestModel = struct;
    bscount = 0;
    bscount2 = 0;
    evalnum = 0;
    solmethod = solnMethod;
    outp.charcount(0);
    fig = figure( ...
        'Name','Pulse Regression', ...
        'WindowStyle','normal');
    lines_R0LAB = cell(ntemp,1);
    lines_R0MOD = cell(ntemp,1);
    for kt = 1:ntemp
      socPct = param.tempSet(kt).socPct;
      colors = cool(length(socPct));
      lines_R0LAB{kt} = gobjects(length(socPct),1);
      lines_R0LAB{kt} = gobjects(length(socPct),1);
      subplot(1,ntemp,kt);
      for kz = 1:length(socPct)
          lines_R0LAB{kt}(kz) = plot(NaN,NaN,'o','Color',colors(kz,:)); hold on;
      end % for soc
      for kz = 1:length(socPct)
          lines_R0MOD{kt}(kz) = plot(NaN,NaN,'-','Color',colors(kz,:)); hold on;
      end % for soc
      xlabel('Pulse Magnitude [C rate]', ...
      'Interpreter','latex');
      ylabel('$R_\mathrm{0}$ [$\mathrm{\Omega}$]','Interpreter','latex');
      title(sprintf('%.2fdegC',Tvect(kt)));
      legh(1) = plot(NaN,NaN,'ko');
      legh(2) = plot(NaN,NaN,'k-');
      if kt == 1
        legend(legh,'Lab','Model','Location','northeast');
      end
    end % for
    thesisFormat;
    drawnow;
  end
  
  % Get struct of parameter values.
  model = optutil.unpack(vect,param.modelspec);
  
  % Iterate temperatures, accumulating cost.
  J = 0;
  R0LAB = cell(ntemp,1);
  R0MOD = cell(ntemp,1);
  for kt = 1:length(param.Tvect)
    TdegC = param.Tvect(kt);
    paramT = param.tempSet(kt);
    modelT = optutil.evaltemp(model,param.modelspec,TdegC);
    
    % Fetch pulse resistance measured in the lab.
    R0LAB{kt} = paramT.R0LAB;
    deltaR0LAB = paramT.deltaR0LAB;

    % Compute pulse resistance predicted by the model.
    ocpData = paramT;
    socPct = paramT.socPct;
    iappLAB = paramT.iappLAB;
    if useLinear
      % Only solve at one current value - R0 invariant with current in 
      % the linear case!
      iapp = ones(length(socPct),1);
      R0lin = getPulseResistance( ...
          modelT,socPct,iapp,TdegC,ocpData,useLinear);
      R0MOD{kt} = repmat(R0lin,1,size(iappLAB,2));
    else
      R0MOD{kt} = getPulseResistance( ...
          modelT,socPct,iappLAB,TdegC,ocpData,useLinear);
    end
    
    % Compute the RMS error between the predicted and actual R0.
    err = (R0LAB{kt}-R0MOD{kt})./deltaR0LAB; % weight by uncertainty
    J = J + mean(err.^2,"all");
  end
  J = sqrt(J);
  
  % Output best solution.
  if J < minJ
    minJ = J;
    bestModel = model;
    % Delete iteration counter if  present.
    if bscount2 > 0
      outp.info(repmat('\b',1,bscount2));
      bscount2 = 0;
    end
    % Delete previous iteration's output if present.
    if bscount > 0
      outp.info(repmat('\b',1,bscount));
    end
    outp.charcount(0); % start tracking the number of characters output
    outp.info('  RMSE  : %-20.4f\n',J);
    if outp.debug
      flatmodel = optutil.flattenstruct(bestModel);
      paramnames = fieldnames(flatmodel);
      paramnamesEact = paramnames(endsWith(paramnames,'_Eact'));
      paramnamesR = paramnames(contains(paramnames,'__R')&~endsWith(paramnames,'_Eact'));
      for pr = 1:length(paramnamesR)
        pname = paramnamesR{pr};
        outp.info('  %-20s:  %10.3g mΩ\n',pname,flatmodel.(pname)*1000);
      end % for
      for pe = 1:length(paramnamesEact)
        pname = paramnamesEact{pe};
        meta = param.modelspec.params.(pname(1:end-5));
        Eact = flatmodel.(pname)/1000;
        EactString = sprintf('%10.3f ',Eact);
        outp.info('  %-20s: %s%s kJ\n',pname,meta.tempcoeff,EactString);
      end % for
    end
    bscount = outp.charcount(); % save number of chars output so we can 
                                % backspace next iteration                       
    
    if outp.plot && ishandle(fig)
      for kt = 1:ntemp
        socPct = param.tempSet(kt).socPct;
        iappLAB = param.tempSet(kt).iappLAB;
        for kz = 1:length(socPct)
          set(lines_R0LAB{kt}(kz),'XData',iappLAB(kz,:)/model.const.Q,'Ydata',R0LAB{kt}(kz,:));
          set(lines_R0MOD{kt}(kz),'XData',iappLAB(kz,:)/model.const.Q,'Ydata',R0MOD{kt}(kz,:));
        end % for soc
      end % for temp
      drawnow;
    end % if
  end % if
  
  % Output optimization metadata.
  if outp.debug
    if bscount2 > 0
      outp.info(repmat('\b',1,bscount2));
    end
    outp.charcount(0); % start tracking the number of characters output
    outp.info('  Function evalulation count: %10d\n',evalnum+1);
    bscount2 = outp.charcount(); % save number of chars output so we can 
                            % backspace next iteration
  end
  
  evalnum = evalnum + 1;
  outp.charcount(0); % reset char count
end % cost()