301 lines
7.1 KiB
Go
301 lines
7.1 KiB
Go
package cmd
|
|
|
|
import (
|
|
"context"
|
|
"encoding/csv"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"os/signal"
|
|
"snitch/internal/collector"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"text/tabwriter"
|
|
"time"
|
|
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
type StatsData struct {
|
|
Timestamp time.Time `json:"ts"`
|
|
Total int `json:"total"`
|
|
ByProto map[string]int `json:"by_proto"`
|
|
ByState map[string]int `json:"by_state"`
|
|
ByProc []ProcessStats `json:"by_proc"`
|
|
ByIf []InterfaceStats `json:"by_if"`
|
|
}
|
|
|
|
type ProcessStats struct {
|
|
PID int `json:"pid"`
|
|
Process string `json:"process"`
|
|
Count int `json:"count"`
|
|
}
|
|
|
|
type InterfaceStats struct {
|
|
Interface string `json:"if"`
|
|
Count int `json:"count"`
|
|
}
|
|
|
|
var (
|
|
statsOutputFormat string
|
|
statsInterval time.Duration
|
|
statsCount int
|
|
statsNoHeaders bool
|
|
)
|
|
|
|
var statsCmd = &cobra.Command{
|
|
Use: "stats [filters...]",
|
|
Short: "Aggregated connection counters",
|
|
Long: `Aggregated connection counters.
|
|
|
|
Filters are specified in key=value format. For example:
|
|
snitch stats proto=tcp state=listening
|
|
|
|
Available filters:
|
|
proto, state, pid, proc, lport, rport, user, laddr, raddr, contains
|
|
`,
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
runStatsCommand(args)
|
|
},
|
|
}
|
|
|
|
func runStatsCommand(args []string) {
|
|
filters, err := parseFilters(args)
|
|
if err != nil {
|
|
log.Fatalf("Error parsing filters: %v", err)
|
|
}
|
|
filters.IPv4 = ipv4
|
|
filters.IPv6 = ipv6
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Handle interrupts gracefully
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
|
go func() {
|
|
<-sigChan
|
|
cancel()
|
|
}()
|
|
|
|
count := 0
|
|
for {
|
|
stats, err := generateStats(filters)
|
|
if err != nil {
|
|
log.Printf("Error generating stats: %v", err)
|
|
if statsCount > 0 || statsInterval == 0 {
|
|
return
|
|
}
|
|
time.Sleep(statsInterval)
|
|
continue
|
|
}
|
|
|
|
switch statsOutputFormat {
|
|
case "json":
|
|
printStatsJSON(stats)
|
|
case "csv":
|
|
printStatsCSV(stats, !statsNoHeaders && count == 0)
|
|
default:
|
|
printStatsTable(stats, !statsNoHeaders && count == 0)
|
|
}
|
|
|
|
count++
|
|
if statsCount > 0 && count >= statsCount {
|
|
return
|
|
}
|
|
|
|
if statsInterval == 0 {
|
|
return // One-shot mode
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(statsInterval):
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
func generateStats(filters collector.FilterOptions) (*StatsData, error) {
|
|
connections, err := collector.GetConnections()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
filteredConnections := collector.FilterConnections(connections, filters)
|
|
|
|
stats := &StatsData{
|
|
Timestamp: time.Now(),
|
|
Total: len(filteredConnections),
|
|
ByProto: make(map[string]int),
|
|
ByState: make(map[string]int),
|
|
ByProc: make([]ProcessStats, 0),
|
|
ByIf: make([]InterfaceStats, 0),
|
|
}
|
|
|
|
procCounts := make(map[string]ProcessStats)
|
|
ifCounts := make(map[string]int)
|
|
|
|
for _, conn := range filteredConnections {
|
|
// Count by protocol
|
|
stats.ByProto[conn.Proto]++
|
|
|
|
// Count by state
|
|
stats.ByState[conn.State]++
|
|
|
|
// Count by process
|
|
if conn.Process != "" {
|
|
key := fmt.Sprintf("%d-%s", conn.PID, conn.Process)
|
|
if existing, ok := procCounts[key]; ok {
|
|
existing.Count++
|
|
procCounts[key] = existing
|
|
} else {
|
|
procCounts[key] = ProcessStats{
|
|
PID: conn.PID,
|
|
Process: conn.Process,
|
|
Count: 1,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Count by interface (placeholder since we don't have interface data yet)
|
|
if conn.Interface != "" {
|
|
ifCounts[conn.Interface]++
|
|
}
|
|
}
|
|
|
|
// Convert process map to sorted slice
|
|
for _, procStats := range procCounts {
|
|
stats.ByProc = append(stats.ByProc, procStats)
|
|
}
|
|
sort.Slice(stats.ByProc, func(i, j int) bool {
|
|
return stats.ByProc[i].Count > stats.ByProc[j].Count
|
|
})
|
|
|
|
// Convert interface map to sorted slice
|
|
for iface, count := range ifCounts {
|
|
stats.ByIf = append(stats.ByIf, InterfaceStats{
|
|
Interface: iface,
|
|
Count: count,
|
|
})
|
|
}
|
|
sort.Slice(stats.ByIf, func(i, j int) bool {
|
|
return stats.ByIf[i].Count > stats.ByIf[j].Count
|
|
})
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
func printStatsJSON(stats *StatsData) {
|
|
jsonOutput, err := json.MarshalIndent(stats, "", " ")
|
|
if err != nil {
|
|
log.Printf("Error marshaling JSON: %v", err)
|
|
return
|
|
}
|
|
fmt.Println(string(jsonOutput))
|
|
}
|
|
|
|
func printStatsCSV(stats *StatsData, headers bool) {
|
|
writer := csv.NewWriter(os.Stdout)
|
|
defer writer.Flush()
|
|
|
|
if headers {
|
|
_ = writer.Write([]string{"timestamp", "metric", "key", "value"})
|
|
}
|
|
|
|
ts := stats.Timestamp.Format(time.RFC3339)
|
|
|
|
_ = writer.Write([]string{ts, "total", "", strconv.Itoa(stats.Total)})
|
|
|
|
for proto, count := range stats.ByProto {
|
|
_ = writer.Write([]string{ts, "proto", proto, strconv.Itoa(count)})
|
|
}
|
|
|
|
for state, count := range stats.ByState {
|
|
_ = writer.Write([]string{ts, "state", state, strconv.Itoa(count)})
|
|
}
|
|
|
|
for _, proc := range stats.ByProc {
|
|
_ = writer.Write([]string{ts, "process", proc.Process, strconv.Itoa(proc.Count)})
|
|
}
|
|
|
|
for _, iface := range stats.ByIf {
|
|
_ = writer.Write([]string{ts, "interface", iface.Interface, strconv.Itoa(iface.Count)})
|
|
}
|
|
}
|
|
|
|
func printStatsTable(stats *StatsData, headers bool) {
|
|
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
|
defer w.Flush()
|
|
|
|
if headers {
|
|
fmt.Fprintf(w, "TIMESTAMP\t%s\n", stats.Timestamp.Format(time.RFC3339))
|
|
fmt.Fprintf(w, "TOTAL CONNECTIONS\t%d\n", stats.Total)
|
|
fmt.Fprintln(w)
|
|
}
|
|
|
|
// Protocol breakdown
|
|
if len(stats.ByProto) > 0 {
|
|
if headers {
|
|
fmt.Fprintln(w, "BY PROTOCOL:")
|
|
fmt.Fprintln(w, "PROTO\tCOUNT")
|
|
}
|
|
protocols := make([]string, 0, len(stats.ByProto))
|
|
for proto := range stats.ByProto {
|
|
protocols = append(protocols, proto)
|
|
}
|
|
sort.Strings(protocols)
|
|
for _, proto := range protocols {
|
|
fmt.Fprintf(w, "%s\t%d\n", strings.ToUpper(proto), stats.ByProto[proto])
|
|
}
|
|
fmt.Fprintln(w)
|
|
}
|
|
|
|
// State breakdown
|
|
if len(stats.ByState) > 0 {
|
|
if headers {
|
|
fmt.Fprintln(w, "BY STATE:")
|
|
fmt.Fprintln(w, "STATE\tCOUNT")
|
|
}
|
|
states := make([]string, 0, len(stats.ByState))
|
|
for state := range stats.ByState {
|
|
states = append(states, state)
|
|
}
|
|
sort.Strings(states)
|
|
for _, state := range states {
|
|
fmt.Fprintf(w, "%s\t%d\n", state, stats.ByState[state])
|
|
}
|
|
fmt.Fprintln(w)
|
|
}
|
|
|
|
// Process breakdown (top 10)
|
|
if len(stats.ByProc) > 0 {
|
|
if headers {
|
|
fmt.Fprintln(w, "BY PROCESS (TOP 10):")
|
|
fmt.Fprintln(w, "PID\tPROCESS\tCOUNT")
|
|
}
|
|
limit := 10
|
|
if len(stats.ByProc) < limit {
|
|
limit = len(stats.ByProc)
|
|
}
|
|
for i := 0; i < limit; i++ {
|
|
proc := stats.ByProc[i]
|
|
fmt.Fprintf(w, "%d\t%s\t%d\n", proc.PID, proc.Process, proc.Count)
|
|
}
|
|
}
|
|
}
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(statsCmd)
|
|
statsCmd.Flags().StringVarP(&statsOutputFormat, "output", "o", "table", "Output format (table, json, csv)")
|
|
statsCmd.Flags().DurationVarP(&statsInterval, "interval", "i", 0, "Refresh interval (0 = one-shot)")
|
|
statsCmd.Flags().IntVarP(&statsCount, "count", "c", 0, "Number of iterations (0 = unlimited)")
|
|
statsCmd.Flags().BoolVar(&statsNoHeaders, "no-headers", false, "Omit headers for table/csv output")
|
|
statsCmd.Flags().BoolVarP(&ipv4, "ipv4", "4", false, "Only show IPv4 connections")
|
|
statsCmd.Flags().BoolVarP(&ipv6, "ipv6", "6", false, "Only show IPv6 connections")
|
|
}
|