QG3 Extension
This page documents the QG3-based SFNO layers and utility functions defined to work with the QG3 package
ESM_PINO.SFNO — Method
SFNO(
ggsh::QG3.GaussianGridtoSHTransform,
shgg::QG3.SHtoGaussianGridTransform;
in_channels,
out_channels,
hidden_channels,
n_layers,
num_encoder_layers,
num_decoder_layers,
lifting_channel_ratio,
projection_channel_ratio,
channel_mlp_expansion,
activation,
positional_embedding,
inner_skip,
outer_skip,
operator_type,
use_norm,
downsampling_factor,
modes,
gpu,
batch_size,
soft_gating,
bias
) -> SFNO{E, L, B, P, ESM_PINOQG3Ext.ESM_PINOQG3} where {E, L, B, P}
Spherical Fourier Neural Operator (SFNO) layer combining positional embeddings, spectral kernels, and channel MLPs.
This layer implements the SFNO architecture on the sphere, optionally using Zonal Symmetric Kernels (ZSK) following the approach described in Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere.
Arguments
ggsh::QG3.GaussianGridtoSHTransform: Precomputed grid-to-SH transform.shgg::QG3.SHtoGaussianGridTransform: Precomputed SH-to-grid transform.- Other keyword arguments are the same as for the primary constructor, except modes which default is set to
ggsh.output_size[1]. Also, no need to specifybatch_sizeorgpuas these are handled in the transforms.
Returns
SFNO: A Lux-compatible container layer.
Details
- Constructs lifting, SFNO blocks, and projection layers compatible with Lux.jl.
- Positional embeddings are appended if
positional_embedding="grid". - Supports both CPU and GPU execution.
- Zonal Symmetric Kernels (ZSK) reduce the number of parameters and improve stability on spherical domains.
Example
using Lux, QG3, Random, NNlib, LuxCUDA
# Load precomputed QG3 parameters
qg3ppars = QG3.load_precomputed_params()[2]
# Input: [lat, lon, channels, batch]
x = rand(Float32, 32, 64, 3, 10)
# Construct SFNO layer using secondary constructor
ggsh = QG3.GaussianGridtoSHTransform(qg3ppars, 32, N_batch=size(x,4))
shgg = QG3.SHtoGaussianGridTransform(qg3ppars, 32, N_batch=size(x,4))
model2 = SFNO(ggsh, shgg;
modes=15,
in_channels=3,
out_channels=3,
hidden_channels=32,
n_layers=4,
lifting_channel_ratio=2,
projection_channel_ratio=2,
channel_mlp_expansion=2.0,
positional_embedding="no_grid",
outer_skip=true,
zsk=true
)
# Setup parameters and state
rng = Random.default_rng(0)
ps, st = Lux.setup(rng, model2)
# Forward pass
y, st = model2(x, ps, st)
# Compute gradients
using Zygote
gr = Zygote.gradient(ps -> sum(model2(x, ps, st)[1]), ps)ESM_PINO.SFNO — Method
Modified SFNO constructor with variable encoder/decoder depths.
New Arguments
num_encoder_layers::Int=2: Number of layers in the encoder (lifting)num_decoder_layers::Int=2: Number of layers in the decoder (projection)
Notes
- When
num_encoder_layers=2andnum_decoder_layers=2, this behaves identically to the original implementation lifting_channel_ratiocontrols the hidden dimension ratio for the encoderprojection_channel_ratiocontrols the hidden dimension ratio for the decoder
ESM_PINO.SFNO_Block — Method
SFNO_Block(
channels::Int64,
ggsh::QG3.GaussianGridtoSHTransform,
shgg::QG3.SHtoGaussianGridTransform;
modes,
expansion_factor,
activation,
skip,
operator_type,
use_norm,
soft_gating,
bias
) -> ESM_PINO.SFNO_Block{ESM_PINOQG3Ext.ESM_PINOQG3}
A block that combines a spherical kernel with a channel MLP. Expects input in (spatial..., channel, batch) format.
Arguments
channels::Int: Number of input/output channelsggsh::GaussianGridtoSHTransform: Transformation from Gaussian grid to spherical harmonicsshgg::SHtoGaussianGridTransform: Transformation from spherical harmonics back to Gaussianmodes::Int=ggsh.output_size[1]: Number of spherical harmonic modes to retain (default:ggsh.output_size[1])expansion_factor::Real=2.0: Expansion factor for the ChannelMLP (default: 2.0)activation: Activation function applied after combining spatial and spectral branches (default:NNlib.gelu)skip::Bool=true: Whether to include a skip connection (default: true)zsk::Bool=false: Whether to use Zonal Symmetric Kernels (ZSK) (default: false)
Returns
SFNO_Block: A Lux-compatible layer operating on 4D arrays[lat, lon, channels, batch].
Fields
spherical_kernel::SphericalKernel: Spherical kernel layerchannel_mlp::ChannelMLP: Channel-wise MLP layerchannels::Int: Number of input/output channelsskip::Bool: Whether to include a skip connection
-activation::Function: Activation function applied after the block
Details
- The input is processed by a SphericalKernel followed by a ChannelMLP
- If
skipis true, the input is added to the output (residual connection)
ESM_PINO.SFNO_Block — Method
SFNO_Block(
channels::Int64,
pars::QG3ModelParameters;
modes,
batch_size,
expansion_factor,
activation,
skip,
gpu,
operator_type,
use_norm,
soft_gating,
bias
) -> ESM_PINO.SFNO_Block{ESM_PINOQG3Ext.ESM_PINOQG3}
A block that combines a spherical kernel with a channel MLP. Expects input in (spatial..., channel, batch) format.
Arguments
channels::Int: Number of input/output channelspars::QG3ModelParameters: Precomputed QG3 model parameters (QG3ModelParameters)modes::Int=pars.L: Number of spherical harmonic modes to retain (default:pars.L)batch_size::Int=1: Batch size for transforms (default: 1)expansion_factor::Real=2.0: Expansion factor for the ChannelMLP (default: 2.0)activation: Activation function applied after combining spatial and spectral branches (default:NNlib.gelu)skip::Bool=true: Whether to include a skip connection (default: true)gpu::Bool=true: Whether to use GPU (default: true)zsk::Bool=false: Whether to use Zonal Symmetric Kernels (ZSK) (default: false)
Returns
SFNO_Block: A Lux-compatible layer operating on 4D arrays[lat, lon, channels, batch].
Fields
spherical_kernel::SphericalKernel: Spherical kernel layerchannel_mlp::ChannelMLP: Channel-wise MLP layerchannels::Int: Number of input/output channelsskip::Bool: Whether to include a skip connection
-activation::Function: Activation function applied after the block
Details
- The input is processed by a SphericalKernel followed by a ChannelMLP
- If
skipis true, the input is added to the output (residual connection)
ESM_PINO.SphericalKernel — Method
SphericalKernel(
hidden_channels::Int64,
ggsh::QG3.GaussianGridtoSHTransform,
shgg::QG3.SHtoGaussianGridTransform;
use_norm,
modes,
operator_type,
inner_mixing,
bias
) -> ESM_PINO.SphericalKernel{ESM_PINOQG3Ext.ESM_PINOQG3}
Construct a SphericalKernel layer using precomputed transforms.
Arguments
hidden_channels::Int: Number of channelsggsh::GaussianGridtoSHTransform: Transformation from Gaussian grid to spherical harmonicsshgg::SHtoGaussianGridTransform: Transformation from spherical harmonics back to Gaussian gridactivation: Activation function applied after combining spatial and spectral branches (default:NNlib.gelu)modes::Int=ggsh.output_size[1]: Number of spherical harmonic modes to retain (default:ggsh.output_size[1])zsk::Bool=false: Whether to use Zonal Symmetric Kernels (ZSK) (default: false)
Returns
SphericalKernel: A Lux-compatible layer operating on 4D arrays[lat, lon, channels, batch].
Fields
spatial_conv::P: 1x1 convolution operating directly in the spatial domainspherical_conv::SphericalalConv: Spherical convolution layer
-norm::Union{Lux.InstanceNorm, Lux.NoOpLayer}: Optional normalization layer
ESM_PINO.SphericalKernel — Method
SphericalKernel(
hidden_channels::Int64,
pars::QG3ModelParameters;
use_norm,
modes,
batch_size,
gpu,
operator_type,
inner_mixing,
bias
) -> ESM_PINO.SphericalKernel{ESM_PINOQG3Ext.ESM_PINOQG3}
Combines a SphericalConv layer with a 1x1 convolution in parallel, followed by an activation function. Expects input in (spatial..., channel, batch) format.
Arguments
hidden_channels: Number of channelspars: Precomputed QG3 model parameters (QG3ModelParameters)activation: Activation function applied after combining spatial and spectral branches (default:NNlib.gelu)modes: Number of spherical harmonic modes to retain (default:pars.L)batch_size: Batch size for transforms (default: 1)gpu: Whether to use GPU (default: true)zsk: Whether to use Zonal Symmetric Kernels (ZSK) (default: false)
#Returns
SphericalKernel: A Lux-compatible layer operating on 4D arrays `[lat,
Fields
spatial_conv::P: 1x1 convolution operating directly in the spatial domainspherical_conv::SphericalalConv: Spherical convolution layer
-norm::Union{Lux.InstanceNorm, Lux.NoOpLayer}: Optional normalization layer
Details
- The input is processed in parallel by a 1x1 convolution and a spherical convolution
- Outputs from both branches are summed and passed through the activation
- Useful for mixing local (spatial) and global (spectral) information
ESM_PINOQG3Ext.GaussianGridInfo — Type
GaussianGridInfoStructure containing information about a Gaussian grid resolution.
Fields
truncation::Int: Spectral truncation number (e.g., 31 for T31)nlat::Int: Number of latitude pointsnlon::Int: Number of longitude pointskm_at_equator::Float64: Approximate grid spacing at equator in kmdeg_at_equator::Float64: Approximate grid spacing at equator in degreesdescription::String: Human-readable description
ESM_PINOQG3Ext.QG3_Physics_Parameters — Method
QG3_Physics_Parameters(
;
n_lat,
modes,
batch_size,
gpu
) -> ESM_PINOQG3Ext.QG3_Physics_Parameters
Helper constructor to pass as empty default to train_model
ESM_PINOQG3Ext.RemapPlan — Type
RemapPlanPrecomputed plan for efficient array remapping. Stores source and destination indices to avoid recomputation.
ESM_PINOQG3Ext.calculate_gaussian_grid_size — Method
calculate_gaussian_grid_size(truncation::Int) -> Tuple{Int, Int}Calculate Gaussian grid dimensions from spectral truncation number using standard formulas.
For a spectral truncation T, the standard relationships are:
- nlat = (truncation + 1) * 3 / 2 (for reduced grids, varies slightly)
- nlon = 2 * nlat (for regular grids)
Arguments
truncation::Int: Spectral truncation number
Returns
Tuple{Int, Int}: (nlat, nlon)
ESM_PINOQG3Ext.compute_ACC — Method
compute_ACC(X_pred, X_true, lats, ltm; l=1)Compute the Anomaly Correlation Coefficient (ACC) for a variable at time-step l.
Arguments
X_pred: Predicted values, array with dimensions [lat, lon, time] or [lat, lon]X_true: True values, array with dimensions [lat, lon, time] or [lat, lon]lats: Latitude values in degrees for each latitude grid point (length NLat)ltm: Long-term mean to subtract (array with dimensions [lat, lon])
Returns
- ACC value (scalar between -1 and 1)
ESM_PINOQG3Ext.compute_ACC — Method
compute_ACC(X_pred, X_true, pars::QG3.QG3ModelParameters, ltm)Compute the Anomaly Correlation Coefficient (ACC) for a variable at each time-step.
Arguments
X_pred: Predicted values, array with dimensions [lat, lon, time] or [lat, lon]X_true: True values, array with dimensions [lat, lon, time] or [lat, lon]pars::QG3.QG3ModelParameters: QG3 model parameters containing latitude weightsltm: Long-term mean to subtract (array with dimensions [lat, lon])
Returns
- ACC value(s): scalar if input is 2D, vector of length
timeif input is 3D
ESM_PINOQG3Ext.create_remap_plan — Method
create_remap_plan(l::Int, c::Int)Create a precomputed plan for remapping arrays from size 2l+1 to 2c.
ESM_PINOQG3Ext.fine_tuning — Method
fine_tuning(
x::AbstractArray,
target::AbstractArray,
x_val::AbstractArray,
target_val::AbstractArray,
model::SFNO,
ps::NamedTuple,
st::NamedTuple;
seed,
n_steps,
nepochs,
num_examples,
num_valid,
lr_0,
gpu,
parameters,
use_physics,
geometric,
spectral,
α,
β,
logging
) -> LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer{Val{true}, SFNO{E, L, B, P, Q}, _A, NamedTuple{names, T}} where {E, L, B, P, Q, _A, names, T<:Tuple}
Fine-tune a pretrained SFNO model using an autoregressive (AR) loss function. This procedure is typically applied after initial training to improve multi-step forecast accuracy.
The function performs a short fine-tuning loop with autoregressive supervision, optionally including a physics-informed loss component.
Arguments
x::AbstractArray: Input data tensor.target::AbstractArray: Target data tensor with shape(lat, lon, channels, batch, time).model: PretrainedSFNOmodel to be fine-tuned.ps::NamedTuple: Model parameters (from previous training).st::NamedTuple: Model internal state.
Keywords
n_steps::Int=2: Number of autoregressive steps in the loss function.maxiters::Int=5: Maximum number of fine-tuning iterations.lr_0::Float64=1e-5: Learning rate for fine-tuning.parameters::QG3_Physics_Parameters=QG3_Physics_Parameters(): Physical parameters used in the loss.use_physics::Bool=true: Include physics-informed component in the loss iftrue.geometric::Bool=true: Use geometric formulation of the physics loss.α::Float32=0.7f0: Weighting factor between physics and data loss terms.
Returns
StatefulLuxLayer{true}: Fine-tuned model instance with updated parameters and states.
Notes
- The target tensor must have five dimensions, with the number of autoregressive steps as the fifth dimension.
- The time dimension (
size(target, 5)) must match the number of autoregressive steps (n_steps).
Example
# Fine-tune a pretrained SFNO model for multi-step forecasting
ft_model = fine_tuning(x_val, y_val, pretrained_model, ps, st;
n_steps=3, maxiters=10, lr_0=1e-5)
# Evaluate fine-tuned model
pred = ft_model(x_val)ESM_PINOQG3Ext.gaussian_resolution_to_grid — Method
gaussian_resolution_to_grid(resolution::AbstractString) -> Tuple{Int, Int}Convert a Gaussian grid resolution string (e.g., "T31", "T63") to (nlat, nlon) tuple.
Arguments
resolution::AbstractString: Grid resolution in format "TN" where N is truncation number
Returns
Tuple{Int, Int}: (number of latitude points, number of longitude points)
Examples
julia> gaussian_resolution_to_grid("T31")
(48, 96)
julia> gaussian_resolution_to_grid("T63")
(96, 192)
julia> gaussian_resolution_to_grid("T255")
(256, 512)Throws
ArgumentError: If resolution is not recognized
ESM_PINOQG3Ext.get_truncation_from_nlat — Method
get_truncation_from_nlat(nlat::Int) -> IntRetrieve the spectral truncation number for a given number of latitude points.
Arguments
nlat::Int: Number of latitude points
Returns
Int: Spectral truncation number (e.g., 31 for T31)
Examples
julia> get_truncation_from_nlat(48)
31
julia> get_truncation_from_nlat(96)
63
julia> get_truncation_from_nlat(256)
255Throws
ArgumentError: If nlat doesn't match any known Gaussian grid resolution
ESM_PINOQG3Ext.inverse_reorderQG3_indexes — Method
inverse_reorderQG3_indexes(A::AbstractMatrix)Inverse of reorderQG3_indexes - shift each column upward by j÷2.
GPU-compatible and Zygote-differentiable.
ESM_PINOQG3Ext.inverse_reorderQG3_indexes_4d — Method
inverse_reorderQG3_indexes_4d(A::AbstractArray{T,4})Inverse of reorderQG3indexes4d.
GPU-compatible and Zygote-differentiable.
ESM_PINOQG3Ext.make_QG3_loss — Method
make_QG3_loss(pars::QG3_Physics_Parameters;
α=0.5f0,
use_physics::Bool=true,
geometric::Bool=false)Create a composite QG3 loss function suitable for Lux training. Returns a callable (model, ps, st, (input, target)) -> (loss, st, metrics).
ESM_PINOQG3Ext.make_autoregressive_loss — Method
make_autoregressive_loss(QG3_loss::Function; steps::Int, sequential::Bool=true)Create an autoregressive loss function that rolls out predictions over steps and accumulates the loss defined by QG3_loss.
Arguments
QG3_loss: A loss function of the form(model, ps, st, (input, target)) -> (loss, st, details)steps: Number of autoregressive rollout stepssequential: If true, predictions are fed sequentially (standard autoregressive); if false, all predictions are computed and compared in batch for efficiency.
Returns
- A loss function
(model, ps, st, (u_t1, targets)) -> (loss, st, details)
ESM_PINOQG3Ext.preprocess_data — Method
preprocess_data(solu::AbstractArray;
noise_level::Real=0.0,
normalize::Bool=true,
channelwise::Bool=false,
to_gpu::Bool=false,
noise_type::Symbol=:gaussian,
slice_range::Union{Nothing, UnitRange, Tuple}=nothing,
train_fraction::Real=0.8
)Preprocess simulation data with normalization, noise injection, and train/validation splitting.
Arguments
solu::AbstractArray: Raw simulation data array of shape (lat, lon, channel, time)
Keyword Arguments
noise_level::Real=0.0: Standard deviation of noise to add (0.0 = no noise)normalize::Bool=true: Whether to normalize the datachannelwise::Bool=false: If true, normalize each channel independentlyto_gpu::Bool=false: If true, transfer data to GPU after preprocessingnoise_type::Symbol=:gaussian: Type of noise (:gaussian, :uniform, :salt_pepper)slice_range::Union{Nothing, UnitRange, Tuple}=nothing: Optional range(s) to slice datatrain_fraction::Real=0.8: Fraction of data to use for training (rest used for validation)
Returns
q_0_train: Training initial conditionsq_evolved_train: Training evolved statesq_0_val: Validation initial conditionsq_evolved_val: Validation evolved statesμ: Mean(s) used for normalizationσ: Std(s) used for normalizationnormalization_params: NamedTuple with normalization metadata
ESM_PINOQG3Ext.preprocess_data — Method
preprocess_data(;
noise_level::Real=0.0,
normalize::Bool=true,
channelwise::Bool=false,
to_gpu::Bool=false,
noise_type::Symbol=:gaussian,
slice_range::Union{Nothing, UnitRange, Tuple}=nothing,
train_fraction::Real=0.8
)Preprocess simulation data with normalization, noise injection, and train/validation splitting.
Keyword Arguments
noise_level::Real=0.0: Standard deviation of noise to add (0.0 = no noise)normalize::Bool=true: Whether to normalize the datachannelwise::Bool=false: If true, normalize each channel independentlyto_gpu::Bool=false: If true, transfer data to GPU after preprocessingnoise_type::Symbol=:gaussian: Type of noise (:gaussian, :uniform, :salt_pepper)slice_range::Union{Nothing, UnitRange, Tuple}=nothing: Optional range(s) to slice datatrain_fraction::Real=0.8: Fraction of data to use for training (rest used for validation)
Returns
q_0_train: Training initial conditionsq_evolved_train: Training evolved statesq_0_val: Validation initial conditionsq_evolved_val: Validation evolved statesμ: Mean(s) used for normalizationσ: Std(s) used for normalizationnormalization_params: NamedTuple with normalization metadata
Examples
# Basic usage with 80% train, 20% validation
q0_tr, qe_tr, q0_val, qe_val, μ, σ, params = preprocess_data(
normalize=true,
train_fraction=0.8
)
# Custom train/val split with noise
q0_tr, qe_tr, q0_val, qe_val, μ, σ, params = preprocess_data(
noise_level=0.01,
normalize=true,
channelwise=true,
train_fraction=0.7,
to_gpu=true
)ESM_PINOQG3Ext.qg3pars_constructor_helper — Method
qg3pars_constructor_helper(
L::Int64,
n_lat::Int64;
n_lon,
iters,
tol,
NF
) -> QG3ModelParameters{Float32, Int64, Vector{Float32}, Matrix{Float32}}
Helper function to hook the constructor for QG3ModelParameters using a Gaussian grid. Generates latitude/longitude points and initializes empty topography and land/sea mask. Used mainly to handle SH transforms.
Arguments
L::Int: Spectral truncation level (maximum degree).n_lat::Int: Number of Gaussian latitudes.
Keywords
n_lon::Int=2*n_lat: Number of longitudes (default: twice the latitude count).iters::Int=100: Maximum number of iterations for Gaussian grid convergence.tol::Real=1e-8: Convergence tolerance.NF::Type{<:AbstractFloat}=Float32: Number format for outputs.
Returns
QG3ModelParameters: Model parameters including grid coordinates, topography (h),
and land-sea mask (LS).
Example
pars = qg3pars_constructor_helper(42, 64)ESM_PINOQG3Ext.remap_array_components_fast — Method
remap_array_components_fast(arr::AbstractArray{T,4}, plan::RemapPlan)Fast array remapping using a precomputed plan.
ESM_PINOQG3Ext.reorderQG3_indexes — Method
reorderQG3_indexes(A::AbstractMatrix)Reorder a 2D array by applying circular shifts to each column. Column j gets shifted by j÷2 positions downward.
GPU-compatible and Zygote-differentiable.
ESM_PINOQG3Ext.reorderQG3_indexes_4d — Method
reorderQG3_indexes_4d(A::AbstractArray{T,4})Reorder a 4D array by applying the 2D reordering to dimensions 2 and 3. Dimensions 1 and 4 are preserved.
GPU-compatible and Zygote-differentiable.
ESM_PINOQG3Ext.stack_time_steps — Method
stack_time_steps(
data::AbstractArray{T, 4},
time_steps::Int64;
dt,
N_sims,
gpu
) -> Any
Convert a 4D tensor of sequential data (lat, lon, channels, batch) into a 5D tensor suitable for autoregressive training or evaluation. The function constructs overlapping sequences along the batch dimension, each containing time_steps consecutive snapshots.
Arguments
data::AbstractArray{T,4}: Input data tensor with dimensions(lat, lon, channels, batch).time_steps::Int: Number of consecutive time steps to include in each sequence.
Returns
AbstractArray{T,5}: A 5D tensor of shape(lat, lon, channels, time_steps, n_sequences), wheren_sequences = batch - time_steps + 1.
Notes
- The resulting array can be used as autoregressive training targets for multi-step prediction.
- Sequences are created by sliding a window of length
time_stepsalong the batch axis.
Example
# Input: 4D array with 10 time samples
data = rand(Float32, 64, 128, 3, 10)
# Stack into 5D sequences of 4 time steps each
seq_data = stack_time_steps(data, 4)
@assert size(seq_data) == (64, 128, 3, 4, 7)ESM_PINOQG3Ext.train_model — Method
train_model(
x::AbstractArray,
target::AbstractArray,
pars::QG3ModelParameters;
seed,
maxiters,
batchsize,
modes,
in_channels,
out_channels,
hidden_channels,
n_layers,
lifting_channel_ratio,
projection_channel_ratio,
channel_mlp_expansion,
activation,
positional_embedding,
inner_skip,
outer_skip,
operator_type,
use_norm,
downsampling_factor,
lr_0,
gpu,
parameters,
use_physics,
geometric,
α
) -> LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer{Val{true}, SFNO{E, L, B, P, Q}, _A, NamedTuple{names, T}} where {E, L, B, P, Q, _A, names, T<:Tuple}
Train an SFNO model with the possibility using a combined data-driven (simple MSE or geometrically weighted) and physics-informed loss. This function initializes the model, optimizer, and training loop, and performs iterative optimization of the model parameters.
Both standard data loss and an optional physics-informed term are tracked during training.
Arguments
x::AbstractArray: Input training data tensor.target::AbstractArray: Target (ground truth) data tensor.pars::QG3ModelParameters: Model configuration including grid and spectral parameters.
Keywords
seed::Int=0: Random seed for reproducibility.maxiters::Int=20: Number of training iterations.batchsize::Int=256: Mini-batch size for training.modes::Int=pars.L: Spectral truncation level.in_channels::Int=3: Number of input channels.out_channels::Int=3: Number of output channels.hidden_channels::Int=256: Width of the hidden feature layers.n_layers::Int=4: Number of SFNO layers.lifting_channel_ratio::Int=2: Ratio of lifting layer expansion.projection_channel_ratio::Int=2: Ratio of projection layer contraction.channel_mlp_expansion::Number=2.0: Expansion factor in channel MLP blocks.activation: Activation function used in SFNO blocks (default:NNlib.gelu).positional_embedding::AbstractString="grid": Type of positional embedding ("grid"or"no_grid").inner_skip::Bool=true: Whether to enable residual connections inside SFNO blocks.outer_skip::Bool=true: Whether to enable skip connections between lifting output and projection input.zsk::Bool=false: Use zonally symmetric kernel formulation iftrue.use_norm::Bool=false: Apply normalization layers inside SFNO blocks.downsampling_factor::Int=2: Ratio of downsampling between layers.lr_0::Float64=1e-3: Initial learning rate for the optimizer.parameters::QG3_Physics_Parameters=QG3_Physics_Parameters(pars, batch_size=batchsize): Physical parameters used in the QG3 loss.use_physics::Bool=true: Whether to include the physics-informed loss component.geometric::Bool=true: Use geometrically weighted formulation for the data loss.α::Float32=0.7f0: Weighting factor between physics loss and data loss.
Returns
StatefulLuxLayer{true}: Trained SFNO model containing learned parameters and internal state.
Example
# Initialize parameters and data
pars = qg3pars_constructor_helper(42, 64)
x, y = generate_training_data(pars)
# Train SFNO model
trained_model = train_model(x, y, pars; maxiters=100, batchsize=128, lr_0=5e-4)
# Perform inference
pred = trained_model(x)ESM_PINOQG3Ext.transfer_SFNO_model — Method
transfer_SFNO_model(
model::SFNO,
qg3ppars::QG3ModelParameters;
batch_size,
gpu
) -> SFNO{E, L, B, P, ESM_PINOQG3Ext.ESM_PINOQG3} where {E, L, B, P}
Construct a new SFNO model that replicates the architecture and parameters of an existing model, but adapts them to a new discretization (qg3ppars) and batch size. This function preserves all spectral modes, channels, and hyperparameters while adjusting the internal transform plans to match the new grid configuration.
Arguments
model::SFNO: Source SFNO model whose architecture and parameters will be cloned.qg3ppars: Target problem parameters (e.g., grid, spectral resolution).
Keywords
batch_size::Int: Optional new batch size. Defaults to the batch size inferred frommodel.sfno_blocks.layers.layer_1.spherical_kernel.spherical_conv.plan.ggsh.FT_4d.plan.input_size[4].
Returns
SFNO: A new model instance with the same architecture asmodel, configured for the target discretization and batch size.
Example
# Original model (batch size = 32)
model = SFNO(orig_pars; batch_size=32, ...)
# Transfer model to a finer grid and larger batch
new_model = transfer_SFNO_model(model, new_pars; batch_size=64)
# Forward pass with transferred weights
ŷ = new_model(x, ps, st)