diff --git a/v2/pkg/runner/runner.go b/v2/pkg/runner/runner.go index 5b0407a90..83887ea94 100644 --- a/v2/pkg/runner/runner.go +++ b/v2/pkg/runner/runner.go @@ -8,11 +8,13 @@ import ( "os" "path" "regexp" + "strconv" "strings" "github.com/pkg/errors" "github.com/projectdiscovery/gologger" + contextutil "github.com/projectdiscovery/utils/context" fileutil "github.com/projectdiscovery/utils/file" mapsutil "github.com/projectdiscovery/utils/maps" @@ -73,7 +75,8 @@ func NewRunner(options *Options) (*Runner, error) { // RunEnumeration wraps RunEnumerationWithCtx with an empty context func (r *Runner) RunEnumeration() error { - return r.RunEnumerationWithCtx(context.Background()) + ctx, _ := contextutil.WithValues(context.Background(), contextutil.ContextArg("All"), contextutil.ContextArg(strconv.FormatBool(r.options.All))) + return r.RunEnumerationWithCtx(ctx) } // RunEnumerationWithCtx runs the subdomain enumeration flow on the targets specified @@ -105,7 +108,8 @@ func (r *Runner) RunEnumerationWithCtx(ctx context.Context) error { // EnumerateMultipleDomains wraps EnumerateMultipleDomainsWithCtx with an empty context func (r *Runner) EnumerateMultipleDomains(reader io.Reader, writers []io.Writer) error { - return r.EnumerateMultipleDomainsWithCtx(context.Background(), reader, writers) + ctx, _ := contextutil.WithValues(context.Background(), contextutil.ContextArg("All"), contextutil.ContextArg(strconv.FormatBool(r.options.All))) + return r.EnumerateMultipleDomainsWithCtx(ctx, reader, writers) } // EnumerateMultipleDomainsWithCtx enumerates subdomains for multiple domains diff --git a/v2/pkg/subscraping/sources/crtsh/crtsh.go b/v2/pkg/subscraping/sources/crtsh/crtsh.go index 28fec2ec4..d01a6e99f 100644 --- a/v2/pkg/subscraping/sources/crtsh/crtsh.go +++ b/v2/pkg/subscraping/sources/crtsh/crtsh.go @@ -5,6 +5,7 @@ import ( "context" "database/sql" "fmt" + "strconv" "strings" "time" @@ -14,6 +15,7 @@ import ( _ "github.com/lib/pq" "github.com/projectdiscovery/subfinder/v2/pkg/subscraping" + contextutil "github.com/projectdiscovery/utils/context" ) type subdomain struct { @@ -40,7 +42,7 @@ func (s *Source) Run(ctx context.Context, domain string, session *subscraping.Se close(results) }(time.Now()) - count := s.getSubdomainsFromSQL(domain, session, results) + count := s.getSubdomainsFromSQL(ctx, domain, session, results) if count > 0 { return } @@ -50,7 +52,7 @@ func (s *Source) Run(ctx context.Context, domain string, session *subscraping.Se return results } -func (s *Source) getSubdomainsFromSQL(domain string, session *subscraping.Session, results chan subscraping.Result) int { +func (s *Source) getSubdomainsFromSQL(ctx context.Context, domain string, session *subscraping.Session, results chan subscraping.Result) int { db, err := sql.Open("postgres", "host=crt.sh user=guest dbname=certwatch sslmode=disable binary_parameters=yes") if err != nil { results <- subscraping.Result{Source: s.Name(), Type: subscraping.Error, Error: err} @@ -60,7 +62,14 @@ func (s *Source) getSubdomainsFromSQL(domain string, session *subscraping.Sessio defer db.Close() - query := `WITH ci AS ( + limitClause := "" + if all, ok := ctx.Value(contextutil.ContextArg("All")).(contextutil.ContextArg); ok { + if allBool, err := strconv.ParseBool(string(all)); err == nil && !allBool { + limitClause = "LIMIT 10000" + } + } + + query := fmt.Sprintf(`WITH ci AS ( SELECT min(sub.CERTIFICATE_ID) ID, min(sub.ISSUER_CA_ID) ISSUER_CA_ID, array_agg(DISTINCT sub.NAME_VALUE) NAME_VALUES, @@ -71,7 +80,8 @@ func (s *Source) getSubdomainsFromSQL(domain string, session *subscraping.Sessio FROM (SELECT * FROM certificate_and_identities cai WHERE plainto_tsquery('certwatch', $1) @@ identities(cai.CERTIFICATE) - AND cai.NAME_VALUE ILIKE ('%' || $1 || '%') + AND cai.NAME_VALUE ILIKE ('%%' || $1 || '%%') + %s ) sub GROUP BY sub.CERTIFICATE ) @@ -84,7 +94,7 @@ func (s *Source) getSubdomainsFromSQL(domain string, session *subscraping.Sessio ) le ON TRUE, ca WHERE ci.ISSUER_CA_ID = ca.ID - ORDER BY le.ENTRY_TIMESTAMP DESC NULLS LAST;` + ORDER BY le.ENTRY_TIMESTAMP DESC NULLS LAST;`, limitClause) rows, err := db.Query(query, domain) if err != nil { results <- subscraping.Result{Source: s.Name(), Type: subscraping.Error, Error: err}