import groovy.json.JsonOutput
import java.util.regex.Pattern
import java.util.regex.Matcher
import org.gradle.util.GradleVersion
import org.gradle.api.internal.artifacts.DefaultResolvedDependency

// Snyk dependency resolution script for Gradle.
// Tested on Gradle versions from v2.14 to v6.8.1

// This script does the following: for all the projects in the build file,
// generate a merged configuration of all the available configurations,
// and then list the dependencies as a tree.

// It's the responsibility of the caller to pick the project(s) they are
// interested in from the results.

// CLI usages:
// gradle -q -I init.gradle snykResolvedDepsJson
// gradle -q -I init.gradle snykResolvedDepsJson -Pconfiguration=specificConf -PonlySubProject=sub-project
// gradle -q -I init.gradle snykResolvedDepsJson -Pconfiguration=confNameRegex -PconfAttr=buildtype:debug,usage:java-runtime

// (-q to have clean output, -P supplies args as per https://stackoverflow.com/a/48370451)

// confAttr parameter (supported only in Gradle 3+) is used to perform attribute-based dependency variant matching
// (important for Android: https://developer.android.com/studio/build/dependencies#variant_aware)
// Its value is a comma-separated list of key:value pairs. The "key" is a case-insensitive substring
// of the class name of the attribute (e.g. "buildtype" would match com.android.build.api.attributes.BuildTypeAttr),
// the value should be a case-insensitive stringified value of the attribute

// Output format:
//
// Since Gradle is chatty and often prints a "Welcome" banner even with -q option,
// the only output lines that matter are:
// - prefixed "SNYKECHO ": should be immediately printed as debug information by the caller
// - prefixed "JSONDEPS ": JSON representation of the dependencies trees for all projects in the following format

// interface JsonDepsScriptResult {
//   defaultProject: string;
//   projects: ProjectsDict;
//   allSubProjectNames: string[];
// }
// interface ProjectsDict {
//   [project: string]: GradleProjectInfo;
// }

// interface GradleProjectInfo {
//   depGraph: DepGraph;
//   gradleGraph: GradleGraph;
//   targetFile: string;
// }
// interface GradleGraph {
//   [id: string]: {
//     name: string;
//     version: string;
//     parentIds: string[];
//   };
// }

class GradleGraph {
    Map<String, Map> nodes = [:]
    String rootId

    GradleGraph(String rootId) {
        this.rootId = rootId
    }

    Map setNode(String key, Map value) {
        if (!key || !value) {
            return null
        }

        nodes.computeIfAbsent(key) { _ ->
            ['name': value.name, 'version': value.version, 'parentIds': [] as Set]
        }
    }

    void setEdge(String parentId, String childId) {
        if (isInvalidEdge(parentId, childId)) {
            return
        }

        def parentNode = nodes.computeIfAbsent(parentId) { _ -> createEmptyNode() }
        def childNode = nodes.computeIfAbsent(childId) { _ -> createEmptyNode() }

        if (!childNode.parentIds.contains(parentId)) {
            childNode.parentIds.add(parentId)
        }
    }

    private boolean isInvalidEdge(String parentId, String childId) {
        parentId == null || childId == null || parentId == childId || (parentId != rootId && !nodes.containsKey(parentId))
    }

    private Map createEmptyNode() {
        ['name': null, 'version': null, 'parentIds': [] as Set]
    }
}


class Sha1Map {
    def coordinates
    Sha1Map() {
        coordinates = [:]
    }

    def setCoordinate(key, value) {
        this.coordinates[key] = value
    }
}

def hash(File file) {
    def md = java.security.MessageDigest.getInstance('SHA-1')
    file.eachByte(1024 * 4) { buffer, len ->
        md.update(buffer, 0, len)
    }
    return md.digest().encodeHex().toString()
}

def hashString(String str) {
    def md = java.security.MessageDigest.getInstance('SHA-1')
    return md.digest(str.bytes).encodeHex().toString()
}

