import { bufferToHex, keccak256 } from 'ethereumjs-util' export default class MerkleTree { private readonly elements: Buffer[] private readonly bufferElementPositionIndex: { [hexElement: string]: number } private readonly layers: Buffer[][] constructor(elements: Buffer[]) { this.elements = [...elements] // Sort elements this.elements.sort(Buffer.compare) // Deduplicate elements this.elements = MerkleTree.bufDedup(this.elements) this.bufferElementPositionIndex = this.elements.reduce<{ [hexElement: string]: number }>((memo, el, index) => { memo[bufferToHex(el)] = index return memo }, {}) // Create layers this.layers = this.getLayers(this.elements) } getLayers(elements: Buffer[]): Buffer[][] { if (elements.length === 0) { throw new Error('empty tree') } const layers = [] layers.push(elements) // Get next layer until we reach the root while (layers[layers.length - 1].length > 1) { layers.push(this.getNextLayer(layers[layers.length - 1])) } return layers } getNextLayer(elements: Buffer[]): Buffer[] { return elements.reduce((layer, el, idx, arr) => { if (idx % 2 === 0) { // Hash the current element with its pair element layer.push(MerkleTree.combinedHash(el, arr[idx + 1])) } return layer }, []) } static combinedHash(first: Buffer, second: Buffer): Buffer { if (!first) { return second } if (!second) { return first } return keccak256(MerkleTree.sortAndConcat(first, second)) } getRoot(): Buffer { return this.layers[this.layers.length - 1][0] } getHexRoot(): string { return bufferToHex(this.getRoot()) } getProof(el: Buffer) { let idx = this.bufferElementPositionIndex[bufferToHex(el)] if (typeof idx !== 'number') { throw new Error('Element does not exist in Merkle tree') } return this.layers.reduce((proof, layer) => { const pairElement = MerkleTree.getPairElement(idx, layer) if (pairElement) { proof.push(pairElement) } idx = Math.floor(idx / 2) return proof }, []) } getHexProof(el: Buffer): string[] { const proof = this.getProof(el) return MerkleTree.bufArrToHexArr(proof) } private static getPairElement(idx: number, layer: Buffer[]): Buffer | null { const pairIdx = idx % 2 === 0 ? idx + 1 : idx - 1 if (pairIdx < layer.length) { return layer[pairIdx] } else { return null } } private static bufDedup(elements: Buffer[]): Buffer[] { return elements.filter((el, idx) => { return idx === 0 || !elements[idx - 1].equals(el) }) } private static bufArrToHexArr(arr: Buffer[]): string[] { if (arr.some((el) => !Buffer.isBuffer(el))) { throw new Error('Array is not an array of buffers') } return arr.map((el) => '0x' + el.toString('hex')) } private static sortAndConcat(...args: Buffer[]): Buffer { return Buffer.concat([...args].sort(Buffer.compare)) } }