# kranke - April 2025
# Helper functions for 18.05 studio on climate variability

RdBu.color <- colorRampPalette( c("#08306B", "#2171B5", "#6BAED6", "#C6DBEF", "#F7FBFF", "#FEE0D2",  "#FC9272", "#EF3B2C", "#A50F15" ) )

select_region <- function( variable, n4obj, lats, lons ) {
  # Extract variable from n4obj in the box defined by lats/lons
  n4lats  = n4obj$dim$lat$vals
  n4lons  = n4obj$dim$lon$vals
  latsIdx = n4lats <= lats[ 2 ] & n4lats >= lats[ 1 ]
  lonsIdx = n4lons <= lons[ 2 ] & n4lons >= lons[ 1 ]
  vals    = variable
  vals[ !lonsIdx,, ] = NA
  vals[ ,!latsIdx, ] = NA
  if ( length( dim( variable ) ) == 3 ) {
#    return( list( lat = n4lats[ latsIdx ], lon = n4lons[ lonsIdx ], vals=vals ) )
    return( vals )
  } 
}

drawContourMap <- function( lat, lon, variable, lon360 = TRUE, ... ) {
  # Define a function that draws a map overlayed with nation's boundaries
  # First fix the longitude difference between map() [-180,180] and the netcdf data [0, 360]
  # ... allows to pass other arguments to filled.contour
  if ( lon360 ) {
    # Fix indices if longitude is given in 0/360 range
    lonfix               = lon
    lonfix[ lon >= 180 ] = lon[ lon >= 180 ]  - 360
    lonvec               = lon - 180
    idxlon               = match( lonfix, lonvec )
  } else {
    # Do not fix indices if longitude is given in -180/180 range
    idxlon = 1:length( lon )
  }
  filled.contour( lonvec, lat, variable[ idxlon, ], plot.axes = { map("world", add = T, col = "black", fill = F ); map.axes() }, xlab = 'Longitude', ylab = 'Latitude', ... )
}

calc_coslat_wmean <- function( variable, lat ) {
  #    Calculate the cosine of latitude-weighted mean of a data array.
  #    This function computes the area-weighted mean of the input data array
  #    using cosine of latitude as weights. It ensures that the averaging
  #    accounts for the varying area represented by each grid cell in
  #    latitude-longitude gridded datasets.
  NT            = dim( variable )[ 3 ]
  latvec        = rep( lat, each = dim( variable )[ 1 ] )
  coslat        = cos( latvec * pi/180 )
  wavg          = rep( NA, dim( variable )[ 3 ] )
  for ( i in 1:NT ) {
    idx_non_na    = !is.na( c( variable[,,i ] ) )
    weight_factor = coslat[ idx_non_na ] / mean( coslat[ idx_non_na ] )
    wavg[ i ]     = mean( c( variable[,,i][ idx_non_na ] ) * weight_factor )
  }
  return( wavg )
}

remove_monthly_clm <- function( tvec, variable ) {
  # Remove the monthly climatology:
  # This function calculates and removes the monthly climatology (mean for each month)
  # from the input data variable, producing anomalies relative to the climatological baseline.
  months       = month( tvec )
  NT           = length( tvec ) 
  anm          = rep( NA, NT )
  uniqueMonths = unique( months )
  for ( k in 1:length( uniqueMonths ) ) {
    idxmonth = which( months == uniqueMonths[ k ] ) 
    climo    = mean( variable[ idxmonth ] )
    anm[ idxmonth ] = variable[ idxmonth ] - climo
  }
  return( anm )
}

calc_rolling_mean <- function( x, window, min_periods = 3 ) {
  # This function computes a rolling mean (moving average) for the input data array
  # over a specified window size along the 'time' dimension. The rolling operation
  # is centered by default and handles missing values by requiring a minimum number
  # of valid observations.
  n      = length( x )
  half   = floor( window / 2 )
  result = rep( NA, n )
  for ( i in seq_len( n ) ) {
    start        = max( 1, i - half )
    end          = min( n, i + half )
    window_vals  = x[ start:end ]
    non_na_count = sum( !is.na( window_vals ) )
    if ( non_na_count >= min_periods ) {
      result[ i ] <- mean( window_vals, na.rm = TRUE )
    }
  }
  return( result )
}