def loadGraph(Iterable deps, GradleGraph graph, parentId, currentChain) {
    deps.each { dep ->
        dep.each { d ->
            def moduleArtifacts
            try {
                moduleArtifacts = d.getModuleArtifacts()
            } catch (Exception ex) {
                debugLog("failed to get artifacts for dep: `$d`")
                return
            }
            moduleArtifacts.each { a ->
                def childName = "${d.moduleGroup}:${d.moduleName}"
                def childId = "${childName}:${a.type}"
                if (a.classifier) childId += ":${a.classifier}"
                childId += "@${d.moduleVersion}"
                if (!graph.nodes.get(childId)) {
                    def childDependency = ['name': childName, 'version': d.moduleVersion]
                    graph.setNode(childId, childDependency)
                }
                if (!currentChain.contains(childId) && d.children) {
                    currentChain.add(childId)
                    loadGraph(d.children, graph, childId, currentChain)
                }
                graph.setEdge(parentId, childId)
            }
        }
    }
}

def loadSha1MapGraph(Iterable deps, GradleGraph graph, parentId, currentChain, sha1Map) {
    deps.each { dep ->
        dep.each { d ->
            def moduleArtifacts
            try {
                moduleArtifacts = d.getModuleArtifacts()
            } catch (Exception ex) {
                debugLog("failed to get artifacts for dep: `$d`")
                return
            }
            moduleArtifacts.each { a ->
                def childName = "${d.moduleGroup}:${d.moduleName}"
                def childId = "${childName}:${a.type}"
                if (a.classifier) childId += ":${a.classifier}"
                childId += "@${d.moduleVersion}"
                def sha1Hash = ""
                try {
                    sha1Hash = hash(a.getFile())
                    sha1Map.setCoordinate(sha1Hash, childId);
                } catch (Exception e) {
                    // sometimes artifacts have no file to hash
                    // for example, a reference to another project
                    // implementation project(':utilities')
                    // in this case hash the childId
                    debugLog("Failed to hash artifact ${a.id}, error=${e}")
                    sha1Hash = hashString(childId)
                    sha1Map.setCoordinate(sha1Hash, childId);
                }
                if (sha1Hash && !graph.nodes.get(sha1Hash)) {
                    def childDependency = ['name': childName, 'version': d.moduleVersion]
                    graph.setNode(sha1Hash, childDependency)
                }
                if (!currentChain.contains(sha1Hash) && d.children) {
                    currentChain.add(sha1Hash)
                    loadSha1MapGraph(d.children, graph, sha1Hash, currentChain, sha1Map)
                }
                graph.setEdge(parentId, sha1Hash)
            }
        }
    }
}

def getGradleGraph(Iterable deps) {
    def rootId = 'root-node'
    def graph = new GradleGraph(rootId)
    def currentChain = new HashSet()
    loadGraph(deps, graph, rootId, currentChain)

    return graph.nodes
}

def getGradleGraphWithSha1Map(Iterable deps, Sha1Map sha1Map) {
    if (!sha1Map) sha1Map = new Sha1Map()
    def rootId = 'root-node'
    def graph = new GradleGraph(rootId)
    def currentChain = new HashSet()
    loadSha1MapGraph(deps, graph, rootId, currentChain, sha1Map)

    return [graph: graph.nodes, sha1Map: sha1Map]
}

def debugLog(msg) {
    def debug = System.getenv('DEBUG') ?: ''
    if (debug.length() > 0) {
        println("SNYKECHO $msg")
    }
}

def matchesAttributeFilter(conf, confAttrSpec) {
    if (!conf.hasProperty('attributes')) {
        // Gradle before version 3 does not support attributes
        return true
    }
    def matches = true
    def attrs = conf.attributes
    attrs.keySet().each({ attr ->
        def attrValueAsString = attrs.getAttribute(attr).toString().toLowerCase()
        confAttrSpec.each({ keyValueFilter ->
            // attr.name is a class name, e.g. com.android.build.api.attributes.BuildTypeAttr
            if (attr.name.toLowerCase().contains(keyValueFilter[0]) && attrValueAsString != keyValueFilter[1]) {
                matches = false
            }
        })
    })
    return matches
}

