#!/usr/bin/env ruby

# Number of acceptable classification errors.
# It should only be decreased.
ACCEPTABLE_ERRORS = 39

# Number of acceptable classification errors when using --all.
# It should only be decreased.
ACCEPTABLE_ERRORS_ALL = 777

# Avoid buffering output.
STDOUT.sync = true
STDERR.sync = true

def print_usage(out)
  out.puts "Usage: #{$PROGRAM_NAME} [--all] [--extensions=<list>] [--test]"
  out.puts ''
  out.puts 'Performs leave-one-out cross-validation of the classifier.'
  out.puts 'By default, outputs results only for samples with ambiguous extensions.'
  out.puts 'If the --all flag is given, all samples and languages are considered.'
  out.puts ''
  out.puts 'If the --extensions option is used, only the extensions specified in the comma-seperated list will be considered.'
  out.puts 'Extensions in the list must include the starting dot.'
  out.puts ''
  out.puts 'The --test flag can be used to verify that the number of errors is acceptable.'
end

if ARGV.include? '--help' or ARGV.include? '-h'
  print_usage STDOUT
  exit 0
end

$all = false
$test = false
$exts = []
ARGV.each_with_index do |arg, index|
  case arg
  when '--all'
    $all = true
  when '--test'
    $test = true
  when /^--extensions=(.*)$/
    $exts = $1.delete("'\"").split(',').map(&:strip)
  else
    STDERR.puts "Invalid command line argument: #{arg}"
    STDERR.puts ''
    print_usage STDERR
    exit 1
  end
end

require 'parallel'
require 'set'
require 'linguist'
include Linguist

# Skip extensions with catch-all heuristic.
# Disabled with --all.
$skip_extensions = Set.new
unless $all
  Heuristics.all.each do |h|
    rules = h.instance_variable_get(:@rules)
    if rules[-1]['pattern'].is_a? AlwaysMatch
      $skip_extensions |= Set.new(h.extensions)
    end
  end
end

$samples = []
Samples.each do |sample|
  sample[:data] = File.read(sample[:path])
  sample[:tokens] = Tokenizer.tokenize(sample[:data])
  $samples << sample
end

def eval(sample)
  return nil if $skip_extensions.include? sample[:extname]

  # Apply extensions list filter
  if $exts.any? && !$exts.include?(File.extname(sample[:path]))
    return nil
  end

  # If --all is set, use all languages. Otherwise, get only languages that are
  # ambiguous in terms of filename and extension.
  if $all
    # nil languages is the same as all languages.
    languages = nil
  else
    languages = Language.find_by_filename(sample[:path]).map(&:name)
    return nil if languages.length == 1
    languages = Language.find_by_extension(sample[:path]).map(&:name)
    return nil if languages.length <= 1
  end

  # Test only languages with at least 2 samples.
  n_samples = 0
  $samples.each do |other_sample|
    if other_sample[:language] == sample[:language]
      n_samples += 1
    end
  end
  if n_samples <= 1
    warn "WARNING: No cross-validation possible for #{sample[:language]} (only one sample)"
    return nil
  end

  # Train the classifier.
  db = {}
  train_samples = $samples.select { |s| s[:path] != sample[:path] }
  train_samples.each do |train_sample|
    Classifier.train!(db, train_sample[:language], train_sample[:tokens])
  end

  # Get result.
  results = Classifier.classify(db, sample[:data], languages)
  path = sample[:path].sub("#{Samples::ROOT}/", '')
  if results.length == 0
    "#{path} BAD (Unknown)"
  elsif sample[:language] == results.first[0]
    "#{path} GOOD"
  else
    "#{path} BAD (#{results.first[0]})"
  end
end

results = Parallel.map($samples) { |sample| eval(sample) }
results.reject! { |s| s.nil? }
results.sort!

total_errors = 0
results.each do |res|
  STDOUT.puts res
  total_errors += 1 if res.include? ' BAD '
end

if $test
  acceptable_errors = $all ? ACCEPTABLE_ERRORS_ALL : ACCEPTABLE_ERRORS
  STDOUT.flush
  STDERR.puts ''
  if total_errors > acceptable_errors
    STDERR.puts "Number of errors (#{total_errors}) is above the acceptable threshold (#{acceptable_errors})"
    exit 1
  else
    STDOUT.puts "Number of errors (#{total_errors}) is within the acceptable threshold (#{acceptable_errors})"
  end
end