extract_true_sequences <- function( logical_vec ) {
  # Helper function for plot_nino34_index_timeseries
  # Gets the continuous sequences of TRUE that define an ElNino/LaNina event
  true_indices = which( logical_vec )
  if ( length( true_indices ) == 0 ) {
    return( list( ) ) 
  }
  # Initialize a list to store sequences of indices
  sequences        = list()
  current_sequence = c( true_indices[ 1 ] )
  
  # Loop through the indices of TRUE values
  for ( i in 2:length(true_indices ) ) {
    if ( true_indices[ i ]  == true_indices[ i-1 ] + 1 ) {
      # If consecutive, add to the current sequence
      current_sequence = c( current_sequence, true_indices[ i ] )
    } else {
      # If not consecutive, save the current sequence and start a new one
      sequences        = append( sequences, list( current_sequence ) )
      current_sequence = c( true_indices[ i ] )
    }
  }
  # Append the last sequence
  sequences = append(sequences, list(current_sequence))
  return( sequences )
}


plot_nino34_index_timeseries <- function( tvec, variable, threshold = 0.4, ... ) {
  # Plot the Niño 3.4 index as a time series with ENSO phases highlighted.
  # This function creates a time series plot of the Niño 3.4 index,
  # highlighting positive (El Niño) and negative (La Niña) phases using shaded regions.
  # The plot includes horizontal threshold lines at ±0.4°C.

  # Set up the plot
  plot( tvec, variable, type = "l", ylab = "Temperature Anomaly (K)", xlab = NA, ... )
  
  # Fill El Nino (positive anomalies)
  idxPoly = extract_true_sequences( variable >= threshold )
  for ( p in 1:length( idxPoly ) ) {
    xvalues = c( tvec[ idxPoly[[ p ]] ], rev( tvec[ idxPoly[[ p ]] ] ) )
    yvalues = c( variable[ idxPoly[[ p ]] ], rep( threshold, length( idxPoly[[ p ]] ) ) )
    polygon( xvalues, yvalues, col = 'red', border = NA )
  }

  # Fill La Nina (negative anomalies)
  idxPoly = extract_true_sequences( variable <= -threshold )
  for ( p in 1:length( idxPoly ) ) {
    xvalues = c( tvec[ idxPoly[[ p ]] ], rev( tvec[ idxPoly[[ p ]] ] ) )
    yvalues = c( variable[ idxPoly[[ p ]] ], rep( -threshold, length( idxPoly[[ p ]] ) ) )
    polygon( xvalues, yvalues, col = 'blue', border = NA )
  }

  # Add horizontal lines
  abline( h = 0, col = "black", lwd = 1 )
  abline( h = c( threshold, -threshold ), col = "black", lty = "dotted", lwd = 0.5 )
}

downsample_quarterly <- function( tvec, variable ) {
    # Helper function to downsample monthly time-series data to seasonal data
    # Convention will be NDJ, FMA, MJJ, ASO 
    # First pad the time vector to force it to start from 
    # Assumption: current time vector starts from january and end in December
    # the year returned in the dataframe is the january one for NDJ
    NT           = length( tvec )
    months       = month( tvec )
    years        = year( tvec )
    startPadding = c( make_date( years[1]-1, 11, 1 ), make_date( years[1]-1, 12, 1 ) )
    endPadding   = make_date( years[ NT ] + 1, 1, 1 )
    tvecP        = c( startPadding, tvec, endPadding )
    qAvg         = colMeans( matrix( c( NA, NA, variable, NA ), 3 ), na.rm = TRUE )
    qLabels      = c( "NDJ", "FMA", "MJJ", "ASO" )
    return( data.frame( quarter = rep( qLabels, length.out = length( tvecP )/3 ), year = c( years[ seq(1,NT,3) ], years[ NT ]+1 ), val = qAvg ) )
}

shift_vector <- function( x, lag = 1 ) {
  # Function to shift a vector to compute lag correlations
  n = length( x )
  if ( lag > 0 ) {
    return( c( rep( NA, lag ), x[ 1:( n - lag ) ] ) ) 
  }
  if ( lag < 0 ) {
    return(c( x[ (-lag + 1):n ], rep( NA, -lag ) ) )
  }
  if ( lag == 0 ) { 
    return( x )
  }
}

get_running_corr <- function( x, y, window = 13, min_periods = 5 ) {
  # Calculate the rolling mean of an xarray DataArray along the time dimension.
  # This function computes a rolling mean (moving average) for the input data array
  # over a specified window size along the 'time' dimension. The rolling operation
  # is centered by default and handles missing values by requiring a minimum number
  # of valid observations.
  n      = length( x )
  half   = floor( window / 2 )
  result = rep( NA, n )
  for ( i in seq_len( n ) ) {
    start        = max( 1, i - half )
    end          = min( n, i + half )
    x_window_vals  = x[ start:end ]
    y_window_vals  = y[ start:end ]
    non_na_count = sum( !is.na( x_window_vals ) & !is.na( y_window_vals ) )
    if ( non_na_count >= min_periods ) {
      result[ i ] <- cor( x_window_vals, y_window_vals, use = "complete.obs" )
    }
  }
  return( result )
}
