Files
snitch/cmd/stats.go

301 lines
6.9 KiB
Go

package cmd
import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"log"
"os"
"os/signal"
"snitch/internal/collector"
"sort"
"strconv"
"strings"
"syscall"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
)
type StatsData struct {
Timestamp time.Time `json:"ts"`
Total int `json:"total"`
ByProto map[string]int `json:"by_proto"`
ByState map[string]int `json:"by_state"`
ByProc []ProcessStats `json:"by_proc"`
ByIf []InterfaceStats `json:"by_if"`
}
type ProcessStats struct {
PID int `json:"pid"`
Process string `json:"process"`
Count int `json:"count"`
}
type InterfaceStats struct {
Interface string `json:"if"`
Count int `json:"count"`
}
// stats-specific flags
var (
statsOutputFormat string
statsInterval time.Duration
statsCount int
statsNoHeaders bool
)
var statsCmd = &cobra.Command{
Use: "stats [filters...]",
Short: "Aggregated connection counters",
Long: `Aggregated connection counters.
Filters are specified in key=value format. For example:
snitch stats proto=tcp state=listening
Available filters:
proto, state, pid, proc, lport, rport, user, laddr, raddr, contains
`,
Run: func(cmd *cobra.Command, args []string) {
runStatsCommand(args)
},
}
func runStatsCommand(args []string) {
filters, err := BuildFilters(args)
if err != nil {
log.Fatalf("Error parsing filters: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Handle interrupts gracefully
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigChan
cancel()
}()
count := 0
for {
stats, err := generateStats(filters)
if err != nil {
log.Printf("Error generating stats: %v", err)
if statsCount > 0 || statsInterval == 0 {
return
}
time.Sleep(statsInterval)
continue
}
switch statsOutputFormat {
case "json":
printStatsJSON(stats)
case "csv":
printStatsCSV(stats, !statsNoHeaders && count == 0)
default:
printStatsTable(stats, !statsNoHeaders && count == 0)
}
count++
if statsCount > 0 && count >= statsCount {
return
}
if statsInterval == 0 {
return // One-shot mode
}
select {
case <-ctx.Done():
return
case <-time.After(statsInterval):
continue
}
}
}
func generateStats(filters collector.FilterOptions) (*StatsData, error) {
filteredConnections, err := FetchConnections(filters)
if err != nil {
return nil, err
}
stats := &StatsData{
Timestamp: time.Now(),
Total: len(filteredConnections),
ByProto: make(map[string]int),
ByState: make(map[string]int),
ByProc: make([]ProcessStats, 0),
ByIf: make([]InterfaceStats, 0),
}
procCounts := make(map[string]ProcessStats)
ifCounts := make(map[string]int)
for _, conn := range filteredConnections {
// Count by protocol
stats.ByProto[conn.Proto]++
// Count by state
stats.ByState[conn.State]++
// Count by process
if conn.Process != "" {
key := fmt.Sprintf("%d-%s", conn.PID, conn.Process)
if existing, ok := procCounts[key]; ok {
existing.Count++
procCounts[key] = existing
} else {
procCounts[key] = ProcessStats{
PID: conn.PID,
Process: conn.Process,
Count: 1,
}
}
}
// Count by interface (placeholder since we don't have interface data yet)
if conn.Interface != "" {
ifCounts[conn.Interface]++
}
}
// Convert process map to sorted slice
for _, procStats := range procCounts {
stats.ByProc = append(stats.ByProc, procStats)
}
sort.Slice(stats.ByProc, func(i, j int) bool {
return stats.ByProc[i].Count > stats.ByProc[j].Count
})
// Convert interface map to sorted slice
for iface, count := range ifCounts {
stats.ByIf = append(stats.ByIf, InterfaceStats{
Interface: iface,
Count: count,
})
}
sort.Slice(stats.ByIf, func(i, j int) bool {
return stats.ByIf[i].Count > stats.ByIf[j].Count
})
return stats, nil
}
func printStatsJSON(stats *StatsData) {
jsonOutput, err := json.MarshalIndent(stats, "", " ")
if err != nil {
log.Printf("Error marshaling JSON: %v", err)
return
}
fmt.Println(string(jsonOutput))
}
func printStatsCSV(stats *StatsData, headers bool) {
writer := csv.NewWriter(os.Stdout)
defer writer.Flush()
if headers {
_ = writer.Write([]string{"timestamp", "metric", "key", "value"})
}
ts := stats.Timestamp.Format(time.RFC3339)
_ = writer.Write([]string{ts, "total", "", strconv.Itoa(stats.Total)})
for proto, count := range stats.ByProto {
_ = writer.Write([]string{ts, "proto", proto, strconv.Itoa(count)})
}
for state, count := range stats.ByState {
_ = writer.Write([]string{ts, "state", state, strconv.Itoa(count)})
}
for _, proc := range stats.ByProc {
_ = writer.Write([]string{ts, "process", proc.Process, strconv.Itoa(proc.Count)})
}
for _, iface := range stats.ByIf {
_ = writer.Write([]string{ts, "interface", iface.Interface, strconv.Itoa(iface.Count)})
}
}
func printStatsTable(stats *StatsData, headers bool) {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
defer w.Flush()
if headers {
fmt.Fprintf(w, "TIMESTAMP\t%s\n", stats.Timestamp.Format(time.RFC3339))
fmt.Fprintf(w, "TOTAL CONNECTIONS\t%d\n", stats.Total)
fmt.Fprintln(w)
}
// Protocol breakdown
if len(stats.ByProto) > 0 {
if headers {
fmt.Fprintln(w, "BY PROTOCOL:")
fmt.Fprintln(w, "PROTO\tCOUNT")
}
protocols := make([]string, 0, len(stats.ByProto))
for proto := range stats.ByProto {
protocols = append(protocols, proto)
}
sort.Strings(protocols)
for _, proto := range protocols {
fmt.Fprintf(w, "%s\t%d\n", strings.ToUpper(proto), stats.ByProto[proto])
}
fmt.Fprintln(w)
}
// State breakdown
if len(stats.ByState) > 0 {
if headers {
fmt.Fprintln(w, "BY STATE:")
fmt.Fprintln(w, "STATE\tCOUNT")
}
states := make([]string, 0, len(stats.ByState))
for state := range stats.ByState {
states = append(states, state)
}
sort.Strings(states)
for _, state := range states {
fmt.Fprintf(w, "%s\t%d\n", state, stats.ByState[state])
}
fmt.Fprintln(w)
}
// Process breakdown (top 10)
if len(stats.ByProc) > 0 {
if headers {
fmt.Fprintln(w, "BY PROCESS (TOP 10):")
fmt.Fprintln(w, "PID\tPROCESS\tCOUNT")
}
limit := 10
if len(stats.ByProc) < limit {
limit = len(stats.ByProc)
}
for i := 0; i < limit; i++ {
proc := stats.ByProc[i]
fmt.Fprintf(w, "%d\t%s\t%d\n", proc.PID, proc.Process, proc.Count)
}
}
}
func init() {
rootCmd.AddCommand(statsCmd)
// stats-specific flags
statsCmd.Flags().StringVarP(&statsOutputFormat, "output", "o", "table", "Output format (table, json, csv)")
statsCmd.Flags().DurationVarP(&statsInterval, "interval", "i", 0, "Refresh interval (0 = one-shot)")
statsCmd.Flags().IntVarP(&statsCount, "count", "c", 0, "Number of iterations (0 = unlimited)")
statsCmd.Flags().BoolVar(&statsNoHeaders, "no-headers", false, "Omit headers for table/csv output")
// shared filter flags
addFilterFlags(statsCmd)
}