/**
 * WGSL compute shaders for Kitten TTS WebGPU inference.
 *
 * The model has these major shader requirements:
 * 1. Embedding lookup (word + position)
 * 2. Layer Normalization
 * 3. Matrix multiplication (with quantized weight dequant)
 * 4. Multi-head attention (ALBERT encoder)
 * 5. Conv1d (text encoder CNN, predictor, decoder, generator)
 * 6. LSTM (text encoder, predictor, shared)
 * 7. Instance Normalization (decoder)
 * 8. Adaptive Instance Normalization (AdaIN - style conditioning)
 * 9. ConvTranspose1d (HiFi-GAN upsampling)
 * 10. LeakyReLU, GELU, Tanh, Sigmoid activations
 * 11. Residual connections
 */
export declare const embeddingShader = "\n@group(0) @binding(0) var<storage, read> embeddings: array<f32>;\n@group(0) @binding(1) var<storage, read> input_ids: array<i32>;\n@group(0) @binding(2) var<storage, read_write> output: array<f32>;\n\nstruct Params {\n  seq_len: u32,\n  embed_dim: u32,\n  vocab_size: u32,\n}\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let seq_idx = idx / params.embed_dim;\n  let dim_idx = idx % params.embed_dim;\n\n  if (seq_idx >= params.seq_len) { return; }\n\n  let token_id = input_ids[seq_idx];\n  let embed_offset = u32(token_id) * params.embed_dim + dim_idx;\n  output[idx] = embeddings[embed_offset];\n}\n";
export declare const layerNormShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;\n@group(0) @binding(1) var<storage, read> gamma: array<f32>;\n@group(0) @binding(2) var<storage, read> beta: array<f32>;\n@group(0) @binding(3) var<storage, read_write> output: array<f32>;\n\nstruct Params {\n  batch_size: u32,\n  hidden_size: u32,\n  eps: f32,\n}\n@group(0) @binding(4) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let batch_idx = gid.x;\n  if (batch_idx >= params.batch_size) { return; }\n\n  let offset = batch_idx * params.hidden_size;\n\n  // Compute mean\n  var sum = 0.0;\n  for (var i = 0u; i < params.hidden_size; i++) {\n    sum += input[offset + i];\n  }\n  let mean = sum / f32(params.hidden_size);\n\n  // Compute variance\n  var var_sum = 0.0;\n  for (var i = 0u; i < params.hidden_size; i++) {\n    let diff = input[offset + i] - mean;\n    var_sum += diff * diff;\n  }\n  let variance = var_sum / f32(params.hidden_size);\n  let inv_std = 1.0 / sqrt(variance + params.eps);\n\n  // Normalize\n  for (var i = 0u; i < params.hidden_size; i++) {\n    output[offset + i] = (input[offset + i] - mean) * inv_std * gamma[i] + beta[i];\n  }\n}\n";
export declare const matmulShader = "\n// Tiled matmul with shared memory. TILE=16, each workgroup computes a 16\u00D716 output tile.\n// Reduces global memory reads by factor of TILE compared to naive approach.\n\nconst TILE: u32 = 16u;\n\n@group(0) @binding(0) var<storage, read> A: array<f32>;\n@group(0) @binding(1) var<storage, read> B: array<f32>;\n@group(0) @binding(2) var<storage, read> bias: array<f32>;\n@group(0) @binding(3) var<storage, read_write> output: array<f32>;\n\nstruct Params {\n  M: u32,  // rows of A / output\n  K: u32,  // cols of A / rows of B\n  N: u32,  // cols of B / output\n  use_bias: u32,\n}\n@group(0) @binding(4) var<uniform> params: Params;\n\nvar<workgroup> tileA: array<f32, 256>;  // 16\u00D716\nvar<workgroup> tileB: array<f32, 256>;  // 16\u00D716\n\n@compute @workgroup_size(16, 16)\nfn main(\n  @builtin(global_invocation_id) gid: vec3<u32>,\n  @builtin(local_invocation_id) lid: vec3<u32>,\n) {\n  let row = gid.x;\n  let col = gid.y;\n  let lr = lid.x;\n  let lc = lid.y;\n\n  var sum = 0.0;\n  let numTiles = (params.K + TILE - 1u) / TILE;\n\n  for (var t = 0u; t < numTiles; t++) {\n    // Load tile of A: rows [row_base..+16], cols [t*16..+16]\n    let aCol = t * TILE + lc;\n    if (row < params.M && aCol < params.K) {\n      tileA[lr * TILE + lc] = A[row * params.K + aCol];\n    } else {\n      tileA[lr * TILE + lc] = 0.0;\n    }\n\n    // Load tile of B: rows [t*16..+16], cols [col_base..+16]\n    let bRow = t * TILE + lr;\n    if (bRow < params.K && col < params.N) {\n      tileB[lr * TILE + lc] = B[bRow * params.N + col];\n    } else {\n      tileB[lr * TILE + lc] = 0.0;\n    }\n\n    workgroupBarrier();\n\n    // Accumulate dot product from shared memory\n    for (var k = 0u; k < TILE; k++) {\n      sum += tileA[lr * TILE + k] * tileB[k * TILE + lc];\n    }\n\n    workgroupBarrier();\n  }\n\n  if (row < params.M && col < params.N) {\n    if (params.use_bias != 0u) {\n      sum += bias[col];\n    }\n    output[row * params.N + col] = sum;\n  }\n}\n";
export declare const conv1dShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;     // [C_in, L]\n@group(0) @binding(1) var<storage, read> weight: array<f32>;    // [C_out, C_in, K]\n@group(0) @binding(2) var<storage, read> bias: array<f32>;      // [C_out]\n@group(0) @binding(3) var<storage, read_write> output: array<f32>; // [C_out, L_out]\n\nstruct Params {\n  in_channels: u32,\n  out_channels: u32,\n  kernel_size: u32,\n  input_length: u32,\n  output_length: u32,\n  padding: u32,\n  stride: u32,\n  dilation: u32,\n  use_bias: u32,\n}\n@group(0) @binding(4) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let out_ch = idx / params.output_length;\n  let out_pos = idx % params.output_length;\n\n  if (out_ch >= params.out_channels) { return; }\n\n  var sum = 0.0;\n  for (var ic = 0u; ic < params.in_channels; ic++) {\n    for (var k = 0u; k < params.kernel_size; k++) {\n      let in_pos_raw = i32(out_pos * params.stride) + i32(k * params.dilation) - i32(params.padding);\n      if (in_pos_raw >= 0 && u32(in_pos_raw) < params.input_length) {\n        let w_idx = out_ch * params.in_channels * params.kernel_size + ic * params.kernel_size + k;\n        let in_idx = ic * params.input_length + u32(in_pos_raw);\n        sum += input[in_idx] * weight[w_idx];\n      }\n    }\n  }\n\n  if (params.use_bias != 0u) {\n    sum += bias[out_ch];\n  }\n\n  output[out_ch * params.output_length + out_pos] = sum;\n}\n";
export declare const instanceNormShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;     // [C, L]\n@group(0) @binding(1) var<storage, read_write> output: array<f32>; // [C, L]\n\nstruct Params {\n  channels: u32,\n  length: u32,\n  eps: f32,\n}\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let ch = gid.x;\n  if (ch >= params.channels) { return; }\n\n  let offset = ch * params.length;\n\n  // Compute mean\n  var sum = 0.0;\n  for (var i = 0u; i < params.length; i++) {\n    sum += input[offset + i];\n  }\n  let mean = sum / f32(params.length);\n\n  // Compute variance\n  var var_sum = 0.0;\n  for (var i = 0u; i < params.length; i++) {\n    let diff = input[offset + i] - mean;\n    var_sum += diff * diff;\n  }\n  let variance = var_sum / f32(params.length);\n  let inv_std = 1.0 / sqrt(variance + params.eps);\n\n  // Normalize (no scale/bias for instance norm in this model - AdaIN handles that)\n  for (var i = 0u; i < params.length; i++) {\n    output[offset + i] = (input[offset + i] - mean) * inv_std;\n  }\n}\n";
export declare const adainShader = "\n@group(0) @binding(0) var<storage, read> normed: array<f32>;    // [C, L] - instance-normed input\n@group(0) @binding(1) var<storage, read> style_fc: array<f32>;  // [2*C] - first C = scale, second C = bias\n@group(0) @binding(2) var<storage, read_write> output: array<f32>; // [C, L]\n\nstruct Params {\n  channels: u32,\n  length: u32,\n}\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let ch = idx / params.length;\n  let pos = idx % params.length;\n\n  if (ch >= params.channels) { return; }\n\n  // AdaIN: (1 + gamma) * normed + beta \u2014 the +1 offset is universal across all AdaIN blocks\n  // style_fc layout: [scale_0..scale_{C-1}, bias_0..bias_{C-1}]\n  let scale = style_fc[ch];\n  let bias = style_fc[params.channels + ch];\n  output[idx] = normed[idx] * (scale + 1.0) + bias;\n}\n";
export declare const adainRowMajorShader = "\n@group(0) @binding(0) var<storage, read> normed: array<f32>;    // [rows, C]\n@group(0) @binding(1) var<storage, read> style_fc: array<f32>;  // [2*C]\n@group(0) @binding(2) var<storage, read_write> output: array<f32>; // [rows, C]\n\nstruct Params {\n  channels: u32,\n  total: u32,\n}\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  if (idx >= params.total) { return; }\n\n  // Row-major: channel = idx % channels\n  let ch = idx % params.channels;\n  let scale = style_fc[ch];\n  let bias = style_fc[params.channels + ch];\n  output[idx] = normed[idx] * (scale + 1.0) + bias;\n}\n";
export declare const snakeShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;     // [C, L]\n@group(0) @binding(1) var<storage, read> alpha: array<f32>;     // [C] (flattened from [1, C, 1])\n@group(0) @binding(2) var<storage, read_write> output: array<f32>; // [C, L]\n\nstruct Params {\n  channels: u32,\n  length: u32,\n}\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let ch = idx / params.length;\n  let pos = idx % params.length;\n\n  if (ch >= params.channels) { return; }\n\n  let x = input[idx];\n  let a = alpha[ch];\n  let sin_ax = sin(a * x);\n  // Snake: x + (1/a) * sin\u00B2(a * x)\n  output[idx] = x + sin_ax * sin_ax / a;\n}\n";
export declare const leakyReluShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;\n@group(0) @binding(1) var<storage, read_write> output: array<f32>;\n\nstruct Params {\n  size: u32,\n  alpha: f32,\n}\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  if (idx >= params.size) { return; }\n  let x = input[idx];\n  output[idx] = select(params.alpha * x, x, x >= 0.0);\n}\n";
export declare const geluShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;\n@group(0) @binding(1) var<storage, read_write> output: array<f32>;\n\nstruct Params {\n  size: u32,\n}\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  if (idx >= params.size) { return; }\n  let x = input[idx];\n  // GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))\n  // Clamp tanh arg to prevent exp(2x) overflow in f32 (exp overflows at ~88.72)\n  let c = 0.7978845608; // sqrt(2/pi)\n  let inner = clamp(c * (x + 0.044715 * x * x * x), -44.0, 44.0);\n  output[idx] = 0.5 * x * (1.0 + tanh(inner));\n}\n";
export declare const tanhShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;\n@group(0) @binding(1) var<storage, read_write> output: array<f32>;\n\nstruct Params { size: u32 }\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  if (idx >= params.size) { return; }\n  output[idx] = tanh(input[idx]);\n}\n";
export declare const sigmoidShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;\n@group(0) @binding(1) var<storage, read_write> output: array<f32>;\n\nstruct Params { size: u32 }\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  if (idx >= params.size) { return; }\n  output[idx] = 1.0 / (1.0 + exp(-input[idx]));\n}\n";
export declare const convTranspose1dShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;     // [C_in, L_in]\n@group(0) @binding(1) var<storage, read> weight: array<f32>;    // [C_in, C_out, K]\n@group(0) @binding(2) var<storage, read> bias: array<f32>;      // [C_out]\n@group(0) @binding(3) var<storage, read_write> output: array<f32>; // [C_out, L_out]\n\nstruct Params {\n  in_channels: u32,\n  out_channels: u32,\n  kernel_size: u32,\n  input_length: u32,\n  output_length: u32,\n  stride: u32,\n  padding: u32,\n  use_bias: u32,\n}\n@group(0) @binding(4) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let out_ch = idx / params.output_length;\n  let out_pos = idx % params.output_length;\n\n  if (out_ch >= params.out_channels) { return; }\n\n  var sum = 0.0;\n  for (var ic = 0u; ic < params.in_channels; ic++) {\n    for (var k = 0u; k < params.kernel_size; k++) {\n      // ConvTranspose: output[out_pos] += input[in_pos] * weight[ic, out_ch, k]\n      // where out_pos = in_pos * stride + k - padding\n      // so in_pos = (out_pos + padding - k) / stride\n      let numerator = i32(out_pos) + i32(params.padding) - i32(k);\n      if (numerator >= 0 && u32(numerator) % params.stride == 0u) {\n        let in_pos = u32(numerator) / params.stride;\n        if (in_pos < params.input_length) {\n          let w_idx = ic * params.out_channels * params.kernel_size + out_ch * params.kernel_size + k;\n          let in_idx = ic * params.input_length + in_pos;\n          sum += input[in_idx] * weight[w_idx];\n        }\n      }\n    }\n  }\n\n  if (params.use_bias != 0u) {\n    sum += bias[out_ch];\n  }\n\n  output[out_ch * params.output_length + out_pos] = sum;\n}\n";
export declare const depthwiseConvTranspose1dShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;     // [channels, L_in]\n@group(0) @binding(1) var<storage, read> weight: array<f32>;    // [channels, 1, K] = [channels * K]\n@group(0) @binding(2) var<storage, read_write> output: array<f32>; // [channels, L_out]\n\nstruct Params {\n  channels: u32,\n  kernel_size: u32,\n  input_length: u32,\n  output_length: u32,\n  stride: u32,\n  padding: u32,\n}\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let ch = idx / params.output_length;\n  let out_pos = idx % params.output_length;\n\n  if (ch >= params.channels) { return; }\n\n  var sum = 0.0;\n  for (var k = 0u; k < params.kernel_size; k++) {\n    let numerator = i32(out_pos) + i32(params.padding) - i32(k);\n    if (numerator >= 0 && u32(numerator) % params.stride == 0u) {\n      let in_pos = u32(numerator) / params.stride;\n      if (in_pos < params.input_length) {\n        let w_idx = ch * params.kernel_size + k;\n        let in_idx = ch * params.input_length + in_pos;\n        sum += input[in_idx] * weight[w_idx];\n      }\n    }\n  }\n\n  output[ch * params.output_length + out_pos] = sum;\n}\n";
export declare const resize1dShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;      // [channels, L_in]\n@group(0) @binding(1) var<storage, read_write> output: array<f32>; // [channels, L_out]\n\nstruct Params {\n  channels: u32,\n  input_length: u32,\n  output_length: u32,\n}\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let ch = idx / params.output_length;\n  let out_pos = idx % params.output_length;\n\n  if (ch >= params.channels) { return; }\n\n  // Nearest neighbor: map output position to input position\n  let in_pos = out_pos * params.input_length / params.output_length;\n  output[ch * params.output_length + out_pos] = input[ch * params.input_length + in_pos];\n}\n";
export declare const softmaxShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;\n@group(0) @binding(1) var<storage, read_write> output: array<f32>;\n\nstruct Params {\n  batch_size: u32,\n  dim_size: u32,\n}\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let batch_idx = gid.x;\n  if (batch_idx >= params.batch_size) { return; }\n\n  let offset = batch_idx * params.dim_size;\n\n  // Find max for numerical stability\n  var max_val = input[offset];\n  for (var i = 1u; i < params.dim_size; i++) {\n    max_val = max(max_val, input[offset + i]);\n  }\n\n  // Compute exp and sum\n  var exp_sum = 0.0;\n  for (var i = 0u; i < params.dim_size; i++) {\n    let e = exp(input[offset + i] - max_val);\n    output[offset + i] = e;\n    exp_sum += e;\n  }\n\n  // Normalize\n  for (var i = 0u; i < params.dim_size; i++) {\n    output[offset + i] /= exp_sum;\n  }\n}\n";
export declare const mhaShader = "\n@group(0) @binding(0) var<storage, read> Q: array<f32>;  // [seq_len, num_heads, head_dim]\n@group(0) @binding(1) var<storage, read> K: array<f32>;  // [seq_len, num_heads, head_dim]\n@group(0) @binding(2) var<storage, read> V: array<f32>;  // [seq_len, num_heads, head_dim]\n@group(0) @binding(3) var<storage, read_write> output: array<f32>; // [seq_len, num_heads, head_dim]\n\nstruct Params {\n  seq_len: u32,\n  num_heads: u32,\n  head_dim: u32,\n  scale: f32,  // 1/sqrt(head_dim)\n}\n@group(0) @binding(4) var<uniform> params: Params;\n\n// Workgroup: one per (head, query_pos). Threads iterate over key positions.\n// We use a simple approach: each thread computes one output element (head_dim index).\n\n@compute @workgroup_size(64)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  // gid.x = dim_idx within head, gid.y = head_idx * seq_len + query_pos\n  let dim_idx = gid.x;\n  let head_query = gid.y;\n  let head_idx = head_query / params.seq_len;\n  let q_pos = head_query % params.seq_len;\n\n  if (dim_idx >= params.head_dim || head_idx >= params.num_heads) { return; }\n\n  let hd = params.head_dim;\n  let nh = params.num_heads;\n  let sl = params.seq_len;\n\n  // Q vector for this (q_pos, head): Q[q_pos * nh * hd + head_idx * hd + ...]\n  let q_base = q_pos * nh * hd + head_idx * hd;\n\n  // Compute attention scores: dot(Q[q_pos, head], K[k_pos, head]) for all k_pos\n  // Then softmax and weighted sum of V\n  // Since we can't do cross-thread softmax easily, each thread computes full attention\n  // for one output dimension. This is O(seq_len * head_dim) per thread but simple.\n\n  // Step 1: Compute all attention scores (each thread does this redundantly)\n  // For short sequences (< 512) this is fine\n  var max_score = -1e10;\n  for (var k = 0u; k < sl; k++) {\n    let k_base = k * nh * hd + head_idx * hd;\n    var score = 0.0;\n    for (var d = 0u; d < hd; d++) {\n      score += Q[q_base + d] * K[k_base + d];\n    }\n    score *= params.scale;\n    max_score = max(max_score, score);\n  }\n\n  // Step 2: Softmax\n  var exp_sum = 0.0;\n  var weighted_val = 0.0;\n  for (var k = 0u; k < sl; k++) {\n    let k_base = k * nh * hd + head_idx * hd;\n    var score = 0.0;\n    for (var d = 0u; d < hd; d++) {\n      score += Q[q_base + d] * K[k_base + d];\n    }\n    score *= params.scale;\n    let w = exp(score - max_score);\n    exp_sum += w;\n\n    // Accumulate V[k_pos, head, dim_idx] weighted by attention\n    let v_base = k * nh * hd + head_idx * hd;\n    weighted_val += w * V[v_base + dim_idx];\n  }\n\n  let out_idx = q_pos * nh * hd + head_idx * hd + dim_idx;\n  output[out_idx] = weighted_val / exp_sum;\n}\n";
export declare const matmulGeluShader = "\n// Tiled matmul + GELU with shared memory.\n\nconst TILE: u32 = 16u;\n\n@group(0) @binding(0) var<storage, read> A: array<f32>;\n@group(0) @binding(1) var<storage, read> B: array<f32>;\n@group(0) @binding(2) var<storage, read> bias: array<f32>;\n@group(0) @binding(3) var<storage, read_write> output: array<f32>;\n\nstruct Params {\n  M: u32,\n  K: u32,\n  N: u32,\n}\n@group(0) @binding(4) var<uniform> params: Params;\n\nvar<workgroup> tileA: array<f32, 256>;\nvar<workgroup> tileB: array<f32, 256>;\n\n@compute @workgroup_size(16, 16)\nfn main(\n  @builtin(global_invocation_id) gid: vec3<u32>,\n  @builtin(local_invocation_id) lid: vec3<u32>,\n) {\n  let row = gid.x;\n  let col = gid.y;\n  let lr = lid.x;\n  let lc = lid.y;\n\n  var sum = 0.0;\n  let numTiles = (params.K + TILE - 1u) / TILE;\n\n  for (var t = 0u; t < numTiles; t++) {\n    let aCol = t * TILE + lc;\n    if (row < params.M && aCol < params.K) {\n      tileA[lr * TILE + lc] = A[row * params.K + aCol];\n    } else {\n      tileA[lr * TILE + lc] = 0.0;\n    }\n\n    let bRow = t * TILE + lr;\n    if (bRow < params.K && col < params.N) {\n      tileB[lr * TILE + lc] = B[bRow * params.N + col];\n    } else {\n      tileB[lr * TILE + lc] = 0.0;\n    }\n\n    workgroupBarrier();\n\n    for (var k = 0u; k < TILE; k++) {\n      sum += tileA[lr * TILE + k] * tileB[k * TILE + lc];\n    }\n\n    workgroupBarrier();\n  }\n\n  if (row < params.M && col < params.N) {\n    sum += bias[col];\n    // GELU activation (clamp tanh arg to prevent f32 exp overflow)\n    let c = 0.7978845608;\n    let x = sum;\n    let inner = clamp(c * (x + 0.044715 * x * x * x), -44.0, 44.0);\n    output[row * params.N + col] = 0.5 * x * (1.0 + tanh(inner));\n  }\n}\n";
export declare const addShader = "\n@group(0) @binding(0) var<storage, read> a: array<f32>;\n@group(0) @binding(1) var<storage, read> b: array<f32>;\n@group(0) @binding(2) var<storage, read_write> output: array<f32>;\n\nstruct Params { size: u32 }\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  if (idx >= params.size) { return; }\n  output[idx] = a[idx] + b[idx];\n}\n";
export declare const scaleShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;\n@group(0) @binding(1) var<storage, read_write> output: array<f32>;\n\nstruct Params {\n  size: u32,\n  _pad1: u32,\n  scale: f32,\n}\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  if (idx >= params.size) { return; }\n  output[idx] = input[idx] * params.scale;\n}\n";
export declare const concatChannelsShader = "\n@group(0) @binding(0) var<storage, read> a: array<f32>;      // [C_a, L]\n@group(0) @binding(1) var<storage, read> b: array<f32>;      // [C_b, L]\n@group(0) @binding(2) var<storage, read_write> output: array<f32>; // [C_a + C_b, L]\n\nstruct Params {\n  channels_a: u32,\n  channels_b: u32,\n  length: u32,\n}\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let total = (params.channels_a + params.channels_b) * params.length;\n  if (idx >= total) { return; }\n\n  let ch = idx / params.length;\n  let pos = idx % params.length;\n\n  if (ch < params.channels_a) {\n    output[idx] = a[ch * params.length + pos];\n  } else {\n    output[idx] = b[(ch - params.channels_a) * params.length + pos];\n  }\n}\n";
export declare const concatBroadcastShader = "\n@group(0) @binding(0) var<storage, read> a: array<f32>;      // [rows, cols_a]\n@group(0) @binding(1) var<storage, read> b: array<f32>;      // [cols_b] \u2014 broadcast to every row\n@group(0) @binding(2) var<storage, read_write> output: array<f32>; // [rows, cols_a + cols_b]\n\nstruct Params {\n  rows: u32,\n  cols_a: u32,\n  cols_b: u32,\n}\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let total_cols = params.cols_a + params.cols_b;\n  let total = params.rows * total_cols;\n  if (idx >= total) { return; }\n\n  let row = idx / total_cols;\n  let col = idx % total_cols;\n\n  if (col < params.cols_a) {\n    output[idx] = a[row * params.cols_a + col];\n  } else {\n    output[idx] = b[col - params.cols_a];\n  }\n}\n";
export declare const reflectionPad1dShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;      // [channels, L_in]\n@group(0) @binding(1) var<storage, read_write> output: array<f32>; // [channels, L_out]\n\nstruct Params {\n  channels: u32,\n  input_length: u32,\n  pad_left: u32,\n  pad_right: u32,\n}\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let out_length = params.input_length + params.pad_left + params.pad_right;\n  let ch = idx / out_length;\n  let out_pos = idx % out_length;\n\n  if (ch >= params.channels) { return; }\n\n  var in_pos: u32;\n  if (out_pos < params.pad_left) {\n    // Reflected left: position 0 -> pad_left, position 1 -> pad_left-1, etc.\n    in_pos = params.pad_left - out_pos;\n  } else if (out_pos >= params.pad_left + params.input_length) {\n    // Reflected right\n    let overshoot = out_pos - params.pad_left - params.input_length;\n    in_pos = params.input_length - 2u - overshoot;\n  } else {\n    in_pos = out_pos - params.pad_left;\n  }\n\n  output[ch * out_length + out_pos] = input[ch * params.input_length + in_pos];\n}\n";
export declare const alphaResidualShader = "\n@group(0) @binding(0) var<storage, read> current: array<f32>;    // conv2 output\n@group(0) @binding(1) var<storage, read> residual: array<f32>;   // residual from previous iteration\n@group(0) @binding(2) var<storage, read> alpha: array<f32>;      // [1, channels, 1] per-channel alpha\n@group(0) @binding(3) var<storage, read_write> output: array<f32>;\n\nstruct Params {\n  channels: u32,\n  length: u32,\n}\n@group(0) @binding(4) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let ch = idx / params.length;\n  if (ch >= params.channels) { return; }\n\n  // output = current + alpha[ch] * residual\n  output[idx] = current[idx] + alpha[ch] * residual[idx];\n}\n";
export declare const transposeShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;\n@group(0) @binding(1) var<storage, read_write> output: array<f32>;\n\nstruct Params { rows: u32, cols: u32 }\n@group(0) @binding(2) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let total = params.rows * params.cols;\n  if (idx >= total) { return; }\n\n  let row = idx / params.cols;\n  let col = idx % params.cols;\n  output[col * params.rows + row] = input[idx];\n}\n";
export declare const lstmShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;   // [seq_len, input_size]\n@group(0) @binding(1) var<storage, read> W: array<f32>;       // [num_dir, input_size, 4*hidden]\n@group(0) @binding(2) var<storage, read> R: array<f32>;       // [num_dir, hidden, 4*hidden]\n@group(0) @binding(3) var<storage, read> bias: array<f32>;    // [num_dir, 8*hidden]\n@group(0) @binding(4) var<storage, read_write> output: array<f32>; // [seq_len, num_dir, hidden]\n\nstruct Params {\n  seq_len: u32,\n  input_size: u32,\n  hidden_size: u32,\n  num_directions: u32,\n}\n@group(0) @binding(5) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let h_idx = gid.x; // which hidden unit\n  let dir = gid.y; // 0=forward, 1=backward\n  // NOTE: no early return \u2014 all threads in workgroup must reach storageBarrier()\n  let is_valid = h_idx < params.hidden_size && dir < params.num_directions;\n\n  let H = params.hidden_size;\n  let H4 = H * 4u;\n  let IS = params.input_size;\n  let SL = params.seq_len;\n\n  // Use safe indices for inactive threads (they won't write)\n  let safe_h = select(0u, h_idx, is_valid);\n  let safe_dir = select(0u, dir, is_valid);\n\n  // Gate offsets within 4*hidden: i=0, o=1, f=2, c=3 (ONNX order)\n  let gate_i = safe_h;\n  let gate_o = H + safe_h;\n  let gate_f = 2u * H + safe_h;\n  let gate_c = 3u * H + safe_h;\n\n  // Bias offsets: [Wb_i, Wb_o, Wb_f, Wb_c, Rb_i, Rb_o, Rb_f, Rb_c]\n  let bias_base = safe_dir * 8u * H;\n  var b_wi = 0.0; var b_wo = 0.0; var b_wf = 0.0; var b_wc = 0.0;\n  var b_ri = 0.0; var b_ro = 0.0; var b_rf = 0.0; var b_rc = 0.0;\n  if (is_valid) {\n    b_wi = bias[bias_base + safe_h];\n    b_wo = bias[bias_base + H + safe_h];\n    b_wf = bias[bias_base + 2u * H + safe_h];\n    b_wc = bias[bias_base + 3u * H + safe_h];\n    b_ri = bias[bias_base + 4u * H + safe_h];\n    b_ro = bias[bias_base + 5u * H + safe_h];\n    b_rf = bias[bias_base + 6u * H + safe_h];\n    b_rc = bias[bias_base + 7u * H + safe_h];\n  }\n\n  var h_val = 0.0; // hidden state for this unit\n  var c_val = 0.0; // cell state for this unit\n\n  // Weight base offsets for this direction\n  // W: [num_dir, IS, 4H] \u2014 flat stride: dir * IS * H4\n  // R: [num_dir, H, 4H]  \u2014 flat stride: dir * H * H4\n  let w_base = safe_dir * IS * H4;\n  let r_base = safe_dir * H * H4;\n\n  for (var step = 0u; step < SL; step++) {\n    if (is_valid) {\n      // Forward: t=step, Backward: t=SL-1-step\n      let t = select(SL - 1u - step, step, safe_dir == 0u);\n\n      // Compute gates from input: sum over input_size\n      var gi = b_wi + b_ri;\n      var go = b_wo + b_ro;\n      var gf = b_wf + b_rf;\n      var gc = b_wc + b_rc;\n\n      // Input contribution: W[dir, j, gate*H+h_idx] \u2014 layout [IS, 4H]\n      // x[j] * W[w_base + j * H4 + gate_offset]\n      for (var j = 0u; j < IS; j++) {\n        let x_val = input[t * IS + j];\n        let w_off = w_base + j * H4;\n        gi += x_val * W[w_off + gate_i];\n        go += x_val * W[w_off + gate_o];\n        gf += x_val * W[w_off + gate_f];\n        gc += x_val * W[w_off + gate_c];\n      }\n\n      // Recurrence contribution: R[dir, j, gate*H+h_idx] \u2014 layout [H, 4H]\n      // h_prev[j] * R[r_base + j * H4 + gate_offset]\n      if (step > 0u) {\n        let prev_t = select(SL - step, step - 1u, safe_dir == 0u);\n        let prev_base = prev_t * params.num_directions * H + safe_dir * H;\n        for (var j = 0u; j < H; j++) {\n          let h_prev = output[prev_base + j];\n          let r_off = r_base + j * H4;\n          gi += h_prev * R[r_off + gate_i];\n          go += h_prev * R[r_off + gate_o];\n          gf += h_prev * R[r_off + gate_f];\n          gc += h_prev * R[r_off + gate_c];\n        }\n      }\n\n      // Apply activations\n      // Clamp sigmoid inputs to avoid exp overflow (exp(88.72) > f32 max)\n      let i_gate = 1.0 / (1.0 + exp(-clamp(gi, -44.0, 44.0))); // sigmoid\n      let o_gate = 1.0 / (1.0 + exp(-clamp(go, -44.0, 44.0)));\n      let f_gate = 1.0 / (1.0 + exp(-clamp(gf, -44.0, 44.0)));\n      // Clamp tanh inputs: tanh uses exp(2x), so |x| > 44 \u2192 exp(88) \u2192 Inf \u2192 NaN\n      let c_gate = tanh(clamp(gc, -44.0, 44.0));\n\n      c_val = f_gate * c_val + i_gate * c_gate;\n      h_val = o_gate * tanh(clamp(c_val, -44.0, 44.0));\n\n      // Write output: [t, dir, h_idx] \u2192 flat: t * num_dir * H + dir * H + h_idx\n      output[t * params.num_directions * H + safe_dir * H + safe_h] = h_val;\n    }\n\n    // Barrier: all threads (active and inactive) must reach this point\n    // storageBarrier() ensures visibility of storage buffer writes across threads in the workgroup\n    storageBarrier();\n  }\n}\n";
export declare const expandRowMajorShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;      // [seq_len, dim]\n@group(0) @binding(1) var<storage, read> cumsum: array<u32>;     // [seq_len] prefix sum of durations\n@group(0) @binding(2) var<storage, read_write> output: array<f32>; // [total_frames, dim]\n\nstruct Params {\n  seq_len: u32,\n  dim: u32,\n  total_frames: u32,\n}\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let total = params.total_frames * params.dim;\n  if (idx >= total) { return; }\n\n  let frame = idx / params.dim;\n  let d = idx % params.dim;\n\n  // Binary search: find token i where cumsum[i-1] <= frame < cumsum[i]\n  var lo: u32 = 0u;\n  var hi: u32 = params.seq_len;\n  while (lo < hi) {\n    let mid = (lo + hi) / 2u;\n    if (cumsum[mid] <= frame) {\n      lo = mid + 1u;\n    } else {\n      hi = mid;\n    }\n  }\n  let token = lo;\n\n  output[idx] = input[token * params.dim + d];\n}\n";
export declare const expandChannelFirstShader = "\n@group(0) @binding(0) var<storage, read> input: array<f32>;      // [seq_len, dim] row-major\n@group(0) @binding(1) var<storage, read> cumsum: array<u32>;     // [seq_len] prefix sum of durations\n@group(0) @binding(2) var<storage, read_write> output: array<f32>; // [dim, total_frames] channel-first\n\nstruct Params {\n  seq_len: u32,\n  dim: u32,\n  total_frames: u32,\n}\n@group(0) @binding(3) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let idx = gid.x;\n  let total = params.total_frames * params.dim;\n  if (idx >= total) { return; }\n\n  // Output layout: [dim, total_frames] \u2014 idx = channel * total_frames + frame\n  let channel = idx / params.total_frames;\n  let frame = idx % params.total_frames;\n\n  // Binary search: find token i where cumsum[i-1] <= frame < cumsum[i]\n  var lo: u32 = 0u;\n  var hi: u32 = params.seq_len;\n  while (lo < hi) {\n    let mid = (lo + hi) / 2u;\n    if (cumsum[mid] <= frame) {\n      lo = mid + 1u;\n    } else {\n      hi = mid;\n    }\n  }\n  let token = lo;\n\n  output[idx] = input[token * params.dim + channel];\n}\n";
export declare const istftShader = "\n// iSTFT synthesis: conv_post [22, genLength] \u2192 waveform [waveformLength]\n// Gather-based ConvTranspose: each thread computes one output sample\n// Fuses: magnitude/phase split, exp, sin(sin(ph)), cos(sin(ph)), ConvTranspose scatter\n\n@group(0) @binding(0) var<storage, read> conv_post: array<f32>;    // [22, gen_length]\n@group(0) @binding(1) var<storage, read> weight_real: array<f32>;  // [11, 20]\n@group(0) @binding(2) var<storage, read> weight_imag: array<f32>;  // [11, 20]\n@group(0) @binding(3) var<storage, read_write> output: array<f32>; // [waveform_length]\n\nstruct Params {\n  gen_length: u32,\n  waveform_length: u32,\n  bins: u32,       // 11\n  kernel_size: u32, // 20\n  stride: u32,     // 5\n}\n@group(0) @binding(4) var<uniform> params: Params;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n  let out_pos = gid.x;\n  if (out_pos >= params.waveform_length) { return; }\n\n  var sum: f32 = 0.0;\n\n  // For each kernel tap, check if this output position has a contribution\n  for (var k: u32 = 0u; k < params.kernel_size; k = k + 1u) {\n    if (out_pos < k) { continue; }\n    let rem = out_pos - k;\n    if (rem % params.stride != 0u) { continue; }\n    let t = rem / params.stride;\n    if (t >= params.gen_length) { continue; }\n\n    // For each frequency bin, compute magnitude/phase and accumulate\n    for (var b: u32 = 0u; b < params.bins; b = b + 1u) {\n      let mag_val = conv_post[b * params.gen_length + t];\n      let ph_val = conv_post[(b + params.bins) * params.gen_length + t];\n\n      let mag = exp(mag_val);\n      let sin_ph = sin(ph_val);\n      let real_comp = mag * cos(sin_ph);\n      let imag_comp = mag * sin(sin_ph);\n\n      sum += real_comp * weight_real[b * params.kernel_size + k]\n           - imag_comp * weight_imag[b * params.kernel_size + k];\n    }\n  }\n\n  output[out_pos] = sum;\n}\n";