def findMatchingConfigs(confs, confNameFilter, confAttrSpec) {
    def matching = confs.findAll({ it.name =~ confNameFilter })
    if (confAttrSpec == null) {
        // We don't have an attribute spec to match
        return matching
    }
    return matching.findAll({ matchesAttributeFilter(it, confAttrSpec) })
}

def findProjectConfigs(proj, confNameFilter, confAttrSpec) {
    def matching = findMatchingConfigs(proj.configurations, confNameFilter, confAttrSpec)
    if (GradleVersion.current() < GradleVersion.version('3.0')) {
        proj.configurations.each({ debugLog("conf.name=$it.name") })
        return matching
    }
    proj.configurations.each({ debugLog("conf.name=$it.name; conf.canBeResolved=$it.canBeResolved; conf.canBeConsumed=$it.canBeConsumed") })
    // We are looking for a configuration that `canBeResolved`, because it's a configuration for which
    // we can compute a dependency graph and that contains all the necessary information for resolution to happen.
    // See Gradle docs: https://docs.gradle.org/current/userguide/declaring_dependencies.html#sec:resolvable-consumable-configs
    def resolvable = []
    matching.each({ it ->
        if (!it.canBeResolved) { return }
        try {
            // Try accessing resolvedConfiguration to filter out configs that may cause issues in strict lock mode
            it.resolvedConfiguration
            resolvable.add(it)
            debugLog("Added resolvedConfiguration for conf.name=$it.name")
        } catch (Exception ex) {
            // Swallow the error
            debugLog("Skipping config ${it.name} due to resolvedConfiguration error.")
        }
    })
    debugLog("resolvableConfigs=$resolvable")
    return resolvable
}
List getResolvedConfigs(resolvableConfigs){
    List resolvedConfigs = []
    resolvableConfigs.each({ config ->
        ResolvedConfiguration resConf = config.getResolvedConfiguration()
        debugLog("config `$config.name' resolution has errors: ${resConf.hasError()}")
        if (!resConf.hasError()) {
            resolvedConfigs.add(resConf)
            debugLog("Fully resolved config `$config.name' with deps: $resConf.firstLevelModuleDependencies")
        } else {
            // even if some dependencies fail to resolve, we prefer a partial result to none
            LenientConfiguration lenientConf = resConf.getLenientConfiguration()
            debugLog("Partially resolved config `$config.name' with: $lenientConf.firstLevelModuleDependencies")
            debugLog("Couldn't resolve: ${lenientConf.getUnresolvedModuleDependencies()}")
            resolvedConfigs.add(lenientConf)
        }
    })
    return resolvedConfigs
}

String formatPath(path) {
    return path.replace(':', '/').replaceAll(~/(^\/+?)|(\/+$)/, '')
}

Boolean isRootPath(path){
    return path == rootProject.path
}

// We are attaching this task to every project, as this is the only reliable way to run it
// when we start with a subproject build.gradle. As a consequence, we need to make sure we
// only ever run it once, for the "starting" project.
def snykDepsConfExecuted = false
allprojects { Project currProj ->
    debugLog("Current project: $currProj.name")
    String onlyProj = project.hasProperty('onlySubProject') ? onlySubProject : null
    def confNameFilter = (project.hasProperty('configuration')
        ? Pattern.compile(configuration, Pattern.CASE_INSENSITIVE)
        : /.*/
    )
    def confAttrSpec = (project.hasProperty('confAttr')
        ? confAttr.toLowerCase().split(',').collect { it.split(':') }
        : null
    )

    task snykResolvedDepsJson {
        doLast { task ->
            if (snykDepsConfExecuted) {
                return
            }

            snykDepsConfExecuted = true
            debugLog('snykResolvedDepsJson task is executing via doLast')
            // debugLog("onlyProj=$onlyProj; confNameFilter=$confNameFilter; confAttrSpec=$confAttrSpec")
            // First pass: scan all configurations that match the attribute filter and collect all attributes
            // from them, to use unambiguous values of the attributes on the merged configuration.
            //
            // Why we need to scan all sub-projects: if a project A depends on B, and only B has some
            // configurations with attribute C, we still might need attribute C in our configuration
            // when resolving the project A, so that it selects a concrete variant of dependency B.
            def allConfigurationAttributes = [:] // Map<Attribute<?>, Set<?>>
            def attributesAsStrings = [:] // Map<String, Set<string>>
            rootProject.allprojects.each { proj ->
                findMatchingConfigs(proj.configurations, confNameFilter, confAttrSpec)
                    .each { conf ->
                        def attrs = conf.attributes
                        attrs.keySet().toList().each({ attr ->
                            def value = attrs.getAttribute(attr)
                            if (!allConfigurationAttributes.containsKey(attr)) {
                                allConfigurationAttributes[attr] = new HashSet()
                                attributesAsStrings[attr.name] = new HashSet()
                            }
                            allConfigurationAttributes[attr].add(value)
                            attributesAsStrings[attr.name].add(value.toString())
                        })
                    }
            }
            // name of target project - either specified by --file or the root of a multi-module build
            String defaultProjectName = currProj.name
            // the path of the above, preferred over name because path is unique
            String defaultProjectKey = isRootPath(currProj.path) ? defaultProjectName : formatPath(currProj.path)
            // collect all subprojects names that are not target
            List allSubProjectNames = []
            allprojects
                .findAll({ it.path != currProj.path })
                .each({
                    String projKey = formatPath(it.path)
                    allSubProjectNames.add(projKey)
                })
            def shouldScanProject = {
                onlyProj == null ||
                (onlyProj == '.' && it.path == currProj.path) ||
                it.name == onlyProj ||
                formatPath(it.path) == onlyProj
            }
            def projectsDict = [:]

            debugLog("defaultProjectName=$defaultProjectName; defaultProjectKey=$defaultProjectKey; allSubProjectNames=$allSubProjectNames")

            // These will be used to suggest attribute filtering to the user if the scan fails
            // due to ambiguous resolution of dependency variants
            def jsonAttrs = JsonOutput.toJson(attributesAsStrings)
            println("JSONATTRS $jsonAttrs")

            rootProject.allprojects.findAll(shouldScanProject).each { proj ->
                debugLog("processing project: name=$proj.name; path=$proj.path")

                def resolvableConfigs = findProjectConfigs(proj, confNameFilter, confAttrSpec)
                List resolvedConfigs = getResolvedConfigs(resolvableConfigs)

                if (resolvedConfigs.isEmpty() && !resolvableConfigs.isEmpty()) {
                    throw new RuntimeException('Configurations: ' + resolvableConfigs.collect { it.name } +
                            ' for project ' + proj + ' could not be resolved.')
                }
                List nonemptyFirstLevelDeps = []
                resolvedConfigs.each { nonemptyFirstLevelDeps.addAll(it.getFirstLevelModuleDependencies()) }

                debugLog("non-empty first level deps for project `$proj.name': $nonemptyFirstLevelDeps")
                debugLog('converting gradle graph to snyk-graph format')

                def projGraph = getGradleGraph(nonemptyFirstLevelDeps)
                String projKey = formatPath(proj.path)
                // if project is root ":", it's formatted to "" - we can't have an empty string as a key so use default name
                if (projKey == "") {
                    debugLog("project path is empty (proj.path=$proj.path)! will use defaultProjectName=$defaultProjectName")
                    projKey = defaultProjectKey
                }

                projectsDict[projKey] = [
                    'targetFile': findProject(proj.path).buildFile.toString(),
                    'gradleGraph': projGraph,
                    'projectVersion': proj.version.toString()
                ]
            }

            def result = [
                'defaultProject': defaultProjectName,
                'defaultProjectKey': defaultProjectKey,
                'projects': projectsDict,
                'allSubProjectNames': allSubProjectNames
            ]
            def jsonDeps = JsonOutput.toJson(result)
            println("JSONDEPS $jsonDeps")
        }
    }

    task snykNormalizedResolvedDepsJson {

        doLast { task ->
            if (snykDepsConfExecuted) {
                return
            }

            snykDepsConfExecuted = true
            debugLog('snykNormalizedResolvedDepsJson task is executing via doLast')
            // debugLog("onlyProj=$onlyProj; confNameFilter=$confNameFilter; confAttrSpec=$confAttrSpec")
            // First pass: scan all configurations that match the attribute filter and collect all attributes
            // from them, to use unambiguous values of the attributes on the merged configuration.
            //
            // Why we need to scan all sub-projects: if a project A depends on B, and only B has some
            // configurations with attribute C, we still might need attribute C in our configuration
            // when resolving the project A, so that it selects a concrete variant of dependency B.
            def allConfigurationAttributes = [:] // Map<Attribute<?>, Set<?>>
            def attributesAsStrings = [:] // Map<String, Set<string>>
            rootProject.allprojects.each { proj ->
                findMatchingConfigs(proj.configurations, confNameFilter, confAttrSpec)
                    .each { Configuration conf ->
                        def attrs = conf.attributes
                        attrs.keySet().each({ attr ->
                            def value = attrs.getAttribute(attr)
                            if (!allConfigurationAttributes.containsKey(attr)) {
                                allConfigurationAttributes[attr] = new HashSet()
                                attributesAsStrings[attr.name] = new HashSet()
                            }
                            allConfigurationAttributes[attr].add(value)
                            attributesAsStrings[attr.name].add(value.toString())
                        })
                    }
            }

            String defaultProjectName = currProj.name
            String defaultProjectKey = isRootPath(currProj.path) ? defaultProjectName : formatPath(currProj.path)
            List allSubProjectNames = []
            allprojects
                .findAll({ it.path != currProj.path })
                .each({
                    String projKey = formatPath(it.path)
                    allSubProjectNames.add(projKey)
                })
            def shouldScanProject = {
                onlyProj == null ||
                (onlyProj == '.' && it.path == currProj.path) ||
                it.name == onlyProj ||
                formatPath(it.path) == onlyProj
            }
            def projectsDict = [:]

            debugLog("defaultProjectName=$defaultProjectName; allSubProjectNames=$allSubProjectNames")

            // These will be used to suggest attribute filtering to the user if the scan fails
            // due to ambiguous resolution of dependency variants
            def jsonAttrs = JsonOutput.toJson(attributesAsStrings)
            println("JSONATTRS $jsonAttrs")

            def sha1Map = new Sha1Map()

            rootProject.allprojects.findAll(shouldScanProject).each { proj ->
                debugLog("processing project: name=$proj.name; path=$proj.path")

                def resolvableConfigs = findProjectConfigs(proj, confNameFilter, confAttrSpec)
                List resolvedConfigs = getResolvedConfigs(resolvableConfigs)
                if (!resolvableConfigs.isEmpty() && resolvedConfigs.isEmpty()) {
                    throw new RuntimeException('Configurations: ' + resolvableConfigs.collect { it.name } +
                            ' for project ' + proj + ' could not be resolved.')
                }

                List nonemptyFirstLevelDeps = []
                resolvedConfigs.each { nonemptyFirstLevelDeps.addAll(it.getFirstLevelModuleDependencies()) }

                debugLog("non-empty first level deps for project `$proj.name': $nonemptyFirstLevelDeps")
                debugLog('converting gradle graph to snyk-graph format')

                def projGraph = getGradleGraphWithSha1Map(nonemptyFirstLevelDeps, sha1Map)
                String projKey = formatPath(proj.path)

                if (projKey == "") {
                    debugLog("project path is empty (proj.path=$proj.path)! will use defaultProjectName=$defaultProjectName")
                    projKey = defaultProjectKey
                }

                projectsDict[projKey] = [
                    'targetFile': findProject(proj.path).buildFile.toString(),
                    'gradleGraph': projGraph.graph,
                    'projectVersion': proj.version.toString()
                ]
            }

            def result = [
                'defaultProject': defaultProjectName,
                'defaultProjectKey': defaultProjectKey,
                'projects': projectsDict,
                'allSubProjectNames': allSubProjectNames,
                'sha1Map': sha1Map.coordinates
            ]

            def jsonDeps = JsonOutput.toJson(result)
            println("JSONDEPS $jsonDeps")
        }
    }
}
